From 5fa8d1ffeb4a3a0c5c6846de3986ec779a4fe564 Mon Sep 17 00:00:00 2001 From: Ariel Costas Guerrero Date: Sun, 7 Dec 2025 22:08:15 +0100 Subject: feat: implement provider-specific configurations for GTFS feed formats --- src/gtfs_vigo_stops/stop_report.py | 61 +++++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 14 deletions(-) (limited to 'src/gtfs_vigo_stops/stop_report.py') diff --git a/src/gtfs_vigo_stops/stop_report.py b/src/gtfs_vigo_stops/stop_report.py index 76eb90d..f8fdc64 100644 --- a/src/gtfs_vigo_stops/stop_report.py +++ b/src/gtfs_vigo_stops/stop_report.py @@ -15,8 +15,9 @@ from src.routes import load_routes from src.services import get_active_services from src.stop_times import get_stops_for_trips, StopTime from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code -from src.street_name import get_street_name, normalise_stop_name +from src.street_name import normalise_stop_name from src.trips import get_trips_for_services, TripLine +from src.providers import get_provider logger = get_logger("stop_report") @@ -43,6 +44,12 @@ def parse_args(): action="store_true", help="Force download even if the feed hasn't been modified (only applies when using --feed-url)", ) + parser.add_argument( + "--provider", + type=str, + default="default", + help="Feed provider type (vitrasa, renfe, default). Default: default", + ) args = parser.parse_args() if args.feed_dir and args.feed_url: @@ -264,7 +271,7 @@ def build_trip_previous_shape_map( return trip_previous_shape -def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]]]: +def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]: """ Process trips for the given date and organize stop arrivals. Also includes night services from the previous day (times >= 24:00:00). @@ -272,6 +279,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] Args: feed_dir: Path to the GTFS feed directory date: Date in YYYY-MM-DD format + provider: Provider class with feed-specific formatting methods Returns: Dictionary mapping stop_code to lists of arrival information. @@ -323,11 +331,14 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] routes = load_routes(feed_dir) logger.info(f"Loaded {len(routes)} routes from feed.") - # Create a reverse lookup from stop_id to stop_code + # Create a reverse lookup from stop_id to stop_code (or stop_id as fallback) stop_id_to_code = {} for stop_id, stop in stops.items(): if stop.stop_code: stop_id_to_code[stop_id] = get_numeric_code(stop.stop_code) + else: + # Fallback to stop_id if stop_code is not available (e.g., train stations) + stop_id_to_code[stop_id] = stop_id # Organize data by stop_code stop_arrivals = {} @@ -370,7 +381,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] for name in stop_names: street = street_cache.get(name) if street is None: - street = get_street_name(name) or "" + street = provider.extract_street_name(name) street_cache[name] = street if street != previous_street: segment_names.append(street) @@ -398,10 +409,20 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] starting_stop_name = first_stop.stop_name if first_stop else "Unknown Stop" terminus_stop_name = last_stop.stop_name if last_stop else "Unknown Stop" - starting_code = get_numeric_code( - first_stop.stop_code) if first_stop else "" - terminus_code = get_numeric_code( - last_stop.stop_code) if last_stop else "" + # Get stop codes with fallback to stop_id if stop_code is empty + if first_stop: + starting_code = get_numeric_code(first_stop.stop_code) + if not starting_code: + starting_code = first_stop_time.stop_id + else: + starting_code = "" + + if last_stop: + terminus_code = get_numeric_code(last_stop.stop_code) + if not terminus_code: + terminus_code = last_stop_time.stop_id + else: + terminus_code = "" starting_name = normalise_stop_name(starting_stop_name) terminus_name = normalise_stop_name(terminus_stop_name) @@ -465,7 +486,11 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] else: next_streets = [] - trip_id_fmt = "_".join(trip_id.split("_")[1:3]) + # Format IDs and route using provider-specific logic + service_id_fmt = provider.format_service_id(service_id) + trip_id_fmt = provider.format_trip_id(trip_id) + route_fmt = provider.format_route( + trip_headsign, terminus_name) # Get previous trip shape_id if available previous_trip_shape_id = trip_previous_shape_map.get( @@ -473,10 +498,10 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] stop_arrivals[stop_code].append( { - "service_id": service_id.split("_")[1], + "service_id": service_id_fmt, "trip_id": trip_id_fmt, "line": route_short_name, - "route": trip_headsign, + "route": route_fmt, "shape_id": getattr(trip, "shape_id", ""), "stop_sequence": stop_time.stop_sequence, "shape_dist_traveled": getattr( @@ -507,7 +532,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] def process_date( - feed_dir: str, date: str, output_dir: str + feed_dir: str, date: str, output_dir: str, provider ) -> tuple[str, Dict[str, int]]: """ Process a single date and write its stop JSON files. @@ -520,7 +545,7 @@ def process_date( stops_by_code = get_all_stops_by_code(feed_dir) # Get all stop arrivals for the current date - stop_arrivals = get_stop_arrivals(feed_dir, date) + stop_arrivals = get_stop_arrivals(feed_dir, date, provider) if not stop_arrivals: logger.warning(f"No stop arrivals found for date {date}") @@ -579,6 +604,14 @@ def main(): output_dir = args.output_dir feed_url = args.feed_url + # Get provider configuration + try: + provider = get_provider(args.provider) + logger.info(f"Using provider: {args.provider}") + except ValueError as e: + logger.error(str(e)) + sys.exit(1) + if not feed_url: feed_dir = args.feed_dir else: @@ -606,7 +639,7 @@ def main(): all_stops_summary = {} for date in date_list: - _, stop_summary = process_date(feed_dir, date, output_dir) + _, stop_summary = process_date(feed_dir, date, output_dir, provider) all_stops_summary[date] = stop_summary logger.info( -- cgit v1.3