From b9bb62cf0c2af848bf02e2a74d9bd109ef570010 Mon Sep 17 00:00:00 2001 From: Ariel Costas Guerrero Date: Mon, 8 Dec 2025 12:04:25 +0100 Subject: Update formatting --- src/gtfs_perstop_report/src/common.py | 1 - src/gtfs_perstop_report/src/download.py | 96 +++++++++++++--------- src/gtfs_perstop_report/src/logger.py | 12 ++- .../src/proto/stop_schedule_pb2.py | 30 +++---- .../src/proto/stop_schedule_pb2.pyi | 69 ++++++++++++++-- src/gtfs_perstop_report/src/routes.py | 24 +++--- src/gtfs_perstop_report/src/services.py | 61 ++++++++------ src/gtfs_perstop_report/src/shapes.py | 31 +++++-- src/gtfs_perstop_report/src/stop_schedule_pb2.py | 35 ++++---- src/gtfs_perstop_report/src/stop_schedule_pb2.pyi | 34 +++++++- src/gtfs_perstop_report/src/stop_times.py | 71 +++++++++++----- src/gtfs_perstop_report/src/stops.py | 4 +- src/gtfs_perstop_report/src/street_name.py | 17 ++-- src/gtfs_perstop_report/src/trips.py | 75 +++++++++++------ src/gtfs_perstop_report/stop_report.py | 69 +++++++--------- 15 files changed, 400 insertions(+), 229 deletions(-) (limited to 'src/gtfs_perstop_report') diff --git a/src/gtfs_perstop_report/src/common.py b/src/gtfs_perstop_report/src/common.py index 22769e4..c2df785 100644 --- a/src/gtfs_perstop_report/src/common.py +++ b/src/gtfs_perstop_report/src/common.py @@ -40,7 +40,6 @@ def get_all_feed_dates(feed_dir: str) -> List[str]: if len(result) > 0: return result - # Fallback: use calendar_dates.txt if os.path.exists(calendar_dates_path): with open(calendar_dates_path, encoding="utf-8") as f: diff --git a/src/gtfs_perstop_report/src/download.py b/src/gtfs_perstop_report/src/download.py index 19125bc..4d0c620 100644 --- a/src/gtfs_perstop_report/src/download.py +++ b/src/gtfs_perstop_report/src/download.py @@ -9,39 +9,44 @@ from src.logger import get_logger logger = get_logger("download") + def _get_metadata_path(output_dir: str) -> str: """Get the path to the metadata file for storing ETag and Last-Modified info.""" - return os.path.join(output_dir, '.gtfsmetadata') + return os.path.join(output_dir, ".gtfsmetadata") + def _load_metadata(output_dir: str) -> Optional[dict]: """Load existing metadata from the output directory.""" metadata_path = _get_metadata_path(output_dir) if os.path.exists(metadata_path): try: - with open(metadata_path, 'r', encoding='utf-8') as f: + with open(metadata_path, "r", encoding="utf-8") as f: return json.load(f) except (json.JSONDecodeError, IOError) as e: logger.warning(f"Failed to load metadata from {metadata_path}: {e}") return None -def _save_metadata(output_dir: str, etag: Optional[str], last_modified: Optional[str]) -> None: + +def _save_metadata( + output_dir: str, etag: Optional[str], last_modified: Optional[str] +) -> None: """Save ETag and Last-Modified metadata to the output directory.""" metadata_path = _get_metadata_path(output_dir) - metadata = { - 'etag': etag, - 'last_modified': last_modified - } - + metadata = {"etag": etag, "last_modified": last_modified} + # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) - + try: - with open(metadata_path, 'w', encoding='utf-8') as f: + with open(metadata_path, "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2) except IOError as e: logger.warning(f"Failed to save metadata to {metadata_path}: {e}") -def _check_if_modified(feed_url: str, output_dir: str) -> Tuple[bool, Optional[str], Optional[str]]: + +def _check_if_modified( + feed_url: str, output_dir: str +) -> Tuple[bool, Optional[str], Optional[str]]: """ Check if the feed has been modified using conditional headers. Returns (is_modified, etag, last_modified) @@ -49,58 +54,69 @@ def _check_if_modified(feed_url: str, output_dir: str) -> Tuple[bool, Optional[s metadata = _load_metadata(output_dir) if not metadata: return True, None, None - + headers = {} - if metadata.get('etag'): - headers['If-None-Match'] = metadata['etag'] - if metadata.get('last_modified'): - headers['If-Modified-Since'] = metadata['last_modified'] - + if metadata.get("etag"): + headers["If-None-Match"] = metadata["etag"] + if metadata.get("last_modified"): + headers["If-Modified-Since"] = metadata["last_modified"] + if not headers: return True, None, None - + try: response = requests.head(feed_url, headers=headers) - + if response.status_code == 304: - logger.info("Feed has not been modified (304 Not Modified), skipping download") - return False, metadata.get('etag'), metadata.get('last_modified') + logger.info( + "Feed has not been modified (304 Not Modified), skipping download" + ) + return False, metadata.get("etag"), metadata.get("last_modified") elif response.status_code == 200: - etag = response.headers.get('ETag') - last_modified = response.headers.get('Last-Modified') + etag = response.headers.get("ETag") + last_modified = response.headers.get("Last-Modified") return True, etag, last_modified else: - logger.warning(f"Unexpected response status {response.status_code} when checking for modifications, proceeding with download") + logger.warning( + f"Unexpected response status {response.status_code} when checking for modifications, proceeding with download" + ) return True, None, None except requests.RequestException as e: - logger.warning(f"Failed to check if feed has been modified: {e}, proceeding with download") + logger.warning( + f"Failed to check if feed has been modified: {e}, proceeding with download" + ) return True, None, None -def download_feed_from_url(feed_url: str, output_dir: str = None, force_download: bool = False) -> Optional[str]: + +def download_feed_from_url( + feed_url: str, output_dir: str = None, force_download: bool = False +) -> Optional[str]: """ Download GTFS feed from URL. - + Args: feed_url: URL to download the GTFS feed from output_dir: Directory where reports will be written (used for metadata storage) force_download: If True, skip conditional download checks - + Returns: Path to the directory containing the extracted GTFS files, or None if download was skipped """ - + # Check if we need to download the feed if not force_download and output_dir: - is_modified, cached_etag, cached_last_modified = _check_if_modified(feed_url, output_dir) + is_modified, cached_etag, cached_last_modified = _check_if_modified( + feed_url, output_dir + ) if not is_modified: logger.info("Feed has not been modified, skipping download") return None - + # Create a directory in the system temporary directory - temp_dir = tempfile.mkdtemp(prefix='gtfs_vigo_') + temp_dir = tempfile.mkdtemp(prefix="gtfs_vigo_") # Create a temporary zip file in the temporary directory - zip_filename = os.path.join(temp_dir, 'gtfs_vigo.zip') + zip_filename = os.path.join(temp_dir, "gtfs_vigo.zip") headers = {} response = requests.get(feed_url, headers=headers) @@ -108,23 +124,23 @@ def download_feed_from_url(feed_url: str, output_dir: str = None, force_download if response.status_code != 200: raise Exception(f"Failed to download GTFS data: {response.status_code}") - with open(zip_filename, 'wb') as file: + with open(zip_filename, "wb") as file: file.write(response.content) - + # Extract and save metadata if output_dir is provided if output_dir: - etag = response.headers.get('ETag') - last_modified = response.headers.get('Last-Modified') + etag = response.headers.get("ETag") + last_modified = response.headers.get("Last-Modified") if etag or last_modified: _save_metadata(output_dir, etag, last_modified) # Extract the zip file - with zipfile.ZipFile(zip_filename, 'r') as zip_ref: + with zipfile.ZipFile(zip_filename, "r") as zip_ref: zip_ref.extractall(temp_dir) - + # Clean up the downloaded zip file os.remove(zip_filename) logger.info(f"GTFS feed downloaded from {feed_url} and extracted to {temp_dir}") - return temp_dir \ No newline at end of file + return temp_dir diff --git a/src/gtfs_perstop_report/src/logger.py b/src/gtfs_perstop_report/src/logger.py index 9488076..6c56787 100644 --- a/src/gtfs_perstop_report/src/logger.py +++ b/src/gtfs_perstop_report/src/logger.py @@ -1,12 +1,14 @@ """ Logging configuration for the GTFS application. """ + import logging from colorama import init, Fore, Style # Initialize Colorama (required on Windows) init(autoreset=True) + class ColorFormatter(logging.Formatter): def format(self, record: logging.LogRecord): # Base format @@ -28,16 +30,18 @@ class ColorFormatter(logging.Formatter): # Add color to the entire line formatter = logging.Formatter( - prefix + log_format + Style.RESET_ALL, "%Y-%m-%d %H:%M:%S") + prefix + log_format + Style.RESET_ALL, "%Y-%m-%d %H:%M:%S" + ) return formatter.format(record) + def get_logger(name: str) -> logging.Logger: """ Create and return a logger with the given name. - + Args: name (str): The name of the logger. - + Returns: logging.Logger: Configured logger instance. """ @@ -50,5 +54,5 @@ def get_logger(name: str) -> logging.Logger: console_handler.setLevel(logging.DEBUG) console_handler.setFormatter(ColorFormatter()) logger.addHandler(console_handler) - + return logger diff --git a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py index cb4f336..c7279c5 100644 --- a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py +++ b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py @@ -2,6 +2,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: stop_schedule.proto """Generated protocol buffer code.""" + from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -11,22 +12,21 @@ from google.protobuf import symbol_database as _symbol_database _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13stop_schedule.proto\x12\x05proto\"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\"\x83\x04\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12\"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\x85\x03\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18\" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\t\x12\x1e\n\x16previous_trip_shape_id\x18\x33 \x01(\t\";\n\x05Shape\x12\x10\n\x08shape_id\x18\x01 \x01(\t\x12 \n\x06points\x18\x03 \x03(\x0b\x32\x10.proto.Epsg25829B$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x13stop_schedule.proto\x12\x05proto"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01"\x83\x04\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\x85\x03\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\t\x12\x1e\n\x16previous_trip_shape_id\x18\x33 \x01(\t";\n\x05Shape\x12\x10\n\x08shape_id\x18\x01 \x01(\t\x12 \n\x06points\x18\x03 \x03(\x0b\x32\x10.proto.Epsg25829B$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3' +) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'stop_schedule_pb2', globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "stop_schedule_pb2", globals()) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\252\002!Costasdev.Busurbano.Backend.Types' - _EPSG25829._serialized_start=30 - _EPSG25829._serialized_end=63 - _STOPARRIVALS._serialized_start=66 - _STOPARRIVALS._serialized_end=581 - _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_start=192 - _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_end=581 - _SHAPE._serialized_start=583 - _SHAPE._serialized_end=642 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\252\002!Costasdev.Busurbano.Backend.Types" + _EPSG25829._serialized_start = 30 + _EPSG25829._serialized_end = 63 + _STOPARRIVALS._serialized_start = 66 + _STOPARRIVALS._serialized_end = 581 + _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_start = 192 + _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_end = 581 + _SHAPE._serialized_start = 583 + _SHAPE._serialized_end = 642 # @@protoc_insertion_point(module_scope) diff --git a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi index 355798f..fc55f4e 100644 --- a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi +++ b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi @@ -1,7 +1,13 @@ from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union +from typing import ( + ClassVar as _ClassVar, + Iterable as _Iterable, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) DESCRIPTOR: _descriptor.FileDescriptor @@ -11,7 +17,9 @@ class Epsg25829(_message.Message): Y_FIELD_NUMBER: _ClassVar[int] x: float y: float - def __init__(self, x: _Optional[float] = ..., y: _Optional[float] = ...) -> None: ... + def __init__( + self, x: _Optional[float] = ..., y: _Optional[float] = ... + ) -> None: ... class Shape(_message.Message): __slots__ = ["points", "shape_id"] @@ -19,12 +27,34 @@ class Shape(_message.Message): SHAPE_ID_FIELD_NUMBER: _ClassVar[int] points: _containers.RepeatedCompositeFieldContainer[Epsg25829] shape_id: str - def __init__(self, shape_id: _Optional[str] = ..., points: _Optional[_Iterable[_Union[Epsg25829, _Mapping]]] = ...) -> None: ... + def __init__( + self, + shape_id: _Optional[str] = ..., + points: _Optional[_Iterable[_Union[Epsg25829, _Mapping]]] = ..., + ) -> None: ... class StopArrivals(_message.Message): __slots__ = ["arrivals", "location", "stop_id"] class ScheduledArrival(_message.Message): - __slots__ = ["calling_ssm", "calling_time", "line", "next_streets", "previous_trip_shape_id", "route", "service_id", "shape_dist_traveled", "shape_id", "starting_code", "starting_name", "starting_time", "stop_sequence", "terminus_code", "terminus_name", "terminus_time", "trip_id"] + __slots__ = [ + "calling_ssm", + "calling_time", + "line", + "next_streets", + "previous_trip_shape_id", + "route", + "service_id", + "shape_dist_traveled", + "shape_id", + "starting_code", + "starting_name", + "starting_time", + "stop_sequence", + "terminus_code", + "terminus_name", + "terminus_time", + "trip_id", + ] CALLING_SSM_FIELD_NUMBER: _ClassVar[int] CALLING_TIME_FIELD_NUMBER: _ClassVar[int] LINE_FIELD_NUMBER: _ClassVar[int] @@ -59,11 +89,38 @@ class StopArrivals(_message.Message): terminus_name: str terminus_time: str trip_id: str - def __init__(self, service_id: _Optional[str] = ..., trip_id: _Optional[str] = ..., line: _Optional[str] = ..., route: _Optional[str] = ..., shape_id: _Optional[str] = ..., shape_dist_traveled: _Optional[float] = ..., stop_sequence: _Optional[int] = ..., next_streets: _Optional[_Iterable[str]] = ..., starting_code: _Optional[str] = ..., starting_name: _Optional[str] = ..., starting_time: _Optional[str] = ..., calling_time: _Optional[str] = ..., calling_ssm: _Optional[int] = ..., terminus_code: _Optional[str] = ..., terminus_name: _Optional[str] = ..., terminus_time: _Optional[str] = ..., previous_trip_shape_id: _Optional[str] = ...) -> None: ... + def __init__( + self, + service_id: _Optional[str] = ..., + trip_id: _Optional[str] = ..., + line: _Optional[str] = ..., + route: _Optional[str] = ..., + shape_id: _Optional[str] = ..., + shape_dist_traveled: _Optional[float] = ..., + stop_sequence: _Optional[int] = ..., + next_streets: _Optional[_Iterable[str]] = ..., + starting_code: _Optional[str] = ..., + starting_name: _Optional[str] = ..., + starting_time: _Optional[str] = ..., + calling_time: _Optional[str] = ..., + calling_ssm: _Optional[int] = ..., + terminus_code: _Optional[str] = ..., + terminus_name: _Optional[str] = ..., + terminus_time: _Optional[str] = ..., + previous_trip_shape_id: _Optional[str] = ..., + ) -> None: ... + ARRIVALS_FIELD_NUMBER: _ClassVar[int] LOCATION_FIELD_NUMBER: _ClassVar[int] STOP_ID_FIELD_NUMBER: _ClassVar[int] arrivals: _containers.RepeatedCompositeFieldContainer[StopArrivals.ScheduledArrival] location: Epsg25829 stop_id: str - def __init__(self, stop_id: _Optional[str] = ..., location: _Optional[_Union[Epsg25829, _Mapping]] = ..., arrivals: _Optional[_Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]]] = ...) -> None: ... + def __init__( + self, + stop_id: _Optional[str] = ..., + location: _Optional[_Union[Epsg25829, _Mapping]] = ..., + arrivals: _Optional[ + _Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]] + ] = ..., + ) -> None: ... diff --git a/src/gtfs_perstop_report/src/routes.py b/src/gtfs_perstop_report/src/routes.py index e67a1a4..06cf0e5 100644 --- a/src/gtfs_perstop_report/src/routes.py +++ b/src/gtfs_perstop_report/src/routes.py @@ -1,12 +1,14 @@ """ Module for loading and querying GTFS routes data. """ + import os import csv from src.logger import get_logger logger = get_logger("routes") + def load_routes(feed_dir: str) -> dict[str, dict[str, str]]: """ Load routes data from the GTFS feed. @@ -16,24 +18,26 @@ def load_routes(feed_dir: str) -> dict[str, dict[str, str]]: containing route_short_name and route_color. """ routes: dict[str, dict[str, str]] = {} - routes_file_path = os.path.join(feed_dir, 'routes.txt') + routes_file_path = os.path.join(feed_dir, "routes.txt") try: - with open(routes_file_path, 'r', encoding='utf-8') as routes_file: + with open(routes_file_path, "r", encoding="utf-8") as routes_file: reader = csv.DictReader(routes_file) header = reader.fieldnames or [] - if 'route_color' not in header: - logger.warning("Column 'route_color' not found in routes.txt. Defaulting to black (#000000).") + if "route_color" not in header: + logger.warning( + "Column 'route_color' not found in routes.txt. Defaulting to black (#000000)." + ) for row in reader: - route_id = row['route_id'] - if 'route_color' in row and row['route_color']: - route_color = row['route_color'] + route_id = row["route_id"] + if "route_color" in row and row["route_color"]: + route_color = row["route_color"] else: - route_color = '000000' + route_color = "000000" routes[route_id] = { - 'route_short_name': row['route_short_name'], - 'route_color': route_color + "route_short_name": row["route_short_name"], + "route_color": route_color, } except FileNotFoundError: raise FileNotFoundError(f"Routes file not found at {routes_file_path}") diff --git a/src/gtfs_perstop_report/src/services.py b/src/gtfs_perstop_report/src/services.py index fb1110d..d456e43 100644 --- a/src/gtfs_perstop_report/src/services.py +++ b/src/gtfs_perstop_report/src/services.py @@ -19,26 +19,28 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: ValueError: If the date format is incorrect. """ search_date = date.replace("-", "").replace(":", "").replace("/", "") - weekday = datetime.datetime.strptime(date, '%Y-%m-%d').weekday() + weekday = datetime.datetime.strptime(date, "%Y-%m-%d").weekday() active_services: list[str] = [] try: - with open(os.path.join(feed_dir, 'calendar.txt'), 'r', encoding="utf-8") as calendar_file: + with open( + os.path.join(feed_dir, "calendar.txt"), "r", encoding="utf-8" + ) as calendar_file: lines = calendar_file.readlines() if len(lines) > 1: # First parse the header, get each column's index - header = lines[0].strip().split(',') + header = lines[0].strip().split(",") try: - service_id_index = header.index('service_id') - monday_index = header.index('monday') - tuesday_index = header.index('tuesday') - wednesday_index = header.index('wednesday') - thursday_index = header.index('thursday') - friday_index = header.index('friday') - saturday_index = header.index('saturday') - sunday_index = header.index('sunday') - start_date_index = header.index('start_date') - end_date_index = header.index('end_date') + service_id_index = header.index("service_id") + monday_index = header.index("monday") + tuesday_index = header.index("tuesday") + wednesday_index = header.index("wednesday") + thursday_index = header.index("thursday") + friday_index = header.index("friday") + saturday_index = header.index("saturday") + sunday_index = header.index("sunday") + start_date_index = header.index("start_date") + end_date_index = header.index("end_date") except ValueError as e: logger.error(f"Required column not found in header: {e}") return active_services @@ -50,14 +52,15 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: 3: thursday_index, 4: friday_index, 5: saturday_index, - 6: sunday_index + 6: sunday_index, } for idx, line in enumerate(lines[1:], 1): - parts = line.strip().split(',') + parts = line.strip().split(",") if len(parts) < len(header): logger.warning( - f"Skipping malformed line in calendar.txt line {idx+1}: {line.strip()}") + f"Skipping malformed line in calendar.txt line {idx + 1}: {line.strip()}" + ) continue service_id = parts[service_id_index] @@ -66,24 +69,27 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: end_date = parts[end_date_index] # Check if day of week is active AND date is within the service range - if day_value == '1' and start_date <= search_date <= end_date: + if day_value == "1" and start_date <= search_date <= end_date: active_services.append(service_id) except FileNotFoundError: logger.warning("calendar.txt file not found.") try: - with open(os.path.join(feed_dir, 'calendar_dates.txt'), 'r', encoding="utf-8") as calendar_dates_file: + with open( + os.path.join(feed_dir, "calendar_dates.txt"), "r", encoding="utf-8" + ) as calendar_dates_file: lines = calendar_dates_file.readlines() if len(lines) <= 1: logger.warning( - "calendar_dates.txt file is empty or has only header line, not processing.") + "calendar_dates.txt file is empty or has only header line, not processing." + ) return active_services - header = lines[0].strip().split(',') + header = lines[0].strip().split(",") try: - service_id_index = header.index('service_id') - date_index = header.index('date') - exception_type_index = header.index('exception_type') + service_id_index = header.index("service_id") + date_index = header.index("date") + exception_type_index = header.index("exception_type") except ValueError as e: logger.error(f"Required column not found in header: {e}") return active_services @@ -91,20 +97,21 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: # Now read the rest of the file, find all services where 'date' matches the search_date # Start from 1 to skip header for idx, line in enumerate(lines[1:], 1): - parts = line.strip().split(',') + parts = line.strip().split(",") if len(parts) < len(header): logger.warning( - f"Skipping malformed line in calendar_dates.txt line {idx+1}: {line.strip()}") + f"Skipping malformed line in calendar_dates.txt line {idx + 1}: {line.strip()}" + ) continue service_id = parts[service_id_index] date_value = parts[date_index] exception_type = parts[exception_type_index] - if date_value == search_date and exception_type == '1': + if date_value == search_date and exception_type == "1": active_services.append(service_id) - if date_value == search_date and exception_type == '2': + if date_value == search_date and exception_type == "2": if service_id in active_services: active_services.remove(service_id) except FileNotFoundError: diff --git a/src/gtfs_perstop_report/src/shapes.py b/src/gtfs_perstop_report/src/shapes.py index f49832a..a308999 100644 --- a/src/gtfs_perstop_report/src/shapes.py +++ b/src/gtfs_perstop_report/src/shapes.py @@ -36,13 +36,24 @@ def process_shapes(feed_dir: str, out_dir: str) -> None: try: shape = Shape( shape_id=row["shape_id"], - shape_pt_lat=float(row["shape_pt_lat"]) if row.get("shape_pt_lat") else None, - shape_pt_lon=float(row["shape_pt_lon"]) if row.get("shape_pt_lon") else None, - shape_pt_position=int(row["shape_pt_position"]) if row.get("shape_pt_position") else None, - shape_dist_traveled=float(row["shape_dist_traveled"]) if row.get("shape_dist_traveled") else None, + shape_pt_lat=float(row["shape_pt_lat"]) + if row.get("shape_pt_lat") + else None, + shape_pt_lon=float(row["shape_pt_lon"]) + if row.get("shape_pt_lon") + else None, + shape_pt_position=int(row["shape_pt_position"]) + if row.get("shape_pt_position") + else None, + shape_dist_traveled=float(row["shape_dist_traveled"]) + if row.get("shape_dist_traveled") + else None, ) - if shape.shape_pt_lat is not None and shape.shape_pt_lon is not None: + if ( + shape.shape_pt_lat is not None + and shape.shape_pt_lon is not None + ): shape_pt_25829_x, shape_pt_25829_y = transformer.transform( shape.shape_pt_lon, shape.shape_pt_lat ) @@ -55,18 +66,22 @@ def process_shapes(feed_dir: str, out_dir: str) -> None: except Exception as e: logger.warning( f"Error parsing stops.txt line {row_num}: {e} - line data: {row}" - ) + ) except FileNotFoundError: logger.error(f"File not found: {file_path}") except Exception as e: logger.error(f"Error reading stops.txt: {e}") - # Write shapes to Protobuf files from src.proto.stop_schedule_pb2 import Epsg25829, Shape as PbShape for shape_id, shape_points in shapes.items(): - points = sorted(shape_points, key=lambda sp: sp.shape_pt_position if sp.shape_pt_position is not None else 0) + points = sorted( + shape_points, + key=lambda sp: sp.shape_pt_position + if sp.shape_pt_position is not None + else 0, + ) pb_shape = PbShape( shape_id=shape_id, diff --git a/src/gtfs_perstop_report/src/stop_schedule_pb2.py b/src/gtfs_perstop_report/src/stop_schedule_pb2.py index 285b057..76a1da4 100644 --- a/src/gtfs_perstop_report/src/stop_schedule_pb2.py +++ b/src/gtfs_perstop_report/src/stop_schedule_pb2.py @@ -4,38 +4,37 @@ # source: stop_schedule.proto # Protobuf Python Version: 6.33.0 """Generated protocol buffer code.""" + from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + _runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 33, - 0, - '', - 'stop_schedule.proto' + _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "stop_schedule.proto" ) # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13stop_schedule.proto\x12\x05proto\"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\"\xe3\x03\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12\"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\xe5\x02\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18\" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\tB$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x13stop_schedule.proto\x12\x05proto"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01"\xe3\x03\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\xe5\x02\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\tB$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'stop_schedule_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "stop_schedule_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: - _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'\252\002!Costasdev.Busurbano.Backend.Types' - _globals['_EPSG25829']._serialized_start=30 - _globals['_EPSG25829']._serialized_end=63 - _globals['_STOPARRIVALS']._serialized_start=66 - _globals['_STOPARRIVALS']._serialized_end=549 - _globals['_STOPARRIVALS_SCHEDULEDARRIVAL']._serialized_start=192 - _globals['_STOPARRIVALS_SCHEDULEDARRIVAL']._serialized_end=549 + _globals["DESCRIPTOR"]._loaded_options = None + _globals[ + "DESCRIPTOR" + ]._serialized_options = b"\252\002!Costasdev.Busurbano.Backend.Types" + _globals["_EPSG25829"]._serialized_start = 30 + _globals["_EPSG25829"]._serialized_end = 63 + _globals["_STOPARRIVALS"]._serialized_start = 66 + _globals["_STOPARRIVALS"]._serialized_end = 549 + _globals["_STOPARRIVALS_SCHEDULEDARRIVAL"]._serialized_start = 192 + _globals["_STOPARRIVALS_SCHEDULEDARRIVAL"]._serialized_end = 549 # @@protoc_insertion_point(module_scope) diff --git a/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi b/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi index aa42cdb..c8d7f36 100644 --- a/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi +++ b/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi @@ -12,7 +12,9 @@ class Epsg25829(_message.Message): Y_FIELD_NUMBER: _ClassVar[int] x: float y: float - def __init__(self, x: _Optional[float] = ..., y: _Optional[float] = ...) -> None: ... + def __init__( + self, x: _Optional[float] = ..., y: _Optional[float] = ... + ) -> None: ... class StopArrivals(_message.Message): __slots__ = () @@ -50,11 +52,37 @@ class StopArrivals(_message.Message): terminus_code: str terminus_name: str terminus_time: str - def __init__(self, service_id: _Optional[str] = ..., trip_id: _Optional[str] = ..., line: _Optional[str] = ..., route: _Optional[str] = ..., shape_id: _Optional[str] = ..., shape_dist_traveled: _Optional[float] = ..., stop_sequence: _Optional[int] = ..., next_streets: _Optional[_Iterable[str]] = ..., starting_code: _Optional[str] = ..., starting_name: _Optional[str] = ..., starting_time: _Optional[str] = ..., calling_time: _Optional[str] = ..., calling_ssm: _Optional[int] = ..., terminus_code: _Optional[str] = ..., terminus_name: _Optional[str] = ..., terminus_time: _Optional[str] = ...) -> None: ... + def __init__( + self, + service_id: _Optional[str] = ..., + trip_id: _Optional[str] = ..., + line: _Optional[str] = ..., + route: _Optional[str] = ..., + shape_id: _Optional[str] = ..., + shape_dist_traveled: _Optional[float] = ..., + stop_sequence: _Optional[int] = ..., + next_streets: _Optional[_Iterable[str]] = ..., + starting_code: _Optional[str] = ..., + starting_name: _Optional[str] = ..., + starting_time: _Optional[str] = ..., + calling_time: _Optional[str] = ..., + calling_ssm: _Optional[int] = ..., + terminus_code: _Optional[str] = ..., + terminus_name: _Optional[str] = ..., + terminus_time: _Optional[str] = ..., + ) -> None: ... + STOP_ID_FIELD_NUMBER: _ClassVar[int] LOCATION_FIELD_NUMBER: _ClassVar[int] ARRIVALS_FIELD_NUMBER: _ClassVar[int] stop_id: str location: Epsg25829 arrivals: _containers.RepeatedCompositeFieldContainer[StopArrivals.ScheduledArrival] - def __init__(self, stop_id: _Optional[str] = ..., location: _Optional[_Union[Epsg25829, _Mapping]] = ..., arrivals: _Optional[_Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]]] = ...) -> None: ... + def __init__( + self, + stop_id: _Optional[str] = ..., + location: _Optional[_Union[Epsg25829, _Mapping]] = ..., + arrivals: _Optional[ + _Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]] + ] = ..., + ) -> None: ... diff --git a/src/gtfs_perstop_report/src/stop_times.py b/src/gtfs_perstop_report/src/stop_times.py index f3c3f25..c48f505 100644 --- a/src/gtfs_perstop_report/src/stop_times.py +++ b/src/gtfs_perstop_report/src/stop_times.py @@ -1,6 +1,7 @@ """ Functions for handling GTFS stop_times data. """ + import csv import os from src.logger import get_logger @@ -9,13 +10,25 @@ logger = get_logger("stop_times") STOP_TIMES_BY_FEED: dict[str, dict[str, list["StopTime"]]] = {} -STOP_TIMES_BY_REQUEST: dict[tuple[str, frozenset[str]], dict[str, list["StopTime"]]] = {} +STOP_TIMES_BY_REQUEST: dict[ + tuple[str, frozenset[str]], dict[str, list["StopTime"]] +] = {} + class StopTime: """ Class representing a stop time entry in the GTFS data. """ - def __init__(self, trip_id: str, arrival_time: str, departure_time: str, stop_id: str, stop_sequence: int, shape_dist_traveled: float | None): + + def __init__( + self, + trip_id: str, + arrival_time: str, + departure_time: str, + stop_id: str, + stop_sequence: int, + shape_dist_traveled: float | None, + ): self.trip_id = trip_id self.arrival_time = arrival_time self.departure_time = departure_time @@ -36,47 +49,63 @@ def _load_stop_times_for_feed(feed_dir: str) -> dict[str, list[StopTime]]: stops: dict[str, list[StopTime]] = {} try: - with open(os.path.join(feed_dir, 'stop_times.txt'), 'r', encoding="utf-8", newline='') as stop_times_file: + with open( + os.path.join(feed_dir, "stop_times.txt"), "r", encoding="utf-8", newline="" + ) as stop_times_file: reader = csv.DictReader(stop_times_file) if reader.fieldnames is None: logger.error("stop_times.txt missing header row.") STOP_TIMES_BY_FEED[feed_dir] = {} return STOP_TIMES_BY_FEED[feed_dir] - required_columns = ['trip_id', 'arrival_time', 'departure_time', 'stop_id', 'stop_sequence'] - missing_columns = [col for col in required_columns if col not in reader.fieldnames] + required_columns = [ + "trip_id", + "arrival_time", + "departure_time", + "stop_id", + "stop_sequence", + ] + missing_columns = [ + col for col in required_columns if col not in reader.fieldnames + ] if missing_columns: logger.error(f"Required columns not found in header: {missing_columns}") STOP_TIMES_BY_FEED[feed_dir] = {} return STOP_TIMES_BY_FEED[feed_dir] - has_shape_dist = 'shape_dist_traveled' in reader.fieldnames + has_shape_dist = "shape_dist_traveled" in reader.fieldnames if not has_shape_dist: - logger.warning("Column 'shape_dist_traveled' not found in stop_times.txt. Distances will be set to None.") + logger.warning( + "Column 'shape_dist_traveled' not found in stop_times.txt. Distances will be set to None." + ) for row in reader: - trip_id = row['trip_id'] + trip_id = row["trip_id"] if trip_id not in stops: stops[trip_id] = [] dist = None - if has_shape_dist and row['shape_dist_traveled']: + if has_shape_dist and row["shape_dist_traveled"]: try: - dist = float(row['shape_dist_traveled']) + dist = float(row["shape_dist_traveled"]) except ValueError: pass try: - stops[trip_id].append(StopTime( - trip_id=trip_id, - arrival_time=row['arrival_time'], - departure_time=row['departure_time'], - stop_id=row['stop_id'], - stop_sequence=int(row['stop_sequence']), - shape_dist_traveled=dist - )) + stops[trip_id].append( + StopTime( + trip_id=trip_id, + arrival_time=row["arrival_time"], + departure_time=row["departure_time"], + stop_id=row["stop_id"], + stop_sequence=int(row["stop_sequence"]), + shape_dist_traveled=dist, + ) + ) except ValueError as e: - logger.warning(f"Error parsing stop_sequence for trip {trip_id}: {e}") + logger.warning( + f"Error parsing stop_sequence for trip {trip_id}: {e}" + ) for trip_stop_times in stops.values(): trip_stop_times.sort(key=lambda st: st.stop_sequence) @@ -89,7 +118,9 @@ def _load_stop_times_for_feed(feed_dir: str) -> dict[str, list[StopTime]]: return stops -def get_stops_for_trips(feed_dir: str, trip_ids: list[str]) -> dict[str, list[StopTime]]: +def get_stops_for_trips( + feed_dir: str, trip_ids: list[str] +) -> dict[str, list[StopTime]]: """ Get stops for a list of trip IDs based on the cached 'stop_times.txt' data. """ diff --git a/src/gtfs_perstop_report/src/stops.py b/src/gtfs_perstop_report/src/stops.py index bb54fa4..fb95cf2 100644 --- a/src/gtfs_perstop_report/src/stops.py +++ b/src/gtfs_perstop_report/src/stops.py @@ -36,9 +36,7 @@ def get_all_stops_by_code(feed_dir: str) -> Dict[str, Stop]: all_stops = get_all_stops(feed_dir) for stop in all_stops.values(): - stop_25829_x, stop_25829_y = transformer.transform( - stop.stop_lon, stop.stop_lat - ) + stop_25829_x, stop_25829_y = transformer.transform(stop.stop_lon, stop.stop_lat) stop.stop_25829_x = stop_25829_x stop.stop_25829_y = stop_25829_y diff --git a/src/gtfs_perstop_report/src/street_name.py b/src/gtfs_perstop_report/src/street_name.py index ec6b5b6..81d419b 100644 --- a/src/gtfs_perstop_report/src/street_name.py +++ b/src/gtfs_perstop_report/src/street_name.py @@ -3,7 +3,8 @@ import re re_remove_quotation_marks = re.compile(r'[""”]', re.IGNORECASE) re_anything_before_stopcharacters_with_parentheses = re.compile( - r'^(.*?)(?:,|\s\s|\s-\s| \d| S\/N|\s\()', re.IGNORECASE) + r"^(.*?)(?:,|\s\s|\s-\s| \d| S\/N|\s\()", re.IGNORECASE +) NAME_REPLACEMENTS = { @@ -17,15 +18,13 @@ NAME_REPLACEMENTS = { " do ": " ", " da ": " ", " das ": " ", - "Riós": "Ríos" + "Riós": "Ríos", } def get_street_name(original_name: str) -> str: - original_name = re.sub(re_remove_quotation_marks, - '', original_name).strip() - match = re.match( - re_anything_before_stopcharacters_with_parentheses, original_name) + original_name = re.sub(re_remove_quotation_marks, "", original_name).strip() + match = re.match(re_anything_before_stopcharacters_with_parentheses, original_name) if match: street_name = match.group(1) else: @@ -41,9 +40,9 @@ def get_street_name(original_name: str) -> str: def normalise_stop_name(original_name: str | None) -> str: if original_name is None: - return '' - stop_name = re.sub(re_remove_quotation_marks, '', original_name).strip() + return "" + stop_name = re.sub(re_remove_quotation_marks, "", original_name).strip() - stop_name = stop_name.replace(' ', ', ') + stop_name = stop_name.replace(" ", ", ") return stop_name diff --git a/src/gtfs_perstop_report/src/trips.py b/src/gtfs_perstop_report/src/trips.py index 0cedd26..0de632a 100644 --- a/src/gtfs_perstop_report/src/trips.py +++ b/src/gtfs_perstop_report/src/trips.py @@ -1,16 +1,28 @@ """ Functions for handling GTFS trip data. """ + import os from src.logger import get_logger logger = get_logger("trips") + class TripLine: """ Class representing a trip line in the GTFS data. """ - def __init__(self, route_id: str, service_id: str, trip_id: str, headsign: str, direction_id: int, shape_id: str|None = None, block_id: str|None = None): + + def __init__( + self, + route_id: str, + service_id: str, + trip_id: str, + headsign: str, + direction_id: int, + shape_id: str | None = None, + block_id: str | None = None, + ): self.route_id = route_id self.service_id = service_id self.trip_id = trip_id @@ -28,15 +40,17 @@ class TripLine: TRIPS_BY_SERVICE_ID: dict[str, dict[str, list[TripLine]]] = {} -def get_trips_for_services(feed_dir: str, service_ids: list[str]) -> dict[str, list[TripLine]]: +def get_trips_for_services( + feed_dir: str, service_ids: list[str] +) -> dict[str, list[TripLine]]: """ Get trips for a list of service IDs based on the 'trips.txt' file. Uses caching to avoid reading and parsing the file multiple times. - + Args: feed_dir (str): Directory containing the GTFS feed files. service_ids (list[str]): List of service IDs to find trips for. - + Returns: dict[str, list[TripLine]]: Dictionary mapping service IDs to lists of trip objects. """ @@ -44,52 +58,58 @@ def get_trips_for_services(feed_dir: str, service_ids: list[str]) -> dict[str, l if feed_dir in TRIPS_BY_SERVICE_ID: logger.debug(f"Using cached trips data for {feed_dir}") # Return only the trips for the requested service IDs - return {service_id: TRIPS_BY_SERVICE_ID[feed_dir].get(service_id, []) - for service_id in service_ids} - + return { + service_id: TRIPS_BY_SERVICE_ID[feed_dir].get(service_id, []) + for service_id in service_ids + } + trips: dict[str, list[TripLine]] = {} try: - with open(os.path.join(feed_dir, 'trips.txt'), 'r', encoding="utf-8") as trips_file: + with open( + os.path.join(feed_dir, "trips.txt"), "r", encoding="utf-8" + ) as trips_file: lines = trips_file.readlines() if len(lines) <= 1: logger.warning( - "trips.txt file is empty or has only header line, not processing.") + "trips.txt file is empty or has only header line, not processing." + ) return trips - header = lines[0].strip().split(',') + header = lines[0].strip().split(",") try: - service_id_index = header.index('service_id') - trip_id_index = header.index('trip_id') - route_id_index = header.index('route_id') - headsign_index = header.index('trip_headsign') - direction_id_index = header.index('direction_id') + service_id_index = header.index("service_id") + trip_id_index = header.index("trip_id") + route_id_index = header.index("route_id") + headsign_index = header.index("trip_headsign") + direction_id_index = header.index("direction_id") except ValueError as e: logger.error(f"Required column not found in header: {e}") return trips # Check if shape_id column exists shape_id_index = None - if 'shape_id' in header: - shape_id_index = header.index('shape_id') + if "shape_id" in header: + shape_id_index = header.index("shape_id") else: logger.warning("shape_id column not found in trips.txt") # Check if block_id column exists block_id_index = None - if 'block_id' in header: - block_id_index = header.index('block_id') + if "block_id" in header: + block_id_index = header.index("block_id") else: logger.info("block_id column not found in trips.txt") # Initialize cache for this feed directory TRIPS_BY_SERVICE_ID[feed_dir] = {} - + for line in lines[1:]: - parts = line.strip().split(',') + parts = line.strip().split(",") if len(parts) < len(header): logger.warning( - f"Skipping malformed line in trips.txt: {line.strip()}") + f"Skipping malformed line in trips.txt: {line.strip()}" + ) continue service_id = parts[service_id_index] @@ -115,19 +135,20 @@ def get_trips_for_services(feed_dir: str, service_ids: list[str]) -> dict[str, l trip_id=trip_id, headsign=parts[headsign_index], direction_id=int( - parts[direction_id_index] if parts[direction_id_index] else -1), + parts[direction_id_index] if parts[direction_id_index] else -1 + ), shape_id=shape_id, - block_id=block_id + block_id=block_id, ) - + TRIPS_BY_SERVICE_ID[feed_dir][service_id].append(trip_line) - + # Also build the result for the requested service IDs if service_id in service_ids: if service_id not in trips: trips[service_id] = [] trips[service_id].append(trip_line) - + except FileNotFoundError: logger.warning("trips.txt file not found.") diff --git a/src/gtfs_perstop_report/stop_report.py b/src/gtfs_perstop_report/stop_report.py index f8fdc64..3bbdf11 100644 --- a/src/gtfs_perstop_report/stop_report.py +++ b/src/gtfs_perstop_report/stop_report.py @@ -32,8 +32,7 @@ def parse_args(): default="./output/", help="Directory to write reports to (default: ./output/)", ) - parser.add_argument("--feed-dir", type=str, - help="Path to the feed directory") + parser.add_argument("--feed-dir", type=str, help="Path to the feed directory") parser.add_argument( "--feed-url", type=str, @@ -244,12 +243,9 @@ def build_trip_previous_shape_map( if shift_key not in trips_by_shift: trips_by_shift[shift_key] = [] - trips_by_shift[shift_key].append(( - trip, - trip_number, - first_stop.stop_id, - last_stop.stop_id - )) + trips_by_shift[shift_key].append( + (trip, trip_number, first_stop.stop_id, last_stop.stop_id) + ) # For each shift, sort trips by trip number and link consecutive trips for shift_key, shift_trips in trips_by_shift.items(): # Sort by trip number @@ -262,16 +258,20 @@ def build_trip_previous_shape_map( # Check if trips are consecutive (trip numbers differ by 1), # if previous trip's terminus matches current trip's start, # and if both trips have valid shape IDs - if (current_num == prev_num + 1 and - prev_end_stop == current_start_stop and - prev_trip.shape_id and - current_trip.shape_id): + if ( + current_num == prev_num + 1 + and prev_end_stop == current_start_stop + and prev_trip.shape_id + and current_trip.shape_id + ): trip_previous_shape[current_trip.trip_id] = prev_trip.shape_id return trip_previous_shape -def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]: +def get_stop_arrivals( + feed_dir: str, date: str, provider +) -> Dict[str, List[Dict[str, Any]]]: """ Process trips for the given date and organize stop arrivals. Also includes night services from the previous day (times >= 24:00:00). @@ -293,15 +293,16 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict if not active_services: logger.info("No active services found for the given date.") - logger.info( - f"Found {len(active_services)} active services for date {date}.") + logger.info(f"Found {len(active_services)} active services for date {date}.") # Also get services from the previous day to include night services (times >= 24:00) - prev_date = (datetime.strptime(date, "%Y-%m-%d") - - timedelta(days=1)).strftime("%Y-%m-%d") + prev_date = (datetime.strptime(date, "%Y-%m-%d") - timedelta(days=1)).strftime( + "%Y-%m-%d" + ) prev_services = get_active_services(feed_dir, prev_date) logger.info( - f"Found {len(prev_services)} active services for previous date {prev_date} (for night services).") + f"Found {len(prev_services)} active services for previous date {prev_date} (for night services)." + ) all_services = list(set(active_services + prev_services)) @@ -314,18 +315,17 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict logger.info(f"Found {total_trip_count} trips for active services.") # Get all trip IDs - all_trip_ids = [trip.trip_id for trip_list in trips.values() - for trip in trip_list] + all_trip_ids = [trip.trip_id for trip_list in trips.values() for trip in trip_list] # Get stops for all trips stops_for_all_trips = get_stops_for_trips(feed_dir, all_trip_ids) logger.info(f"Precomputed stops for {len(stops_for_all_trips)} trips.") # Build mapping from trip_id to previous trip's shape_id - trip_previous_shape_map = build_trip_previous_shape_map( - trips, stops_for_all_trips) + trip_previous_shape_map = build_trip_previous_shape_map(trips, stops_for_all_trips) logger.info( - f"Built previous trip shape mapping for {len(trip_previous_shape_map)} trips.") + f"Built previous trip shape mapping for {len(trip_previous_shape_map)} trips." + ) # Load routes information routes = load_routes(feed_dir) @@ -389,8 +389,7 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict stop_to_segment_idx.append(len(segment_names) - 1) # Precompute future street transitions per segment - future_suffix_by_segment: list[tuple[str, ...]] = [ - ()] * len(segment_names) + future_suffix_by_segment: list[tuple[str, ...]] = [()] * len(segment_names) future_tuple: tuple[str, ...] = () for idx in range(len(segment_names) - 1, -1, -1): future_suffix_by_segment[idx] = future_tuple @@ -437,7 +436,7 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict passes.append("previous") for mode in passes: - is_current_mode = (mode == "current") + is_current_mode = mode == "current" for i, (stop_time, _) in enumerate(trip_stop_pairs): # Skip the last stop of the trip (terminus) to avoid duplication @@ -457,11 +456,9 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict continue # Normalize times for display on current day (e.g. 25:30 -> 01:30) - final_starting_time = normalize_gtfs_time( - starting_time) + final_starting_time = normalize_gtfs_time(starting_time) final_calling_time = normalize_gtfs_time(dep_time) - final_terminus_time = normalize_gtfs_time( - terminus_time) + final_terminus_time = normalize_gtfs_time(terminus_time) # SSM should be small (early morning) final_calling_ssm = time_to_seconds(final_calling_time) else: @@ -489,12 +486,10 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict # Format IDs and route using provider-specific logic service_id_fmt = provider.format_service_id(service_id) trip_id_fmt = provider.format_trip_id(trip_id) - route_fmt = provider.format_route( - trip_headsign, terminus_name) + route_fmt = provider.format_route(trip_headsign, terminus_name) # Get previous trip shape_id if available - previous_trip_shape_id = trip_previous_shape_map.get( - trip_id, "") + previous_trip_shape_id = trip_previous_shape_map.get(trip_id, "") stop_arrivals[stop_code].append( { @@ -616,8 +611,7 @@ def main(): feed_dir = args.feed_dir else: logger.info(f"Downloading GTFS feed from {feed_url}...") - feed_dir = download_feed_from_url( - feed_url, output_dir, args.force_download) + feed_dir = download_feed_from_url(feed_url, output_dir, args.force_download) if feed_dir is None: logger.info("Download was skipped (feed not modified). Exiting.") return @@ -642,8 +636,7 @@ def main(): _, stop_summary = process_date(feed_dir, date, output_dir, provider) all_stops_summary[date] = stop_summary - logger.info( - "Finished processing all dates. Beginning with shape transformation.") + logger.info("Finished processing all dates. Beginning with shape transformation.") # Process shapes, converting each coordinate to EPSG:25829 and saving as Protobuf process_shapes(feed_dir, output_dir) -- cgit v1.3