aboutsummaryrefslogtreecommitdiff
path: root/src/gtfs_perstop_report/stop_report.py
diff options
context:
space:
mode:
authorAriel Costas Guerrero <ariel@costas.dev>2025-12-08 01:37:10 +0100
committerAriel Costas Guerrero <ariel@costas.dev>2025-12-08 01:37:10 +0100
commit107295575e3a7c37911ae192baf426b0003975a4 (patch)
treeceb528a428716d5313517a0fa72fcac3ea1360fb /src/gtfs_perstop_report/stop_report.py
parent3b3fd2f6880eaa9170b480d41d43311925483bea (diff)
Refactor code structure for improved readability and maintainability; removed redundant code blocks and optimized functions.
Diffstat (limited to 'src/gtfs_perstop_report/stop_report.py')
-rw-r--r--src/gtfs_perstop_report/stop_report.py666
1 files changed, 666 insertions, 0 deletions
diff --git a/src/gtfs_perstop_report/stop_report.py b/src/gtfs_perstop_report/stop_report.py
new file mode 100644
index 0000000..f8fdc64
--- /dev/null
+++ b/src/gtfs_perstop_report/stop_report.py
@@ -0,0 +1,666 @@
+import argparse
+import os
+import shutil
+import sys
+import time
+import traceback
+from typing import Any, Dict, List, Optional, Tuple
+
+from src.shapes import process_shapes
+from src.common import get_all_feed_dates
+from src.download import download_feed_from_url
+from src.logger import get_logger
+from src.report_writer import write_stop_json, write_stop_protobuf
+from src.routes import load_routes
+from src.services import get_active_services
+from src.stop_times import get_stops_for_trips, StopTime
+from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code
+from src.street_name import normalise_stop_name
+from src.trips import get_trips_for_services, TripLine
+from src.providers import get_provider
+
+logger = get_logger("stop_report")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Generate stop-based JSON reports for a date or date range."
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="./output/",
+ help="Directory to write reports to (default: ./output/)",
+ )
+ parser.add_argument("--feed-dir", type=str,
+ help="Path to the feed directory")
+ parser.add_argument(
+ "--feed-url",
+ type=str,
+ help="URL to download the GTFS feed from (if not using local feed directory)",
+ )
+ parser.add_argument(
+ "--force-download",
+ action="store_true",
+ help="Force download even if the feed hasn't been modified (only applies when using --feed-url)",
+ )
+ parser.add_argument(
+ "--provider",
+ type=str,
+ default="default",
+ help="Feed provider type (vitrasa, renfe, default). Default: default",
+ )
+ args = parser.parse_args()
+
+ if args.feed_dir and args.feed_url:
+ parser.error("Specify either --feed-dir or --feed-url, not both.")
+ if not args.feed_dir and not args.feed_url:
+ parser.error(
+ "You must specify either a path to the existing feed (unzipped) or a URL to download the GTFS feed from."
+ )
+ if args.feed_dir and not os.path.exists(args.feed_dir):
+ parser.error(f"Feed directory does not exist: {args.feed_dir}")
+ return args
+
+
+def time_to_seconds(time_str: str) -> int:
+ """
+ Convert HH:MM:SS to seconds since midnight.
+ Handles GTFS times that can exceed 24 hours (e.g., 25:30:00 for 1:30 AM next day).
+ """
+ if not time_str:
+ return 0
+
+ parts = time_str.split(":")
+ if len(parts) != 3:
+ return 0
+
+ try:
+ hours, minutes, seconds = map(int, parts)
+ return hours * 3600 + minutes * 60 + seconds
+ except ValueError:
+ return 0
+
+
+def normalize_gtfs_time(time_str: str) -> str:
+ """
+ Normalize GTFS time format to standard HH:MM:SS (0-23 hours).
+ Converts times like 25:30:00 to 01:30:00.
+
+ Args:
+ time_str: Time in HH:MM:SS format, possibly with hours >= 24
+
+ Returns:
+ Normalized time string in HH:MM:SS format
+ """
+ if not time_str:
+ return time_str
+
+ parts = time_str.split(":")
+ if len(parts) != 3:
+ return time_str
+
+ try:
+ hours, minutes, seconds = map(int, parts)
+ normalized_hours = hours % 24
+ return f"{normalized_hours:02d}:{minutes:02d}:{seconds:02d}"
+ except ValueError:
+ return time_str
+
+
+def format_gtfs_time(time_str: str) -> str:
+ """
+ Format GTFS time to HH:MM:SS, preserving hours >= 24.
+ """
+ if not time_str:
+ return time_str
+
+ parts = time_str.split(":")
+ if len(parts) != 3:
+ return time_str
+
+ try:
+ hours, minutes, seconds = map(int, parts)
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
+ except ValueError:
+ return time_str
+
+
+def is_next_day_service(time_str: str) -> bool:
+ """
+ Check if a GTFS time represents a service on the next day (hours >= 24).
+
+ Args:
+ time_str: Time in HH:MM:SS format
+
+ Returns:
+ True if the time is >= 24:00:00, False otherwise
+ """
+ if not time_str:
+ return False
+
+ parts = time_str.split(":")
+ if len(parts) != 3:
+ return False
+
+ try:
+ hours = int(parts[0])
+ return hours >= 24
+ except ValueError:
+ return False
+
+
+def parse_trip_id_components(trip_id: str) -> Optional[Tuple[str, str, int]]:
+ """
+ Parse a trip ID in format XXXYYY-Z or XXXYYY_Z where:
+ - XXX = line number (e.g., 003)
+ - YYY = shift/internal ID (e.g., 001)
+ - Z = trip number (e.g., 12)
+
+ Supported formats:
+ 1. ..._XXXYYY_Z (e.g. "C1 01SNA00_001001_18")
+ 2. ..._XXXYYY-Z (e.g. "VIGO_20241122_003001-12")
+
+ Returns tuple of (line, shift_id, trip_number) or None if parsing fails.
+ """
+ try:
+ parts = trip_id.split("_")
+ if len(parts) < 2:
+ return None
+
+ # Try format 1: ..._XXXYYY_Z
+ # Check if second to last part is 6 digits (XXXYYY) and last part is numeric
+ if len(parts) >= 2:
+ shift_part = parts[-2]
+ trip_num_str = parts[-1]
+ if len(shift_part) == 6 and shift_part.isdigit() and trip_num_str.isdigit():
+ line = shift_part[:3]
+ shift_id = shift_part[3:6]
+ trip_number = int(trip_num_str)
+ return (line, shift_id, trip_number)
+
+ # Try format 2: ..._XXXYYY-Z
+ # The trip ID is the last part in format XXXYYY-Z
+ trip_part = parts[-1]
+
+ if "-" in trip_part:
+ shift_part, trip_num_str = trip_part.split("-", 1)
+
+ # shift_part should be 6 digits: XXXYYY
+ if len(shift_part) == 6 and shift_part.isdigit():
+ line = shift_part[:3] # First 3 digits
+ shift_id = shift_part[3:6] # Next 3 digits
+ trip_number = int(trip_num_str)
+ return (line, shift_id, trip_number)
+
+ return None
+ except (ValueError, IndexError):
+ return None
+
+
+def build_trip_previous_shape_map(
+ trips: Dict[str, List[TripLine]],
+ stops_for_all_trips: Dict[str, List[StopTime]],
+) -> Dict[str, Optional[str]]:
+ """
+ Build a mapping from trip_id to previous_trip_shape_id.
+
+ Links trips based on trip ID structure (XXXYYY-Z) where trips with the same
+ XXX (line) and YYY (shift) and sequential Z (trip numbers) are connected
+ if the terminus of trip N matches the start of trip N+1.
+
+ Args:
+ trips: Dictionary of service_id -> list of trips
+ stops_for_all_trips: Dictionary of trip_id -> list of stop times
+
+ Returns:
+ Dictionary mapping trip_id to previous_trip_shape_id (or None)
+ """
+ trip_previous_shape: Dict[str, Optional[str]] = {}
+
+ # Collect all trips across all services
+ all_trips_list: List[TripLine] = []
+ for trip_list in trips.values():
+ all_trips_list.extend(trip_list)
+
+ # Group trips by shift ID (line + shift combination)
+ trips_by_shift: Dict[str, List[Tuple[TripLine, int, str, str]]] = {}
+
+ for trip in all_trips_list:
+ parsed = parse_trip_id_components(trip.trip_id)
+ if not parsed:
+ continue
+
+ line, shift_id, trip_number = parsed
+ shift_key = f"{line}{shift_id}"
+
+ trip_stops = stops_for_all_trips.get(trip.trip_id)
+ if not trip_stops or len(trip_stops) < 2:
+ continue
+
+ first_stop = trip_stops[0]
+ last_stop = trip_stops[-1]
+
+ if shift_key not in trips_by_shift:
+ trips_by_shift[shift_key] = []
+
+ trips_by_shift[shift_key].append((
+ trip,
+ trip_number,
+ first_stop.stop_id,
+ last_stop.stop_id
+ ))
+ # For each shift, sort trips by trip number and link consecutive trips
+ for shift_key, shift_trips in trips_by_shift.items():
+ # Sort by trip number
+ shift_trips.sort(key=lambda x: x[1])
+ # Link consecutive trips if their stops match
+ for i in range(1, len(shift_trips)):
+ current_trip, current_num, current_start_stop, _ = shift_trips[i]
+ prev_trip, prev_num, _, prev_end_stop = shift_trips[i - 1]
+
+ # Check if trips are consecutive (trip numbers differ by 1),
+ # if previous trip's terminus matches current trip's start,
+ # and if both trips have valid shape IDs
+ if (current_num == prev_num + 1 and
+ prev_end_stop == current_start_stop and
+ prev_trip.shape_id and
+ current_trip.shape_id):
+ trip_previous_shape[current_trip.trip_id] = prev_trip.shape_id
+
+ return trip_previous_shape
+
+
+def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]:
+ """
+ Process trips for the given date and organize stop arrivals.
+ Also includes night services from the previous day (times >= 24:00:00).
+
+ Args:
+ feed_dir: Path to the GTFS feed directory
+ date: Date in YYYY-MM-DD format
+ provider: Provider class with feed-specific formatting methods
+
+ Returns:
+ Dictionary mapping stop_code to lists of arrival information.
+ """
+ from datetime import datetime, timedelta
+
+ stops = get_all_stops(feed_dir)
+ logger.info(f"Found {len(stops)} stops in the feed.")
+
+ active_services = get_active_services(feed_dir, date)
+ if not active_services:
+ logger.info("No active services found for the given date.")
+
+ logger.info(
+ f"Found {len(active_services)} active services for date {date}.")
+
+ # Also get services from the previous day to include night services (times >= 24:00)
+ prev_date = (datetime.strptime(date, "%Y-%m-%d") -
+ timedelta(days=1)).strftime("%Y-%m-%d")
+ prev_services = get_active_services(feed_dir, prev_date)
+ logger.info(
+ f"Found {len(prev_services)} active services for previous date {prev_date} (for night services).")
+
+ all_services = list(set(active_services + prev_services))
+
+ if not all_services:
+ logger.info("No active services found for current or previous date.")
+ return {}
+
+ trips = get_trips_for_services(feed_dir, all_services)
+ total_trip_count = sum(len(trip_list) for trip_list in trips.values())
+ logger.info(f"Found {total_trip_count} trips for active services.")
+
+ # Get all trip IDs
+ all_trip_ids = [trip.trip_id for trip_list in trips.values()
+ for trip in trip_list]
+
+ # Get stops for all trips
+ stops_for_all_trips = get_stops_for_trips(feed_dir, all_trip_ids)
+ logger.info(f"Precomputed stops for {len(stops_for_all_trips)} trips.")
+
+ # Build mapping from trip_id to previous trip's shape_id
+ trip_previous_shape_map = build_trip_previous_shape_map(
+ trips, stops_for_all_trips)
+ logger.info(
+ f"Built previous trip shape mapping for {len(trip_previous_shape_map)} trips.")
+
+ # Load routes information
+ routes = load_routes(feed_dir)
+ logger.info(f"Loaded {len(routes)} routes from feed.")
+
+ # Create a reverse lookup from stop_id to stop_code (or stop_id as fallback)
+ stop_id_to_code = {}
+ for stop_id, stop in stops.items():
+ if stop.stop_code:
+ stop_id_to_code[stop_id] = get_numeric_code(stop.stop_code)
+ else:
+ # Fallback to stop_id if stop_code is not available (e.g., train stations)
+ stop_id_to_code[stop_id] = stop_id
+
+ # Organize data by stop_code
+ stop_arrivals = {}
+
+ active_services_set = set(active_services)
+ prev_services_set = set(prev_services)
+
+ for service_id, trip_list in trips.items():
+ is_active = service_id in active_services_set
+ is_prev = service_id in prev_services_set
+
+ if not is_active and not is_prev:
+ continue
+
+ for trip in trip_list:
+ # Get route information once per trip
+ route_info = routes.get(trip.route_id, {})
+ route_short_name = route_info.get("route_short_name", "")
+ trip_headsign = getattr(trip, "headsign", "") or ""
+ trip_id = trip.trip_id
+
+ # Get stop times for this trip
+ trip_stops = stops_for_all_trips.get(trip.trip_id, [])
+ if not trip_stops:
+ continue
+
+ # Pair stop_times with stop metadata once to avoid repeated lookups
+ trip_stop_pairs = []
+ stop_names = []
+ for stop_time in trip_stops:
+ stop = stops.get(stop_time.stop_id)
+ trip_stop_pairs.append((stop_time, stop))
+ stop_names.append(stop.stop_name if stop else "Unknown Stop")
+
+ # Memoize street names per stop name for this trip and build segments
+ street_cache: dict[str, str] = {}
+ segment_names: list[str] = []
+ stop_to_segment_idx: list[int] = []
+ previous_street: str | None = None
+ for name in stop_names:
+ street = street_cache.get(name)
+ if street is None:
+ street = provider.extract_street_name(name)
+ street_cache[name] = street
+ if street != previous_street:
+ segment_names.append(street)
+ previous_street = street
+ stop_to_segment_idx.append(len(segment_names) - 1)
+
+ # Precompute future street transitions per segment
+ future_suffix_by_segment: list[tuple[str, ...]] = [
+ ()] * len(segment_names)
+ future_tuple: tuple[str, ...] = ()
+ for idx in range(len(segment_names) - 1, -1, -1):
+ future_suffix_by_segment[idx] = future_tuple
+ current_street = segment_names[idx]
+ future_tuple = (
+ (current_street,) + future_tuple
+ if current_street is not None
+ else future_tuple
+ )
+
+ segment_future_lists: dict[int, list[str]] = {}
+
+ first_stop_time, first_stop = trip_stop_pairs[0]
+ last_stop_time, last_stop = trip_stop_pairs[-1]
+
+ starting_stop_name = first_stop.stop_name if first_stop else "Unknown Stop"
+ terminus_stop_name = last_stop.stop_name if last_stop else "Unknown Stop"
+
+ # Get stop codes with fallback to stop_id if stop_code is empty
+ if first_stop:
+ starting_code = get_numeric_code(first_stop.stop_code)
+ if not starting_code:
+ starting_code = first_stop_time.stop_id
+ else:
+ starting_code = ""
+
+ if last_stop:
+ terminus_code = get_numeric_code(last_stop.stop_code)
+ if not terminus_code:
+ terminus_code = last_stop_time.stop_id
+ else:
+ terminus_code = ""
+
+ starting_name = normalise_stop_name(starting_stop_name)
+ terminus_name = normalise_stop_name(terminus_stop_name)
+ starting_time = first_stop_time.departure_time
+ terminus_time = last_stop_time.arrival_time
+
+ # Determine processing passes for this trip
+ passes = []
+ if is_active:
+ passes.append("current")
+ if is_prev:
+ passes.append("previous")
+
+ for mode in passes:
+ is_current_mode = (mode == "current")
+
+ for i, (stop_time, _) in enumerate(trip_stop_pairs):
+ # Skip the last stop of the trip (terminus) to avoid duplication
+ if i == len(trip_stop_pairs) - 1:
+ continue
+
+ stop_code = stop_id_to_code.get(stop_time.stop_id)
+
+ if not stop_code:
+ continue # Skip stops without a code
+
+ dep_time = stop_time.departure_time
+
+ if not is_current_mode:
+ # Previous day service: only include if calling_time >= 24:00:00 (night services rolling to this day)
+ if not is_next_day_service(dep_time):
+ continue
+
+ # Normalize times for display on current day (e.g. 25:30 -> 01:30)
+ final_starting_time = normalize_gtfs_time(
+ starting_time)
+ final_calling_time = normalize_gtfs_time(dep_time)
+ final_terminus_time = normalize_gtfs_time(
+ terminus_time)
+ # SSM should be small (early morning)
+ final_calling_ssm = time_to_seconds(final_calling_time)
+ else:
+ # Current day service: include ALL times
+ # Keep times as is (e.g. 25:30 stays 25:30)
+ final_starting_time = format_gtfs_time(starting_time)
+ final_calling_time = format_gtfs_time(dep_time)
+ final_terminus_time = format_gtfs_time(terminus_time)
+ # SSM should be large if > 24:00
+ final_calling_ssm = time_to_seconds(dep_time)
+
+ if stop_code not in stop_arrivals:
+ stop_arrivals[stop_code] = []
+
+ if segment_names:
+ segment_idx = stop_to_segment_idx[i]
+ if segment_idx not in segment_future_lists:
+ segment_future_lists[segment_idx] = list(
+ future_suffix_by_segment[segment_idx]
+ )
+ next_streets = segment_future_lists[segment_idx].copy()
+ else:
+ next_streets = []
+
+ # Format IDs and route using provider-specific logic
+ service_id_fmt = provider.format_service_id(service_id)
+ trip_id_fmt = provider.format_trip_id(trip_id)
+ route_fmt = provider.format_route(
+ trip_headsign, terminus_name)
+
+ # Get previous trip shape_id if available
+ previous_trip_shape_id = trip_previous_shape_map.get(
+ trip_id, "")
+
+ stop_arrivals[stop_code].append(
+ {
+ "service_id": service_id_fmt,
+ "trip_id": trip_id_fmt,
+ "line": route_short_name,
+ "route": route_fmt,
+ "shape_id": getattr(trip, "shape_id", ""),
+ "stop_sequence": stop_time.stop_sequence,
+ "shape_dist_traveled": getattr(
+ stop_time, "shape_dist_traveled", 0
+ ),
+ "next_streets": [s for s in next_streets if s != ""],
+ "starting_code": starting_code,
+ "starting_name": starting_name,
+ "starting_time": final_starting_time,
+ "calling_time": final_calling_time,
+ "calling_ssm": final_calling_ssm,
+ "terminus_code": terminus_code,
+ "terminus_name": terminus_name,
+ "terminus_time": final_terminus_time,
+ "previous_trip_shape_id": previous_trip_shape_id,
+ }
+ )
+
+ # Sort each stop's arrivals by arrival time
+ for stop_code in stop_arrivals:
+ # Filter out entries with None arrival_seconds
+ stop_arrivals[stop_code] = [
+ item for item in stop_arrivals[stop_code] if item["calling_ssm"] is not None
+ ]
+ stop_arrivals[stop_code].sort(key=lambda x: x["calling_ssm"])
+
+ return stop_arrivals
+
+
+def process_date(
+ feed_dir: str, date: str, output_dir: str, provider
+) -> tuple[str, Dict[str, int]]:
+ """
+ Process a single date and write its stop JSON files.
+ Returns summary data for index generation.
+ """
+ logger = get_logger(f"stop_report_{date}")
+ try:
+ logger.info(f"Starting stop report generation for date {date}")
+
+ stops_by_code = get_all_stops_by_code(feed_dir)
+
+ # Get all stop arrivals for the current date
+ stop_arrivals = get_stop_arrivals(feed_dir, date, provider)
+
+ if not stop_arrivals:
+ logger.warning(f"No stop arrivals found for date {date}")
+ return date, {}
+
+ logger.info(
+ f"Writing stop reports for {len(stop_arrivals)} stops for date {date}"
+ )
+
+ # Write individual stop JSON files
+ writing_start_time = time.perf_counter()
+ for stop_code, arrivals in stop_arrivals.items():
+ write_stop_json(output_dir, date, stop_code, arrivals)
+ writing_end_time = time.perf_counter()
+ writing_elapsed = writing_end_time - writing_start_time
+
+ logger.info(
+ f"Finished writing stop JSON reports for date {date} in {writing_elapsed:.2f}s"
+ )
+
+ # Write individual stop JSON files
+ writing_start_time = time.perf_counter()
+ for stop_code, arrivals in stop_arrivals.items():
+ stop_by_code = stops_by_code.get(stop_code)
+
+ if stop_by_code is not None:
+ write_stop_protobuf(
+ output_dir,
+ date,
+ stop_code,
+ arrivals,
+ stop_by_code.stop_25829_x or 0.0,
+ stop_by_code.stop_25829_y or 0.0,
+ )
+
+ writing_end_time = time.perf_counter()
+ writing_elapsed = writing_end_time - writing_start_time
+
+ logger.info(
+ f"Finished writing stop protobuf reports for date {date} in {writing_elapsed:.2f}s"
+ )
+
+ logger.info(f"Processed {len(stop_arrivals)} stops for date {date}")
+
+ stop_summary = {
+ stop_code: len(arrivals) for stop_code, arrivals in stop_arrivals.items()
+ }
+ return date, stop_summary
+ except Exception as e:
+ logger.error(f"Error processing date {date}: {e}")
+ raise
+
+
+def main():
+ args = parse_args()
+ output_dir = args.output_dir
+ feed_url = args.feed_url
+
+ # Get provider configuration
+ try:
+ provider = get_provider(args.provider)
+ logger.info(f"Using provider: {args.provider}")
+ except ValueError as e:
+ logger.error(str(e))
+ sys.exit(1)
+
+ if not feed_url:
+ feed_dir = args.feed_dir
+ else:
+ logger.info(f"Downloading GTFS feed from {feed_url}...")
+ feed_dir = download_feed_from_url(
+ feed_url, output_dir, args.force_download)
+ if feed_dir is None:
+ logger.info("Download was skipped (feed not modified). Exiting.")
+ return
+
+ all_dates = get_all_feed_dates(feed_dir)
+ if not all_dates:
+ logger.error("No valid dates found in feed.")
+ return
+ date_list = all_dates
+
+ # Ensure date_list is not empty before processing
+ if not date_list:
+ logger.error("No valid dates to process.")
+ return
+
+ logger.info(f"Processing {len(date_list)} dates")
+
+ # Dictionary to store summary data for index files
+ all_stops_summary = {}
+
+ for date in date_list:
+ _, stop_summary = process_date(feed_dir, date, output_dir, provider)
+ all_stops_summary[date] = stop_summary
+
+ logger.info(
+ "Finished processing all dates. Beginning with shape transformation.")
+
+ # Process shapes, converting each coordinate to EPSG:25829 and saving as Protobuf
+ process_shapes(feed_dir, output_dir)
+
+ logger.info("Finished processing shapes.")
+
+ if feed_url:
+ if os.path.exists(feed_dir):
+ shutil.rmtree(feed_dir)
+ logger.info(f"Removed temporary feed directory: {feed_dir}")
+
+
+if __name__ == "__main__":
+ try:
+ main()
+ except Exception as e:
+ logger = get_logger("stop_report")
+ logger.critical(f"An unexpected error occurred: {e}", exc_info=True)
+ traceback.print_exc()
+ sys.exit(1)