diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/gtfs_vigo_stops/src/common.py | 4 | ||||
| -rw-r--r-- | src/gtfs_vigo_stops/src/providers.py | 136 | ||||
| -rw-r--r-- | src/gtfs_vigo_stops/src/services.py | 16 | ||||
| -rw-r--r-- | src/gtfs_vigo_stops/stop_report.py | 61 |
4 files changed, 198 insertions, 19 deletions
diff --git a/src/gtfs_vigo_stops/src/common.py b/src/gtfs_vigo_stops/src/common.py index 4b5cea5..fcf93d5 100644 --- a/src/gtfs_vigo_stops/src/common.py +++ b/src/gtfs_vigo_stops/src/common.py @@ -36,6 +36,10 @@ def get_all_feed_dates(feed_dir: str) -> List[str]: result.append(start.strftime("%Y-%m-%d")) start += timedelta(days=1) return result + else: + # Return from today to 7 days ahead if no valid dates found + today = datetime.now() + return [(today + timedelta(days=i)).strftime("%Y-%m-%d") for i in range(8)] # Fallback: use calendar_dates.txt if os.path.exists(calendar_dates_path): diff --git a/src/gtfs_vigo_stops/src/providers.py b/src/gtfs_vigo_stops/src/providers.py new file mode 100644 index 0000000..f6414f6 --- /dev/null +++ b/src/gtfs_vigo_stops/src/providers.py @@ -0,0 +1,136 @@ +""" +Provider-specific configuration for different GTFS feed formats. +""" + +from typing import Protocol, Optional +from src.street_name import get_street_name + + +class FeedProvider(Protocol): + """Protocol defining provider-specific behavior for GTFS feeds.""" + + @staticmethod + def format_service_id(service_id: str) -> str: + """Format service_id for output.""" + ... + + @staticmethod + def format_trip_id(trip_id: str) -> str: + """Format trip_id for output.""" + ... + + @staticmethod + def format_route(route: str, terminus_name: str) -> str: + """Format route/headsign, potentially using terminus name as fallback.""" + ... + + @staticmethod + def extract_street_name(stop_name: str) -> str: + """Extract street name from stop name, or return full name.""" + ... + + +class VitrasaProvider: + """Provider configuration for Vitrasa (Vigo bus system).""" + + @staticmethod + def format_service_id(service_id: str) -> str: + """Extract middle part from underscore-separated service_id.""" + parts = service_id.split("_") + return parts[1] if len(parts) >= 2 else service_id + + @staticmethod + def format_trip_id(trip_id: str) -> str: + """Extract middle parts from underscore-separated trip_id.""" + parts = trip_id.split("_") + return "_".join(parts[1:3]) if len(parts) >= 3 else trip_id + + @staticmethod + def format_route(route: str, terminus_name: str) -> str: + """Return route as-is for Vitrasa.""" + return route + + @staticmethod + def extract_street_name(stop_name: str) -> str: + """Extract street name from stop name using standard logic.""" + return get_street_name(stop_name) or "" + + +class RenfeProvider: + """Provider configuration for Renfe (Spanish rail system).""" + + @staticmethod + def format_service_id(service_id: str) -> str: + """Use full service_id for Renfe (no underscores).""" + return service_id + + @staticmethod + def format_trip_id(trip_id: str) -> str: + """Use full trip_id for Renfe (no underscores).""" + return trip_id + + @staticmethod + def format_route(route: str, terminus_name: str) -> str: + """Use terminus name as route if route is empty.""" + return route if route else terminus_name + + @staticmethod + def extract_street_name(stop_name: str) -> str: + """Preserve full stop name for train stations.""" + return stop_name + + +class DefaultProvider: + """Default provider configuration for generic GTFS feeds.""" + + @staticmethod + def format_service_id(service_id: str) -> str: + """Try to extract from underscores, fallback to full ID.""" + parts = service_id.split("_") + return parts[1] if len(parts) >= 2 else service_id + + @staticmethod + def format_trip_id(trip_id: str) -> str: + """Try to extract from underscores, fallback to full ID.""" + parts = trip_id.split("_") + return "_".join(parts[1:3]) if len(parts) >= 3 else trip_id + + @staticmethod + def format_route(route: str, terminus_name: str) -> str: + """Use terminus name as route if route is empty.""" + return route if route else terminus_name + + @staticmethod + def extract_street_name(stop_name: str) -> str: + """Extract street name from stop name using standard logic.""" + return get_street_name(stop_name) or "" + + +# Provider registry +PROVIDERS = { + "vitrasa": VitrasaProvider, + "renfe": RenfeProvider, + "default": DefaultProvider, +} + + +def get_provider(provider_name: str) -> type[FeedProvider]: + """ + Get provider configuration by name. + + Args: + provider_name: Name of the provider (case-insensitive) + + Returns: + Provider class with configuration methods + + Raises: + ValueError: If provider name is not recognized + """ + provider_name_lower = provider_name.lower() + if provider_name_lower not in PROVIDERS: + raise ValueError( + f"Unknown provider: {provider_name}. " + f"Available providers: {', '.join(PROVIDERS.keys())}" + ) + return PROVIDERS[provider_name_lower] diff --git a/src/gtfs_vigo_stops/src/services.py b/src/gtfs_vigo_stops/src/services.py index 9b16173..fb1110d 100644 --- a/src/gtfs_vigo_stops/src/services.py +++ b/src/gtfs_vigo_stops/src/services.py @@ -4,16 +4,17 @@ from src.logger import get_logger logger = get_logger("services") + def get_active_services(feed_dir: str, date: str) -> list[str]: """ Get active services for a given date based on the 'calendar.txt' and 'calendar_dates.txt' files. - + Args: date (str): Date in 'YYYY-MM-DD' format. - + Returns: list[str]: List of active service IDs for the given date. - + Raises: ValueError: If the date format is incorrect. """ @@ -24,7 +25,7 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: try: with open(os.path.join(feed_dir, 'calendar.txt'), 'r', encoding="utf-8") as calendar_file: lines = calendar_file.readlines() - if len(lines) >1: + if len(lines) > 1: # First parse the header, get each column's index header = lines[0].strip().split(',') try: @@ -36,6 +37,8 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: friday_index = header.index('friday') saturday_index = header.index('saturday') sunday_index = header.index('sunday') + start_date_index = header.index('start_date') + end_date_index = header.index('end_date') except ValueError as e: logger.error(f"Required column not found in header: {e}") return active_services @@ -59,8 +62,11 @@ def get_active_services(feed_dir: str, date: str) -> list[str]: service_id = parts[service_id_index] day_value = parts[weekday_columns[weekday]] + start_date = parts[start_date_index] + end_date = parts[end_date_index] - if day_value == '1': + # Check if day of week is active AND date is within the service range + if day_value == '1' and start_date <= search_date <= end_date: active_services.append(service_id) except FileNotFoundError: logger.warning("calendar.txt file not found.") diff --git a/src/gtfs_vigo_stops/stop_report.py b/src/gtfs_vigo_stops/stop_report.py index 76eb90d..f8fdc64 100644 --- a/src/gtfs_vigo_stops/stop_report.py +++ b/src/gtfs_vigo_stops/stop_report.py @@ -15,8 +15,9 @@ from src.routes import load_routes from src.services import get_active_services from src.stop_times import get_stops_for_trips, StopTime from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code -from src.street_name import get_street_name, normalise_stop_name +from src.street_name import normalise_stop_name from src.trips import get_trips_for_services, TripLine +from src.providers import get_provider logger = get_logger("stop_report") @@ -43,6 +44,12 @@ def parse_args(): action="store_true", help="Force download even if the feed hasn't been modified (only applies when using --feed-url)", ) + parser.add_argument( + "--provider", + type=str, + default="default", + help="Feed provider type (vitrasa, renfe, default). Default: default", + ) args = parser.parse_args() if args.feed_dir and args.feed_url: @@ -264,7 +271,7 @@ def build_trip_previous_shape_map( return trip_previous_shape -def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]]]: +def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]: """ Process trips for the given date and organize stop arrivals. Also includes night services from the previous day (times >= 24:00:00). @@ -272,6 +279,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] Args: feed_dir: Path to the GTFS feed directory date: Date in YYYY-MM-DD format + provider: Provider class with feed-specific formatting methods Returns: Dictionary mapping stop_code to lists of arrival information. @@ -323,11 +331,14 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] routes = load_routes(feed_dir) logger.info(f"Loaded {len(routes)} routes from feed.") - # Create a reverse lookup from stop_id to stop_code + # Create a reverse lookup from stop_id to stop_code (or stop_id as fallback) stop_id_to_code = {} for stop_id, stop in stops.items(): if stop.stop_code: stop_id_to_code[stop_id] = get_numeric_code(stop.stop_code) + else: + # Fallback to stop_id if stop_code is not available (e.g., train stations) + stop_id_to_code[stop_id] = stop_id # Organize data by stop_code stop_arrivals = {} @@ -370,7 +381,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] for name in stop_names: street = street_cache.get(name) if street is None: - street = get_street_name(name) or "" + street = provider.extract_street_name(name) street_cache[name] = street if street != previous_street: segment_names.append(street) @@ -398,10 +409,20 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] starting_stop_name = first_stop.stop_name if first_stop else "Unknown Stop" terminus_stop_name = last_stop.stop_name if last_stop else "Unknown Stop" - starting_code = get_numeric_code( - first_stop.stop_code) if first_stop else "" - terminus_code = get_numeric_code( - last_stop.stop_code) if last_stop else "" + # Get stop codes with fallback to stop_id if stop_code is empty + if first_stop: + starting_code = get_numeric_code(first_stop.stop_code) + if not starting_code: + starting_code = first_stop_time.stop_id + else: + starting_code = "" + + if last_stop: + terminus_code = get_numeric_code(last_stop.stop_code) + if not terminus_code: + terminus_code = last_stop_time.stop_id + else: + terminus_code = "" starting_name = normalise_stop_name(starting_stop_name) terminus_name = normalise_stop_name(terminus_stop_name) @@ -465,7 +486,11 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] else: next_streets = [] - trip_id_fmt = "_".join(trip_id.split("_")[1:3]) + # Format IDs and route using provider-specific logic + service_id_fmt = provider.format_service_id(service_id) + trip_id_fmt = provider.format_trip_id(trip_id) + route_fmt = provider.format_route( + trip_headsign, terminus_name) # Get previous trip shape_id if available previous_trip_shape_id = trip_previous_shape_map.get( @@ -473,10 +498,10 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] stop_arrivals[stop_code].append( { - "service_id": service_id.split("_")[1], + "service_id": service_id_fmt, "trip_id": trip_id_fmt, "line": route_short_name, - "route": trip_headsign, + "route": route_fmt, "shape_id": getattr(trip, "shape_id", ""), "stop_sequence": stop_time.stop_sequence, "shape_dist_traveled": getattr( @@ -507,7 +532,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any] def process_date( - feed_dir: str, date: str, output_dir: str + feed_dir: str, date: str, output_dir: str, provider ) -> tuple[str, Dict[str, int]]: """ Process a single date and write its stop JSON files. @@ -520,7 +545,7 @@ def process_date( stops_by_code = get_all_stops_by_code(feed_dir) # Get all stop arrivals for the current date - stop_arrivals = get_stop_arrivals(feed_dir, date) + stop_arrivals = get_stop_arrivals(feed_dir, date, provider) if not stop_arrivals: logger.warning(f"No stop arrivals found for date {date}") @@ -579,6 +604,14 @@ def main(): output_dir = args.output_dir feed_url = args.feed_url + # Get provider configuration + try: + provider = get_provider(args.provider) + logger.info(f"Using provider: {args.provider}") + except ValueError as e: + logger.error(str(e)) + sys.exit(1) + if not feed_url: feed_dir = args.feed_dir else: @@ -606,7 +639,7 @@ def main(): all_stops_summary = {} for date in date_list: - _, stop_summary = process_date(feed_dir, date, output_dir) + _, stop_summary = process_date(feed_dir, date, output_dir, provider) all_stops_summary[date] = stop_summary logger.info( |
