aboutsummaryrefslogtreecommitdiff
path: root/src/gtfs_perstop_report
diff options
context:
space:
mode:
Diffstat (limited to 'src/gtfs_perstop_report')
-rw-r--r--src/gtfs_perstop_report/src/common.py1
-rw-r--r--src/gtfs_perstop_report/src/download.py96
-rw-r--r--src/gtfs_perstop_report/src/logger.py12
-rw-r--r--src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py30
-rw-r--r--src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi69
-rw-r--r--src/gtfs_perstop_report/src/routes.py24
-rw-r--r--src/gtfs_perstop_report/src/services.py61
-rw-r--r--src/gtfs_perstop_report/src/shapes.py31
-rw-r--r--src/gtfs_perstop_report/src/stop_schedule_pb2.py35
-rw-r--r--src/gtfs_perstop_report/src/stop_schedule_pb2.pyi34
-rw-r--r--src/gtfs_perstop_report/src/stop_times.py71
-rw-r--r--src/gtfs_perstop_report/src/stops.py4
-rw-r--r--src/gtfs_perstop_report/src/street_name.py17
-rw-r--r--src/gtfs_perstop_report/src/trips.py75
-rw-r--r--src/gtfs_perstop_report/stop_report.py69
15 files changed, 400 insertions, 229 deletions
diff --git a/src/gtfs_perstop_report/src/common.py b/src/gtfs_perstop_report/src/common.py
index 22769e4..c2df785 100644
--- a/src/gtfs_perstop_report/src/common.py
+++ b/src/gtfs_perstop_report/src/common.py
@@ -40,7 +40,6 @@ def get_all_feed_dates(feed_dir: str) -> List[str]:
if len(result) > 0:
return result
-
# Fallback: use calendar_dates.txt
if os.path.exists(calendar_dates_path):
with open(calendar_dates_path, encoding="utf-8") as f:
diff --git a/src/gtfs_perstop_report/src/download.py b/src/gtfs_perstop_report/src/download.py
index 19125bc..4d0c620 100644
--- a/src/gtfs_perstop_report/src/download.py
+++ b/src/gtfs_perstop_report/src/download.py
@@ -9,39 +9,44 @@ from src.logger import get_logger
logger = get_logger("download")
+
def _get_metadata_path(output_dir: str) -> str:
"""Get the path to the metadata file for storing ETag and Last-Modified info."""
- return os.path.join(output_dir, '.gtfsmetadata')
+ return os.path.join(output_dir, ".gtfsmetadata")
+
def _load_metadata(output_dir: str) -> Optional[dict]:
"""Load existing metadata from the output directory."""
metadata_path = _get_metadata_path(output_dir)
if os.path.exists(metadata_path):
try:
- with open(metadata_path, 'r', encoding='utf-8') as f:
+ with open(metadata_path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load metadata from {metadata_path}: {e}")
return None
-def _save_metadata(output_dir: str, etag: Optional[str], last_modified: Optional[str]) -> None:
+
+def _save_metadata(
+ output_dir: str, etag: Optional[str], last_modified: Optional[str]
+) -> None:
"""Save ETag and Last-Modified metadata to the output directory."""
metadata_path = _get_metadata_path(output_dir)
- metadata = {
- 'etag': etag,
- 'last_modified': last_modified
- }
-
+ metadata = {"etag": etag, "last_modified": last_modified}
+
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
-
+
try:
- with open(metadata_path, 'w', encoding='utf-8') as f:
+ with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
except IOError as e:
logger.warning(f"Failed to save metadata to {metadata_path}: {e}")
-def _check_if_modified(feed_url: str, output_dir: str) -> Tuple[bool, Optional[str], Optional[str]]:
+
+def _check_if_modified(
+ feed_url: str, output_dir: str
+) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Check if the feed has been modified using conditional headers.
Returns (is_modified, etag, last_modified)
@@ -49,58 +54,69 @@ def _check_if_modified(feed_url: str, output_dir: str) -> Tuple[bool, Optional[s
metadata = _load_metadata(output_dir)
if not metadata:
return True, None, None
-
+
headers = {}
- if metadata.get('etag'):
- headers['If-None-Match'] = metadata['etag']
- if metadata.get('last_modified'):
- headers['If-Modified-Since'] = metadata['last_modified']
-
+ if metadata.get("etag"):
+ headers["If-None-Match"] = metadata["etag"]
+ if metadata.get("last_modified"):
+ headers["If-Modified-Since"] = metadata["last_modified"]
+
if not headers:
return True, None, None
-
+
try:
response = requests.head(feed_url, headers=headers)
-
+
if response.status_code == 304:
- logger.info("Feed has not been modified (304 Not Modified), skipping download")
- return False, metadata.get('etag'), metadata.get('last_modified')
+ logger.info(
+ "Feed has not been modified (304 Not Modified), skipping download"
+ )
+ return False, metadata.get("etag"), metadata.get("last_modified")
elif response.status_code == 200:
- etag = response.headers.get('ETag')
- last_modified = response.headers.get('Last-Modified')
+ etag = response.headers.get("ETag")
+ last_modified = response.headers.get("Last-Modified")
return True, etag, last_modified
else:
- logger.warning(f"Unexpected response status {response.status_code} when checking for modifications, proceeding with download")
+ logger.warning(
+ f"Unexpected response status {response.status_code} when checking for modifications, proceeding with download"
+ )
return True, None, None
except requests.RequestException as e:
- logger.warning(f"Failed to check if feed has been modified: {e}, proceeding with download")
+ logger.warning(
+ f"Failed to check if feed has been modified: {e}, proceeding with download"
+ )
return True, None, None
-def download_feed_from_url(feed_url: str, output_dir: str = None, force_download: bool = False) -> Optional[str]:
+
+def download_feed_from_url(
+ feed_url: str, output_dir: str = None, force_download: bool = False
+) -> Optional[str]:
"""
Download GTFS feed from URL.
-
+
Args:
feed_url: URL to download the GTFS feed from
output_dir: Directory where reports will be written (used for metadata storage)
force_download: If True, skip conditional download checks
-
+
Returns:
Path to the directory containing the extracted GTFS files, or None if download was skipped
"""
-
+
# Check if we need to download the feed
if not force_download and output_dir:
- is_modified, cached_etag, cached_last_modified = _check_if_modified(feed_url, output_dir)
+ is_modified, cached_etag, cached_last_modified = _check_if_modified(
+ feed_url, output_dir
+ )
if not is_modified:
logger.info("Feed has not been modified, skipping download")
return None
-
+
# Create a directory in the system temporary directory
- temp_dir = tempfile.mkdtemp(prefix='gtfs_vigo_')
+ temp_dir = tempfile.mkdtemp(prefix="gtfs_vigo_")
# Create a temporary zip file in the temporary directory
- zip_filename = os.path.join(temp_dir, 'gtfs_vigo.zip')
+ zip_filename = os.path.join(temp_dir, "gtfs_vigo.zip")
headers = {}
response = requests.get(feed_url, headers=headers)
@@ -108,23 +124,23 @@ def download_feed_from_url(feed_url: str, output_dir: str = None, force_download
if response.status_code != 200:
raise Exception(f"Failed to download GTFS data: {response.status_code}")
- with open(zip_filename, 'wb') as file:
+ with open(zip_filename, "wb") as file:
file.write(response.content)
-
+
# Extract and save metadata if output_dir is provided
if output_dir:
- etag = response.headers.get('ETag')
- last_modified = response.headers.get('Last-Modified')
+ etag = response.headers.get("ETag")
+ last_modified = response.headers.get("Last-Modified")
if etag or last_modified:
_save_metadata(output_dir, etag, last_modified)
# Extract the zip file
- with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
+ with zipfile.ZipFile(zip_filename, "r") as zip_ref:
zip_ref.extractall(temp_dir)
-
+
# Clean up the downloaded zip file
os.remove(zip_filename)
logger.info(f"GTFS feed downloaded from {feed_url} and extracted to {temp_dir}")
- return temp_dir \ No newline at end of file
+ return temp_dir
diff --git a/src/gtfs_perstop_report/src/logger.py b/src/gtfs_perstop_report/src/logger.py
index 9488076..6c56787 100644
--- a/src/gtfs_perstop_report/src/logger.py
+++ b/src/gtfs_perstop_report/src/logger.py
@@ -1,12 +1,14 @@
"""
Logging configuration for the GTFS application.
"""
+
import logging
from colorama import init, Fore, Style
# Initialize Colorama (required on Windows)
init(autoreset=True)
+
class ColorFormatter(logging.Formatter):
def format(self, record: logging.LogRecord):
# Base format
@@ -28,16 +30,18 @@ class ColorFormatter(logging.Formatter):
# Add color to the entire line
formatter = logging.Formatter(
- prefix + log_format + Style.RESET_ALL, "%Y-%m-%d %H:%M:%S")
+ prefix + log_format + Style.RESET_ALL, "%Y-%m-%d %H:%M:%S"
+ )
return formatter.format(record)
+
def get_logger(name: str) -> logging.Logger:
"""
Create and return a logger with the given name.
-
+
Args:
name (str): The name of the logger.
-
+
Returns:
logging.Logger: Configured logger instance.
"""
@@ -50,5 +54,5 @@ def get_logger(name: str) -> logging.Logger:
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(ColorFormatter())
logger.addHandler(console_handler)
-
+
return logger
diff --git a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py
index cb4f336..c7279c5 100644
--- a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py
+++ b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.py
@@ -2,6 +2,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: stop_schedule.proto
"""Generated protocol buffer code."""
+
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -11,22 +12,21 @@ from google.protobuf import symbol_database as _symbol_database
_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13stop_schedule.proto\x12\x05proto\"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\"\x83\x04\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12\"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\x85\x03\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18\" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\t\x12\x1e\n\x16previous_trip_shape_id\x18\x33 \x01(\t\";\n\x05Shape\x12\x10\n\x08shape_id\x18\x01 \x01(\t\x12 \n\x06points\x18\x03 \x03(\x0b\x32\x10.proto.Epsg25829B$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x13stop_schedule.proto\x12\x05proto"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01"\x83\x04\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\x85\x03\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\t\x12\x1e\n\x16previous_trip_shape_id\x18\x33 \x01(\t";\n\x05Shape\x12\x10\n\x08shape_id\x18\x01 \x01(\t\x12 \n\x06points\x18\x03 \x03(\x0b\x32\x10.proto.Epsg25829B$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3'
+)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'stop_schedule_pb2', globals())
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "stop_schedule_pb2", globals())
if _descriptor._USE_C_DESCRIPTORS == False:
-
- DESCRIPTOR._options = None
- DESCRIPTOR._serialized_options = b'\252\002!Costasdev.Busurbano.Backend.Types'
- _EPSG25829._serialized_start=30
- _EPSG25829._serialized_end=63
- _STOPARRIVALS._serialized_start=66
- _STOPARRIVALS._serialized_end=581
- _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_start=192
- _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_end=581
- _SHAPE._serialized_start=583
- _SHAPE._serialized_end=642
+ DESCRIPTOR._options = None
+ DESCRIPTOR._serialized_options = b"\252\002!Costasdev.Busurbano.Backend.Types"
+ _EPSG25829._serialized_start = 30
+ _EPSG25829._serialized_end = 63
+ _STOPARRIVALS._serialized_start = 66
+ _STOPARRIVALS._serialized_end = 581
+ _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_start = 192
+ _STOPARRIVALS_SCHEDULEDARRIVAL._serialized_end = 581
+ _SHAPE._serialized_start = 583
+ _SHAPE._serialized_end = 642
# @@protoc_insertion_point(module_scope)
diff --git a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi
index 355798f..fc55f4e 100644
--- a/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi
+++ b/src/gtfs_perstop_report/src/proto/stop_schedule_pb2.pyi
@@ -1,7 +1,13 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
-from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
+from typing import (
+ ClassVar as _ClassVar,
+ Iterable as _Iterable,
+ Mapping as _Mapping,
+ Optional as _Optional,
+ Union as _Union,
+)
DESCRIPTOR: _descriptor.FileDescriptor
@@ -11,7 +17,9 @@ class Epsg25829(_message.Message):
Y_FIELD_NUMBER: _ClassVar[int]
x: float
y: float
- def __init__(self, x: _Optional[float] = ..., y: _Optional[float] = ...) -> None: ...
+ def __init__(
+ self, x: _Optional[float] = ..., y: _Optional[float] = ...
+ ) -> None: ...
class Shape(_message.Message):
__slots__ = ["points", "shape_id"]
@@ -19,12 +27,34 @@ class Shape(_message.Message):
SHAPE_ID_FIELD_NUMBER: _ClassVar[int]
points: _containers.RepeatedCompositeFieldContainer[Epsg25829]
shape_id: str
- def __init__(self, shape_id: _Optional[str] = ..., points: _Optional[_Iterable[_Union[Epsg25829, _Mapping]]] = ...) -> None: ...
+ def __init__(
+ self,
+ shape_id: _Optional[str] = ...,
+ points: _Optional[_Iterable[_Union[Epsg25829, _Mapping]]] = ...,
+ ) -> None: ...
class StopArrivals(_message.Message):
__slots__ = ["arrivals", "location", "stop_id"]
class ScheduledArrival(_message.Message):
- __slots__ = ["calling_ssm", "calling_time", "line", "next_streets", "previous_trip_shape_id", "route", "service_id", "shape_dist_traveled", "shape_id", "starting_code", "starting_name", "starting_time", "stop_sequence", "terminus_code", "terminus_name", "terminus_time", "trip_id"]
+ __slots__ = [
+ "calling_ssm",
+ "calling_time",
+ "line",
+ "next_streets",
+ "previous_trip_shape_id",
+ "route",
+ "service_id",
+ "shape_dist_traveled",
+ "shape_id",
+ "starting_code",
+ "starting_name",
+ "starting_time",
+ "stop_sequence",
+ "terminus_code",
+ "terminus_name",
+ "terminus_time",
+ "trip_id",
+ ]
CALLING_SSM_FIELD_NUMBER: _ClassVar[int]
CALLING_TIME_FIELD_NUMBER: _ClassVar[int]
LINE_FIELD_NUMBER: _ClassVar[int]
@@ -59,11 +89,38 @@ class StopArrivals(_message.Message):
terminus_name: str
terminus_time: str
trip_id: str
- def __init__(self, service_id: _Optional[str] = ..., trip_id: _Optional[str] = ..., line: _Optional[str] = ..., route: _Optional[str] = ..., shape_id: _Optional[str] = ..., shape_dist_traveled: _Optional[float] = ..., stop_sequence: _Optional[int] = ..., next_streets: _Optional[_Iterable[str]] = ..., starting_code: _Optional[str] = ..., starting_name: _Optional[str] = ..., starting_time: _Optional[str] = ..., calling_time: _Optional[str] = ..., calling_ssm: _Optional[int] = ..., terminus_code: _Optional[str] = ..., terminus_name: _Optional[str] = ..., terminus_time: _Optional[str] = ..., previous_trip_shape_id: _Optional[str] = ...) -> None: ...
+ def __init__(
+ self,
+ service_id: _Optional[str] = ...,
+ trip_id: _Optional[str] = ...,
+ line: _Optional[str] = ...,
+ route: _Optional[str] = ...,
+ shape_id: _Optional[str] = ...,
+ shape_dist_traveled: _Optional[float] = ...,
+ stop_sequence: _Optional[int] = ...,
+ next_streets: _Optional[_Iterable[str]] = ...,
+ starting_code: _Optional[str] = ...,
+ starting_name: _Optional[str] = ...,
+ starting_time: _Optional[str] = ...,
+ calling_time: _Optional[str] = ...,
+ calling_ssm: _Optional[int] = ...,
+ terminus_code: _Optional[str] = ...,
+ terminus_name: _Optional[str] = ...,
+ terminus_time: _Optional[str] = ...,
+ previous_trip_shape_id: _Optional[str] = ...,
+ ) -> None: ...
+
ARRIVALS_FIELD_NUMBER: _ClassVar[int]
LOCATION_FIELD_NUMBER: _ClassVar[int]
STOP_ID_FIELD_NUMBER: _ClassVar[int]
arrivals: _containers.RepeatedCompositeFieldContainer[StopArrivals.ScheduledArrival]
location: Epsg25829
stop_id: str
- def __init__(self, stop_id: _Optional[str] = ..., location: _Optional[_Union[Epsg25829, _Mapping]] = ..., arrivals: _Optional[_Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]]] = ...) -> None: ...
+ def __init__(
+ self,
+ stop_id: _Optional[str] = ...,
+ location: _Optional[_Union[Epsg25829, _Mapping]] = ...,
+ arrivals: _Optional[
+ _Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]]
+ ] = ...,
+ ) -> None: ...
diff --git a/src/gtfs_perstop_report/src/routes.py b/src/gtfs_perstop_report/src/routes.py
index e67a1a4..06cf0e5 100644
--- a/src/gtfs_perstop_report/src/routes.py
+++ b/src/gtfs_perstop_report/src/routes.py
@@ -1,12 +1,14 @@
"""
Module for loading and querying GTFS routes data.
"""
+
import os
import csv
from src.logger import get_logger
logger = get_logger("routes")
+
def load_routes(feed_dir: str) -> dict[str, dict[str, str]]:
"""
Load routes data from the GTFS feed.
@@ -16,24 +18,26 @@ def load_routes(feed_dir: str) -> dict[str, dict[str, str]]:
containing route_short_name and route_color.
"""
routes: dict[str, dict[str, str]] = {}
- routes_file_path = os.path.join(feed_dir, 'routes.txt')
+ routes_file_path = os.path.join(feed_dir, "routes.txt")
try:
- with open(routes_file_path, 'r', encoding='utf-8') as routes_file:
+ with open(routes_file_path, "r", encoding="utf-8") as routes_file:
reader = csv.DictReader(routes_file)
header = reader.fieldnames or []
- if 'route_color' not in header:
- logger.warning("Column 'route_color' not found in routes.txt. Defaulting to black (#000000).")
+ if "route_color" not in header:
+ logger.warning(
+ "Column 'route_color' not found in routes.txt. Defaulting to black (#000000)."
+ )
for row in reader:
- route_id = row['route_id']
- if 'route_color' in row and row['route_color']:
- route_color = row['route_color']
+ route_id = row["route_id"]
+ if "route_color" in row and row["route_color"]:
+ route_color = row["route_color"]
else:
- route_color = '000000'
+ route_color = "000000"
routes[route_id] = {
- 'route_short_name': row['route_short_name'],
- 'route_color': route_color
+ "route_short_name": row["route_short_name"],
+ "route_color": route_color,
}
except FileNotFoundError:
raise FileNotFoundError(f"Routes file not found at {routes_file_path}")
diff --git a/src/gtfs_perstop_report/src/services.py b/src/gtfs_perstop_report/src/services.py
index fb1110d..d456e43 100644
--- a/src/gtfs_perstop_report/src/services.py
+++ b/src/gtfs_perstop_report/src/services.py
@@ -19,26 +19,28 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
ValueError: If the date format is incorrect.
"""
search_date = date.replace("-", "").replace(":", "").replace("/", "")
- weekday = datetime.datetime.strptime(date, '%Y-%m-%d').weekday()
+ weekday = datetime.datetime.strptime(date, "%Y-%m-%d").weekday()
active_services: list[str] = []
try:
- with open(os.path.join(feed_dir, 'calendar.txt'), 'r', encoding="utf-8") as calendar_file:
+ with open(
+ os.path.join(feed_dir, "calendar.txt"), "r", encoding="utf-8"
+ ) as calendar_file:
lines = calendar_file.readlines()
if len(lines) > 1:
# First parse the header, get each column's index
- header = lines[0].strip().split(',')
+ header = lines[0].strip().split(",")
try:
- service_id_index = header.index('service_id')
- monday_index = header.index('monday')
- tuesday_index = header.index('tuesday')
- wednesday_index = header.index('wednesday')
- thursday_index = header.index('thursday')
- friday_index = header.index('friday')
- saturday_index = header.index('saturday')
- sunday_index = header.index('sunday')
- start_date_index = header.index('start_date')
- end_date_index = header.index('end_date')
+ service_id_index = header.index("service_id")
+ monday_index = header.index("monday")
+ tuesday_index = header.index("tuesday")
+ wednesday_index = header.index("wednesday")
+ thursday_index = header.index("thursday")
+ friday_index = header.index("friday")
+ saturday_index = header.index("saturday")
+ sunday_index = header.index("sunday")
+ start_date_index = header.index("start_date")
+ end_date_index = header.index("end_date")
except ValueError as e:
logger.error(f"Required column not found in header: {e}")
return active_services
@@ -50,14 +52,15 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
3: thursday_index,
4: friday_index,
5: saturday_index,
- 6: sunday_index
+ 6: sunday_index,
}
for idx, line in enumerate(lines[1:], 1):
- parts = line.strip().split(',')
+ parts = line.strip().split(",")
if len(parts) < len(header):
logger.warning(
- f"Skipping malformed line in calendar.txt line {idx+1}: {line.strip()}")
+ f"Skipping malformed line in calendar.txt line {idx + 1}: {line.strip()}"
+ )
continue
service_id = parts[service_id_index]
@@ -66,24 +69,27 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
end_date = parts[end_date_index]
# Check if day of week is active AND date is within the service range
- if day_value == '1' and start_date <= search_date <= end_date:
+ if day_value == "1" and start_date <= search_date <= end_date:
active_services.append(service_id)
except FileNotFoundError:
logger.warning("calendar.txt file not found.")
try:
- with open(os.path.join(feed_dir, 'calendar_dates.txt'), 'r', encoding="utf-8") as calendar_dates_file:
+ with open(
+ os.path.join(feed_dir, "calendar_dates.txt"), "r", encoding="utf-8"
+ ) as calendar_dates_file:
lines = calendar_dates_file.readlines()
if len(lines) <= 1:
logger.warning(
- "calendar_dates.txt file is empty or has only header line, not processing.")
+ "calendar_dates.txt file is empty or has only header line, not processing."
+ )
return active_services
- header = lines[0].strip().split(',')
+ header = lines[0].strip().split(",")
try:
- service_id_index = header.index('service_id')
- date_index = header.index('date')
- exception_type_index = header.index('exception_type')
+ service_id_index = header.index("service_id")
+ date_index = header.index("date")
+ exception_type_index = header.index("exception_type")
except ValueError as e:
logger.error(f"Required column not found in header: {e}")
return active_services
@@ -91,20 +97,21 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
# Now read the rest of the file, find all services where 'date' matches the search_date
# Start from 1 to skip header
for idx, line in enumerate(lines[1:], 1):
- parts = line.strip().split(',')
+ parts = line.strip().split(",")
if len(parts) < len(header):
logger.warning(
- f"Skipping malformed line in calendar_dates.txt line {idx+1}: {line.strip()}")
+ f"Skipping malformed line in calendar_dates.txt line {idx + 1}: {line.strip()}"
+ )
continue
service_id = parts[service_id_index]
date_value = parts[date_index]
exception_type = parts[exception_type_index]
- if date_value == search_date and exception_type == '1':
+ if date_value == search_date and exception_type == "1":
active_services.append(service_id)
- if date_value == search_date and exception_type == '2':
+ if date_value == search_date and exception_type == "2":
if service_id in active_services:
active_services.remove(service_id)
except FileNotFoundError:
diff --git a/src/gtfs_perstop_report/src/shapes.py b/src/gtfs_perstop_report/src/shapes.py
index f49832a..a308999 100644
--- a/src/gtfs_perstop_report/src/shapes.py
+++ b/src/gtfs_perstop_report/src/shapes.py
@@ -36,13 +36,24 @@ def process_shapes(feed_dir: str, out_dir: str) -> None:
try:
shape = Shape(
shape_id=row["shape_id"],
- shape_pt_lat=float(row["shape_pt_lat"]) if row.get("shape_pt_lat") else None,
- shape_pt_lon=float(row["shape_pt_lon"]) if row.get("shape_pt_lon") else None,
- shape_pt_position=int(row["shape_pt_position"]) if row.get("shape_pt_position") else None,
- shape_dist_traveled=float(row["shape_dist_traveled"]) if row.get("shape_dist_traveled") else None,
+ shape_pt_lat=float(row["shape_pt_lat"])
+ if row.get("shape_pt_lat")
+ else None,
+ shape_pt_lon=float(row["shape_pt_lon"])
+ if row.get("shape_pt_lon")
+ else None,
+ shape_pt_position=int(row["shape_pt_position"])
+ if row.get("shape_pt_position")
+ else None,
+ shape_dist_traveled=float(row["shape_dist_traveled"])
+ if row.get("shape_dist_traveled")
+ else None,
)
- if shape.shape_pt_lat is not None and shape.shape_pt_lon is not None:
+ if (
+ shape.shape_pt_lat is not None
+ and shape.shape_pt_lon is not None
+ ):
shape_pt_25829_x, shape_pt_25829_y = transformer.transform(
shape.shape_pt_lon, shape.shape_pt_lat
)
@@ -55,18 +66,22 @@ def process_shapes(feed_dir: str, out_dir: str) -> None:
except Exception as e:
logger.warning(
f"Error parsing stops.txt line {row_num}: {e} - line data: {row}"
- )
+ )
except FileNotFoundError:
logger.error(f"File not found: {file_path}")
except Exception as e:
logger.error(f"Error reading stops.txt: {e}")
-
# Write shapes to Protobuf files
from src.proto.stop_schedule_pb2 import Epsg25829, Shape as PbShape
for shape_id, shape_points in shapes.items():
- points = sorted(shape_points, key=lambda sp: sp.shape_pt_position if sp.shape_pt_position is not None else 0)
+ points = sorted(
+ shape_points,
+ key=lambda sp: sp.shape_pt_position
+ if sp.shape_pt_position is not None
+ else 0,
+ )
pb_shape = PbShape(
shape_id=shape_id,
diff --git a/src/gtfs_perstop_report/src/stop_schedule_pb2.py b/src/gtfs_perstop_report/src/stop_schedule_pb2.py
index 285b057..76a1da4 100644
--- a/src/gtfs_perstop_report/src/stop_schedule_pb2.py
+++ b/src/gtfs_perstop_report/src/stop_schedule_pb2.py
@@ -4,38 +4,37 @@
# source: stop_schedule.proto
# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
+
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
+
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC,
- 6,
- 33,
- 0,
- '',
- 'stop_schedule.proto'
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "stop_schedule.proto"
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13stop_schedule.proto\x12\x05proto\"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\"\xe3\x03\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12\"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\xe5\x02\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18\" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\tB$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x13stop_schedule.proto\x12\x05proto"!\n\tEpsg25829\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01"\xe3\x03\n\x0cStopArrivals\x12\x0f\n\x07stop_id\x18\x01 \x01(\t\x12"\n\x08location\x18\x03 \x01(\x0b\x32\x10.proto.Epsg25829\x12\x36\n\x08\x61rrivals\x18\x05 \x03(\x0b\x32$.proto.StopArrivals.ScheduledArrival\x1a\xe5\x02\n\x10ScheduledArrival\x12\x12\n\nservice_id\x18\x01 \x01(\t\x12\x0f\n\x07trip_id\x18\x02 \x01(\t\x12\x0c\n\x04line\x18\x03 \x01(\t\x12\r\n\x05route\x18\x04 \x01(\t\x12\x10\n\x08shape_id\x18\x05 \x01(\t\x12\x1b\n\x13shape_dist_traveled\x18\x06 \x01(\x01\x12\x15\n\rstop_sequence\x18\x0b \x01(\r\x12\x14\n\x0cnext_streets\x18\x0c \x03(\t\x12\x15\n\rstarting_code\x18\x15 \x01(\t\x12\x15\n\rstarting_name\x18\x16 \x01(\t\x12\x15\n\rstarting_time\x18\x17 \x01(\t\x12\x14\n\x0c\x63\x61lling_time\x18! \x01(\t\x12\x13\n\x0b\x63\x61lling_ssm\x18" \x01(\r\x12\x15\n\rterminus_code\x18) \x01(\t\x12\x15\n\rterminus_name\x18* \x01(\t\x12\x15\n\rterminus_time\x18+ \x01(\tB$\xaa\x02!Costasdev.Busurbano.Backend.Typesb\x06proto3'
+)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'stop_schedule_pb2', _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "stop_schedule_pb2", _globals)
if not _descriptor._USE_C_DESCRIPTORS:
- _globals['DESCRIPTOR']._loaded_options = None
- _globals['DESCRIPTOR']._serialized_options = b'\252\002!Costasdev.Busurbano.Backend.Types'
- _globals['_EPSG25829']._serialized_start=30
- _globals['_EPSG25829']._serialized_end=63
- _globals['_STOPARRIVALS']._serialized_start=66
- _globals['_STOPARRIVALS']._serialized_end=549
- _globals['_STOPARRIVALS_SCHEDULEDARRIVAL']._serialized_start=192
- _globals['_STOPARRIVALS_SCHEDULEDARRIVAL']._serialized_end=549
+ _globals["DESCRIPTOR"]._loaded_options = None
+ _globals[
+ "DESCRIPTOR"
+ ]._serialized_options = b"\252\002!Costasdev.Busurbano.Backend.Types"
+ _globals["_EPSG25829"]._serialized_start = 30
+ _globals["_EPSG25829"]._serialized_end = 63
+ _globals["_STOPARRIVALS"]._serialized_start = 66
+ _globals["_STOPARRIVALS"]._serialized_end = 549
+ _globals["_STOPARRIVALS_SCHEDULEDARRIVAL"]._serialized_start = 192
+ _globals["_STOPARRIVALS_SCHEDULEDARRIVAL"]._serialized_end = 549
# @@protoc_insertion_point(module_scope)
diff --git a/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi b/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi
index aa42cdb..c8d7f36 100644
--- a/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi
+++ b/src/gtfs_perstop_report/src/stop_schedule_pb2.pyi
@@ -12,7 +12,9 @@ class Epsg25829(_message.Message):
Y_FIELD_NUMBER: _ClassVar[int]
x: float
y: float
- def __init__(self, x: _Optional[float] = ..., y: _Optional[float] = ...) -> None: ...
+ def __init__(
+ self, x: _Optional[float] = ..., y: _Optional[float] = ...
+ ) -> None: ...
class StopArrivals(_message.Message):
__slots__ = ()
@@ -50,11 +52,37 @@ class StopArrivals(_message.Message):
terminus_code: str
terminus_name: str
terminus_time: str
- def __init__(self, service_id: _Optional[str] = ..., trip_id: _Optional[str] = ..., line: _Optional[str] = ..., route: _Optional[str] = ..., shape_id: _Optional[str] = ..., shape_dist_traveled: _Optional[float] = ..., stop_sequence: _Optional[int] = ..., next_streets: _Optional[_Iterable[str]] = ..., starting_code: _Optional[str] = ..., starting_name: _Optional[str] = ..., starting_time: _Optional[str] = ..., calling_time: _Optional[str] = ..., calling_ssm: _Optional[int] = ..., terminus_code: _Optional[str] = ..., terminus_name: _Optional[str] = ..., terminus_time: _Optional[str] = ...) -> None: ...
+ def __init__(
+ self,
+ service_id: _Optional[str] = ...,
+ trip_id: _Optional[str] = ...,
+ line: _Optional[str] = ...,
+ route: _Optional[str] = ...,
+ shape_id: _Optional[str] = ...,
+ shape_dist_traveled: _Optional[float] = ...,
+ stop_sequence: _Optional[int] = ...,
+ next_streets: _Optional[_Iterable[str]] = ...,
+ starting_code: _Optional[str] = ...,
+ starting_name: _Optional[str] = ...,
+ starting_time: _Optional[str] = ...,
+ calling_time: _Optional[str] = ...,
+ calling_ssm: _Optional[int] = ...,
+ terminus_code: _Optional[str] = ...,
+ terminus_name: _Optional[str] = ...,
+ terminus_time: _Optional[str] = ...,
+ ) -> None: ...
+
STOP_ID_FIELD_NUMBER: _ClassVar[int]
LOCATION_FIELD_NUMBER: _ClassVar[int]
ARRIVALS_FIELD_NUMBER: _ClassVar[int]
stop_id: str
location: Epsg25829
arrivals: _containers.RepeatedCompositeFieldContainer[StopArrivals.ScheduledArrival]
- def __init__(self, stop_id: _Optional[str] = ..., location: _Optional[_Union[Epsg25829, _Mapping]] = ..., arrivals: _Optional[_Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]]] = ...) -> None: ...
+ def __init__(
+ self,
+ stop_id: _Optional[str] = ...,
+ location: _Optional[_Union[Epsg25829, _Mapping]] = ...,
+ arrivals: _Optional[
+ _Iterable[_Union[StopArrivals.ScheduledArrival, _Mapping]]
+ ] = ...,
+ ) -> None: ...
diff --git a/src/gtfs_perstop_report/src/stop_times.py b/src/gtfs_perstop_report/src/stop_times.py
index f3c3f25..c48f505 100644
--- a/src/gtfs_perstop_report/src/stop_times.py
+++ b/src/gtfs_perstop_report/src/stop_times.py
@@ -1,6 +1,7 @@
"""
Functions for handling GTFS stop_times data.
"""
+
import csv
import os
from src.logger import get_logger
@@ -9,13 +10,25 @@ logger = get_logger("stop_times")
STOP_TIMES_BY_FEED: dict[str, dict[str, list["StopTime"]]] = {}
-STOP_TIMES_BY_REQUEST: dict[tuple[str, frozenset[str]], dict[str, list["StopTime"]]] = {}
+STOP_TIMES_BY_REQUEST: dict[
+ tuple[str, frozenset[str]], dict[str, list["StopTime"]]
+] = {}
+
class StopTime:
"""
Class representing a stop time entry in the GTFS data.
"""
- def __init__(self, trip_id: str, arrival_time: str, departure_time: str, stop_id: str, stop_sequence: int, shape_dist_traveled: float | None):
+
+ def __init__(
+ self,
+ trip_id: str,
+ arrival_time: str,
+ departure_time: str,
+ stop_id: str,
+ stop_sequence: int,
+ shape_dist_traveled: float | None,
+ ):
self.trip_id = trip_id
self.arrival_time = arrival_time
self.departure_time = departure_time
@@ -36,47 +49,63 @@ def _load_stop_times_for_feed(feed_dir: str) -> dict[str, list[StopTime]]:
stops: dict[str, list[StopTime]] = {}
try:
- with open(os.path.join(feed_dir, 'stop_times.txt'), 'r', encoding="utf-8", newline='') as stop_times_file:
+ with open(
+ os.path.join(feed_dir, "stop_times.txt"), "r", encoding="utf-8", newline=""
+ ) as stop_times_file:
reader = csv.DictReader(stop_times_file)
if reader.fieldnames is None:
logger.error("stop_times.txt missing header row.")
STOP_TIMES_BY_FEED[feed_dir] = {}
return STOP_TIMES_BY_FEED[feed_dir]
- required_columns = ['trip_id', 'arrival_time', 'departure_time', 'stop_id', 'stop_sequence']
- missing_columns = [col for col in required_columns if col not in reader.fieldnames]
+ required_columns = [
+ "trip_id",
+ "arrival_time",
+ "departure_time",
+ "stop_id",
+ "stop_sequence",
+ ]
+ missing_columns = [
+ col for col in required_columns if col not in reader.fieldnames
+ ]
if missing_columns:
logger.error(f"Required columns not found in header: {missing_columns}")
STOP_TIMES_BY_FEED[feed_dir] = {}
return STOP_TIMES_BY_FEED[feed_dir]
- has_shape_dist = 'shape_dist_traveled' in reader.fieldnames
+ has_shape_dist = "shape_dist_traveled" in reader.fieldnames
if not has_shape_dist:
- logger.warning("Column 'shape_dist_traveled' not found in stop_times.txt. Distances will be set to None.")
+ logger.warning(
+ "Column 'shape_dist_traveled' not found in stop_times.txt. Distances will be set to None."
+ )
for row in reader:
- trip_id = row['trip_id']
+ trip_id = row["trip_id"]
if trip_id not in stops:
stops[trip_id] = []
dist = None
- if has_shape_dist and row['shape_dist_traveled']:
+ if has_shape_dist and row["shape_dist_traveled"]:
try:
- dist = float(row['shape_dist_traveled'])
+ dist = float(row["shape_dist_traveled"])
except ValueError:
pass
try:
- stops[trip_id].append(StopTime(
- trip_id=trip_id,
- arrival_time=row['arrival_time'],
- departure_time=row['departure_time'],
- stop_id=row['stop_id'],
- stop_sequence=int(row['stop_sequence']),
- shape_dist_traveled=dist
- ))
+ stops[trip_id].append(
+ StopTime(
+ trip_id=trip_id,
+ arrival_time=row["arrival_time"],
+ departure_time=row["departure_time"],
+ stop_id=row["stop_id"],
+ stop_sequence=int(row["stop_sequence"]),
+ shape_dist_traveled=dist,
+ )
+ )
except ValueError as e:
- logger.warning(f"Error parsing stop_sequence for trip {trip_id}: {e}")
+ logger.warning(
+ f"Error parsing stop_sequence for trip {trip_id}: {e}"
+ )
for trip_stop_times in stops.values():
trip_stop_times.sort(key=lambda st: st.stop_sequence)
@@ -89,7 +118,9 @@ def _load_stop_times_for_feed(feed_dir: str) -> dict[str, list[StopTime]]:
return stops
-def get_stops_for_trips(feed_dir: str, trip_ids: list[str]) -> dict[str, list[StopTime]]:
+def get_stops_for_trips(
+ feed_dir: str, trip_ids: list[str]
+) -> dict[str, list[StopTime]]:
"""
Get stops for a list of trip IDs based on the cached 'stop_times.txt' data.
"""
diff --git a/src/gtfs_perstop_report/src/stops.py b/src/gtfs_perstop_report/src/stops.py
index bb54fa4..fb95cf2 100644
--- a/src/gtfs_perstop_report/src/stops.py
+++ b/src/gtfs_perstop_report/src/stops.py
@@ -36,9 +36,7 @@ def get_all_stops_by_code(feed_dir: str) -> Dict[str, Stop]:
all_stops = get_all_stops(feed_dir)
for stop in all_stops.values():
- stop_25829_x, stop_25829_y = transformer.transform(
- stop.stop_lon, stop.stop_lat
- )
+ stop_25829_x, stop_25829_y = transformer.transform(stop.stop_lon, stop.stop_lat)
stop.stop_25829_x = stop_25829_x
stop.stop_25829_y = stop_25829_y
diff --git a/src/gtfs_perstop_report/src/street_name.py b/src/gtfs_perstop_report/src/street_name.py
index ec6b5b6..81d419b 100644
--- a/src/gtfs_perstop_report/src/street_name.py
+++ b/src/gtfs_perstop_report/src/street_name.py
@@ -3,7 +3,8 @@ import re
re_remove_quotation_marks = re.compile(r'[""”]', re.IGNORECASE)
re_anything_before_stopcharacters_with_parentheses = re.compile(
- r'^(.*?)(?:,|\s\s|\s-\s| \d| S\/N|\s\()', re.IGNORECASE)
+ r"^(.*?)(?:,|\s\s|\s-\s| \d| S\/N|\s\()", re.IGNORECASE
+)
NAME_REPLACEMENTS = {
@@ -17,15 +18,13 @@ NAME_REPLACEMENTS = {
" do ": " ",
" da ": " ",
" das ": " ",
- "Riós": "Ríos"
+ "Riós": "Ríos",
}
def get_street_name(original_name: str) -> str:
- original_name = re.sub(re_remove_quotation_marks,
- '', original_name).strip()
- match = re.match(
- re_anything_before_stopcharacters_with_parentheses, original_name)
+ original_name = re.sub(re_remove_quotation_marks, "", original_name).strip()
+ match = re.match(re_anything_before_stopcharacters_with_parentheses, original_name)
if match:
street_name = match.group(1)
else:
@@ -41,9 +40,9 @@ def get_street_name(original_name: str) -> str:
def normalise_stop_name(original_name: str | None) -> str:
if original_name is None:
- return ''
- stop_name = re.sub(re_remove_quotation_marks, '', original_name).strip()
+ return ""
+ stop_name = re.sub(re_remove_quotation_marks, "", original_name).strip()
- stop_name = stop_name.replace(' ', ', ')
+ stop_name = stop_name.replace(" ", ", ")
return stop_name
diff --git a/src/gtfs_perstop_report/src/trips.py b/src/gtfs_perstop_report/src/trips.py
index 0cedd26..0de632a 100644
--- a/src/gtfs_perstop_report/src/trips.py
+++ b/src/gtfs_perstop_report/src/trips.py
@@ -1,16 +1,28 @@
"""
Functions for handling GTFS trip data.
"""
+
import os
from src.logger import get_logger
logger = get_logger("trips")
+
class TripLine:
"""
Class representing a trip line in the GTFS data.
"""
- def __init__(self, route_id: str, service_id: str, trip_id: str, headsign: str, direction_id: int, shape_id: str|None = None, block_id: str|None = None):
+
+ def __init__(
+ self,
+ route_id: str,
+ service_id: str,
+ trip_id: str,
+ headsign: str,
+ direction_id: int,
+ shape_id: str | None = None,
+ block_id: str | None = None,
+ ):
self.route_id = route_id
self.service_id = service_id
self.trip_id = trip_id
@@ -28,15 +40,17 @@ class TripLine:
TRIPS_BY_SERVICE_ID: dict[str, dict[str, list[TripLine]]] = {}
-def get_trips_for_services(feed_dir: str, service_ids: list[str]) -> dict[str, list[TripLine]]:
+def get_trips_for_services(
+ feed_dir: str, service_ids: list[str]
+) -> dict[str, list[TripLine]]:
"""
Get trips for a list of service IDs based on the 'trips.txt' file.
Uses caching to avoid reading and parsing the file multiple times.
-
+
Args:
feed_dir (str): Directory containing the GTFS feed files.
service_ids (list[str]): List of service IDs to find trips for.
-
+
Returns:
dict[str, list[TripLine]]: Dictionary mapping service IDs to lists of trip objects.
"""
@@ -44,52 +58,58 @@ def get_trips_for_services(feed_dir: str, service_ids: list[str]) -> dict[str, l
if feed_dir in TRIPS_BY_SERVICE_ID:
logger.debug(f"Using cached trips data for {feed_dir}")
# Return only the trips for the requested service IDs
- return {service_id: TRIPS_BY_SERVICE_ID[feed_dir].get(service_id, [])
- for service_id in service_ids}
-
+ return {
+ service_id: TRIPS_BY_SERVICE_ID[feed_dir].get(service_id, [])
+ for service_id in service_ids
+ }
+
trips: dict[str, list[TripLine]] = {}
try:
- with open(os.path.join(feed_dir, 'trips.txt'), 'r', encoding="utf-8") as trips_file:
+ with open(
+ os.path.join(feed_dir, "trips.txt"), "r", encoding="utf-8"
+ ) as trips_file:
lines = trips_file.readlines()
if len(lines) <= 1:
logger.warning(
- "trips.txt file is empty or has only header line, not processing.")
+ "trips.txt file is empty or has only header line, not processing."
+ )
return trips
- header = lines[0].strip().split(',')
+ header = lines[0].strip().split(",")
try:
- service_id_index = header.index('service_id')
- trip_id_index = header.index('trip_id')
- route_id_index = header.index('route_id')
- headsign_index = header.index('trip_headsign')
- direction_id_index = header.index('direction_id')
+ service_id_index = header.index("service_id")
+ trip_id_index = header.index("trip_id")
+ route_id_index = header.index("route_id")
+ headsign_index = header.index("trip_headsign")
+ direction_id_index = header.index("direction_id")
except ValueError as e:
logger.error(f"Required column not found in header: {e}")
return trips
# Check if shape_id column exists
shape_id_index = None
- if 'shape_id' in header:
- shape_id_index = header.index('shape_id')
+ if "shape_id" in header:
+ shape_id_index = header.index("shape_id")
else:
logger.warning("shape_id column not found in trips.txt")
# Check if block_id column exists
block_id_index = None
- if 'block_id' in header:
- block_id_index = header.index('block_id')
+ if "block_id" in header:
+ block_id_index = header.index("block_id")
else:
logger.info("block_id column not found in trips.txt")
# Initialize cache for this feed directory
TRIPS_BY_SERVICE_ID[feed_dir] = {}
-
+
for line in lines[1:]:
- parts = line.strip().split(',')
+ parts = line.strip().split(",")
if len(parts) < len(header):
logger.warning(
- f"Skipping malformed line in trips.txt: {line.strip()}")
+ f"Skipping malformed line in trips.txt: {line.strip()}"
+ )
continue
service_id = parts[service_id_index]
@@ -115,19 +135,20 @@ def get_trips_for_services(feed_dir: str, service_ids: list[str]) -> dict[str, l
trip_id=trip_id,
headsign=parts[headsign_index],
direction_id=int(
- parts[direction_id_index] if parts[direction_id_index] else -1),
+ parts[direction_id_index] if parts[direction_id_index] else -1
+ ),
shape_id=shape_id,
- block_id=block_id
+ block_id=block_id,
)
-
+
TRIPS_BY_SERVICE_ID[feed_dir][service_id].append(trip_line)
-
+
# Also build the result for the requested service IDs
if service_id in service_ids:
if service_id not in trips:
trips[service_id] = []
trips[service_id].append(trip_line)
-
+
except FileNotFoundError:
logger.warning("trips.txt file not found.")
diff --git a/src/gtfs_perstop_report/stop_report.py b/src/gtfs_perstop_report/stop_report.py
index f8fdc64..3bbdf11 100644
--- a/src/gtfs_perstop_report/stop_report.py
+++ b/src/gtfs_perstop_report/stop_report.py
@@ -32,8 +32,7 @@ def parse_args():
default="./output/",
help="Directory to write reports to (default: ./output/)",
)
- parser.add_argument("--feed-dir", type=str,
- help="Path to the feed directory")
+ parser.add_argument("--feed-dir", type=str, help="Path to the feed directory")
parser.add_argument(
"--feed-url",
type=str,
@@ -244,12 +243,9 @@ def build_trip_previous_shape_map(
if shift_key not in trips_by_shift:
trips_by_shift[shift_key] = []
- trips_by_shift[shift_key].append((
- trip,
- trip_number,
- first_stop.stop_id,
- last_stop.stop_id
- ))
+ trips_by_shift[shift_key].append(
+ (trip, trip_number, first_stop.stop_id, last_stop.stop_id)
+ )
# For each shift, sort trips by trip number and link consecutive trips
for shift_key, shift_trips in trips_by_shift.items():
# Sort by trip number
@@ -262,16 +258,20 @@ def build_trip_previous_shape_map(
# Check if trips are consecutive (trip numbers differ by 1),
# if previous trip's terminus matches current trip's start,
# and if both trips have valid shape IDs
- if (current_num == prev_num + 1 and
- prev_end_stop == current_start_stop and
- prev_trip.shape_id and
- current_trip.shape_id):
+ if (
+ current_num == prev_num + 1
+ and prev_end_stop == current_start_stop
+ and prev_trip.shape_id
+ and current_trip.shape_id
+ ):
trip_previous_shape[current_trip.trip_id] = prev_trip.shape_id
return trip_previous_shape
-def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]:
+def get_stop_arrivals(
+ feed_dir: str, date: str, provider
+) -> Dict[str, List[Dict[str, Any]]]:
"""
Process trips for the given date and organize stop arrivals.
Also includes night services from the previous day (times >= 24:00:00).
@@ -293,15 +293,16 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict
if not active_services:
logger.info("No active services found for the given date.")
- logger.info(
- f"Found {len(active_services)} active services for date {date}.")
+ logger.info(f"Found {len(active_services)} active services for date {date}.")
# Also get services from the previous day to include night services (times >= 24:00)
- prev_date = (datetime.strptime(date, "%Y-%m-%d") -
- timedelta(days=1)).strftime("%Y-%m-%d")
+ prev_date = (datetime.strptime(date, "%Y-%m-%d") - timedelta(days=1)).strftime(
+ "%Y-%m-%d"
+ )
prev_services = get_active_services(feed_dir, prev_date)
logger.info(
- f"Found {len(prev_services)} active services for previous date {prev_date} (for night services).")
+ f"Found {len(prev_services)} active services for previous date {prev_date} (for night services)."
+ )
all_services = list(set(active_services + prev_services))
@@ -314,18 +315,17 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict
logger.info(f"Found {total_trip_count} trips for active services.")
# Get all trip IDs
- all_trip_ids = [trip.trip_id for trip_list in trips.values()
- for trip in trip_list]
+ all_trip_ids = [trip.trip_id for trip_list in trips.values() for trip in trip_list]
# Get stops for all trips
stops_for_all_trips = get_stops_for_trips(feed_dir, all_trip_ids)
logger.info(f"Precomputed stops for {len(stops_for_all_trips)} trips.")
# Build mapping from trip_id to previous trip's shape_id
- trip_previous_shape_map = build_trip_previous_shape_map(
- trips, stops_for_all_trips)
+ trip_previous_shape_map = build_trip_previous_shape_map(trips, stops_for_all_trips)
logger.info(
- f"Built previous trip shape mapping for {len(trip_previous_shape_map)} trips.")
+ f"Built previous trip shape mapping for {len(trip_previous_shape_map)} trips."
+ )
# Load routes information
routes = load_routes(feed_dir)
@@ -389,8 +389,7 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict
stop_to_segment_idx.append(len(segment_names) - 1)
# Precompute future street transitions per segment
- future_suffix_by_segment: list[tuple[str, ...]] = [
- ()] * len(segment_names)
+ future_suffix_by_segment: list[tuple[str, ...]] = [()] * len(segment_names)
future_tuple: tuple[str, ...] = ()
for idx in range(len(segment_names) - 1, -1, -1):
future_suffix_by_segment[idx] = future_tuple
@@ -437,7 +436,7 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict
passes.append("previous")
for mode in passes:
- is_current_mode = (mode == "current")
+ is_current_mode = mode == "current"
for i, (stop_time, _) in enumerate(trip_stop_pairs):
# Skip the last stop of the trip (terminus) to avoid duplication
@@ -457,11 +456,9 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict
continue
# Normalize times for display on current day (e.g. 25:30 -> 01:30)
- final_starting_time = normalize_gtfs_time(
- starting_time)
+ final_starting_time = normalize_gtfs_time(starting_time)
final_calling_time = normalize_gtfs_time(dep_time)
- final_terminus_time = normalize_gtfs_time(
- terminus_time)
+ final_terminus_time = normalize_gtfs_time(terminus_time)
# SSM should be small (early morning)
final_calling_ssm = time_to_seconds(final_calling_time)
else:
@@ -489,12 +486,10 @@ def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict
# Format IDs and route using provider-specific logic
service_id_fmt = provider.format_service_id(service_id)
trip_id_fmt = provider.format_trip_id(trip_id)
- route_fmt = provider.format_route(
- trip_headsign, terminus_name)
+ route_fmt = provider.format_route(trip_headsign, terminus_name)
# Get previous trip shape_id if available
- previous_trip_shape_id = trip_previous_shape_map.get(
- trip_id, "")
+ previous_trip_shape_id = trip_previous_shape_map.get(trip_id, "")
stop_arrivals[stop_code].append(
{
@@ -616,8 +611,7 @@ def main():
feed_dir = args.feed_dir
else:
logger.info(f"Downloading GTFS feed from {feed_url}...")
- feed_dir = download_feed_from_url(
- feed_url, output_dir, args.force_download)
+ feed_dir = download_feed_from_url(feed_url, output_dir, args.force_download)
if feed_dir is None:
logger.info("Download was skipped (feed not modified). Exiting.")
return
@@ -642,8 +636,7 @@ def main():
_, stop_summary = process_date(feed_dir, date, output_dir, provider)
all_stops_summary[date] = stop_summary
- logger.info(
- "Finished processing all dates. Beginning with shape transformation.")
+ logger.info("Finished processing all dates. Beginning with shape transformation.")
# Process shapes, converting each coordinate to EPSG:25829 and saving as Protobuf
process_shapes(feed_dir, output_dir)