diff options
Diffstat (limited to 'src/gtfs_vigo_stops/stop_report.py')
| -rw-r--r-- | src/gtfs_vigo_stops/stop_report.py | 158 |
1 files changed, 90 insertions, 68 deletions
diff --git a/src/gtfs_vigo_stops/stop_report.py b/src/gtfs_vigo_stops/stop_report.py index 880eaf7..7db751c 100644 --- a/src/gtfs_vigo_stops/stop_report.py +++ b/src/gtfs_vigo_stops/stop_report.py @@ -1,54 +1,66 @@ +import argparse import os import shutil import sys import traceback -import argparse -from typing import List, Dict, Any -from multiprocessing import Pool, cpu_count +from typing import Any, Dict, List +from src.common import get_all_feed_dates from src.download import download_feed_from_url from src.logger import get_logger -from src.common import get_all_feed_dates, time_to_seconds -from src.stops import get_all_stops +from src.report_writer import write_stop_json, write_stop_protobuf +from src.routes import load_routes from src.services import get_active_services +from src.stop_times import get_stops_for_trips +from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code from src.street_name import get_street_name, normalise_stop_name from src.trips import get_trips_for_services -from src.stop_times import get_stops_for_trips -from src.routes import load_routes -from src.report_writer import write_stop_json logger = get_logger("stop_report") def parse_args(): parser = argparse.ArgumentParser( - description="Generate stop-based JSON reports for a date or date range.") - parser.add_argument('--output-dir', type=str, default="./output/", - help='Directory to write reports to (default: ./output/)') - parser.add_argument('--feed-dir', type=str, - help="Path to the feed directory") - parser.add_argument('--feed-url', type=str, - help="URL to download the GTFS feed from (if not using local feed directory)") - parser.add_argument('--force-download', action='store_true', - help="Force download even if the feed hasn't been modified (only applies when using --feed-url)") + description="Generate stop-based JSON reports for a date or date range." + ) + parser.add_argument( + "--output-dir", + type=str, + default="./output/", + help="Directory to write reports to (default: ./output/)", + ) + parser.add_argument("--feed-dir", type=str, help="Path to the feed directory") + parser.add_argument( + "--feed-url", + type=str, + help="URL to download the GTFS feed from (if not using local feed directory)", + ) + parser.add_argument( + "--force-download", + action="store_true", + help="Force download even if the feed hasn't been modified (only applies when using --feed-url)", + ) args = parser.parse_args() if args.feed_dir and args.feed_url: parser.error("Specify either --feed-dir or --feed-url, not both.") if not args.feed_dir and not args.feed_url: parser.error( - "You must specify either a path to the existing feed (unzipped) or a URL to download the GTFS feed from.") + "You must specify either a path to the existing feed (unzipped) or a URL to download the GTFS feed from." + ) if args.feed_dir and not os.path.exists(args.feed_dir): parser.error(f"Feed directory does not exist: {args.feed_dir}") return args def time_to_seconds(time_str: str) -> int: - """Convert HH:MM:SS to seconds since midnight.""" + """ + Convert HH:MM:SS to seconds since midnight. + """ if not time_str: return 0 - parts = time_str.split(':') + parts = time_str.split(":") if len(parts) != 3: return 0 @@ -59,17 +71,7 @@ def time_to_seconds(time_str: str) -> int: return 0 -def get_numeric_code(stop_code: str | None) -> str: - if not stop_code: - return "" - numeric_code = ''.join(c for c in stop_code if c.isdigit()) - return str(int(numeric_code)) if numeric_code else "" - - -def get_stop_arrivals( - feed_dir: str, - date: str -) -> Dict[str, List[Dict[str, Any]]]: +def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]]]: """ Process trips for the given date and organize stop arrivals. @@ -89,16 +91,14 @@ def get_stop_arrivals( logger.info("No active services found for the given date.") return {} - logger.info( - f"Found {len(active_services)} active services for date {date}.") + logger.info(f"Found {len(active_services)} active services for date {date}.") trips = get_trips_for_services(feed_dir, active_services) total_trip_count = sum(len(trip_list) for trip_list in trips.values()) logger.info(f"Found {total_trip_count} trips for active services.") # Get all trip IDs - all_trip_ids = [trip.trip_id for trip_list in trips.values() - for trip in trip_list] + all_trip_ids = [trip.trip_id for trip_list in trips.values() for trip in trip_list] # Get stops for all trips stops_for_all_trips = get_stops_for_trips(feed_dir, all_trip_ids) @@ -121,8 +121,8 @@ def get_stop_arrivals( for trip in trip_list: # Get route information once per trip route_info = routes.get(trip.route_id, {}) - route_short_name = route_info.get('route_short_name', '') - trip_headsign = getattr(trip, 'headsign', '') or '' + route_short_name = route_info.get("route_short_name", "") + trip_headsign = getattr(trip, "headsign", "") or "" trip_id = trip.trip_id # Get stop times for this trip @@ -159,7 +159,11 @@ def get_stop_arrivals( for idx in range(len(segment_names) - 1, -1, -1): future_suffix_by_segment[idx] = future_tuple current_street = segment_names[idx] - future_tuple = (current_street,) + future_tuple if current_street is not None else future_tuple + future_tuple = ( + (current_street,) + future_tuple + if current_street is not None + else future_tuple + ) segment_future_lists: dict[int, list[str]] = {} @@ -189,48 +193,51 @@ def get_stop_arrivals( if segment_names: segment_idx = stop_to_segment_idx[i] if segment_idx not in segment_future_lists: - segment_future_lists[segment_idx] = list(future_suffix_by_segment[segment_idx]) + segment_future_lists[segment_idx] = list( + future_suffix_by_segment[segment_idx] + ) next_streets = segment_future_lists[segment_idx].copy() else: next_streets = [] trip_id_fmt = "_".join(trip_id.split("_")[1:3]) - stop_arrivals[stop_code].append({ - "trip_id": trip_id_fmt, - "service_id": service_id.split("_")[1], - "line": route_short_name, - "route": trip_headsign, - "stop_sequence": stop_time.stop_sequence, - 'shape_dist_traveled': getattr(stop_time, 'shape_dist_traveled', 0), - "next_streets": next_streets, - - "starting_code": starting_code, - "starting_name": starting_name, - "starting_time": starting_time, - - "calling_time": stop_time.departure_time, - "calling_ssm": time_to_seconds(stop_time.departure_time), - - "terminus_code": terminus_code, - "terminus_name": terminus_name, - "terminus_time": terminus_time, - }) + stop_arrivals[stop_code].append( + { + "service_id": service_id.split("_")[1], + "trip_id": trip_id_fmt, + "line": route_short_name, + "route": trip_headsign, + "shape_id": getattr(trip, "shape_id", ""), + "stop_sequence": stop_time.stop_sequence, + "shape_dist_traveled": getattr( + stop_time, "shape_dist_traveled", 0 + ), + "next_streets": next_streets, + "starting_code": starting_code, + "starting_name": starting_name, + "starting_time": starting_time, + "calling_time": stop_time.departure_time, + "calling_ssm": time_to_seconds(stop_time.departure_time), + "terminus_code": terminus_code, + "terminus_name": terminus_name, + "terminus_time": terminus_time, + } + ) # Sort each stop's arrivals by arrival time for stop_code in stop_arrivals: # Filter out entries with None arrival_seconds stop_arrivals[stop_code] = [ - item for item in stop_arrivals[stop_code] if item["calling_ssm"] is not None] + item for item in stop_arrivals[stop_code] if item["calling_ssm"] is not None + ] stop_arrivals[stop_code].sort(key=lambda x: x["calling_ssm"]) return stop_arrivals def process_date( - feed_dir: str, - date: str, - output_dir: str + feed_dir: str, date: str, output_dir: str ) -> tuple[str, Dict[str, int]]: """ Process a single date and write its stop JSON files. @@ -240,6 +247,8 @@ def process_date( try: logger.info(f"Starting stop report generation for date {date}") + stops_by_code = get_all_stops_by_code(feed_dir) + # Get all stop arrivals for the current date stop_arrivals = get_stop_arrivals(feed_dir, date) @@ -249,11 +258,25 @@ def process_date( # Write individual stop JSON files for stop_code, arrivals in stop_arrivals.items(): + # Get the stop from 'stops' by value to get the coords + stop_by_code = stops_by_code.get(stop_code) + + if stop_by_code is not None: + write_stop_protobuf( + output_dir, + date, + stop_code, + arrivals, + stop_by_code.stop_lat or 0.0, + stop_by_code.stop_lon or 0.0, + ) + write_stop_json(output_dir, date, stop_code, arrivals) # Create summary for index - stop_summary = {stop_code: len(arrivals) - for stop_code, arrivals in stop_arrivals.items()} + stop_summary = { + stop_code: len(arrivals) for stop_code, arrivals in stop_arrivals.items() + } logger.info(f"Processed {len(stop_arrivals)} stops for date {date}") return date, stop_summary @@ -271,15 +294,14 @@ def main(): feed_dir = args.feed_dir else: logger.info(f"Downloading GTFS feed from {feed_url}...") - feed_dir = download_feed_from_url( - feed_url, output_dir, args.force_download) + feed_dir = download_feed_from_url(feed_url, output_dir, args.force_download) if feed_dir is None: logger.info("Download was skipped (feed not modified). Exiting.") return all_dates = get_all_feed_dates(feed_dir) if not all_dates: - logger.error('No valid dates found in feed.') + logger.error("No valid dates found in feed.") return date_list = all_dates |
