aboutsummaryrefslogtreecommitdiff
path: root/src/gtfs_vigo_stops/stop_report.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gtfs_vigo_stops/stop_report.py')
-rw-r--r--src/gtfs_vigo_stops/stop_report.py61
1 files changed, 47 insertions, 14 deletions
diff --git a/src/gtfs_vigo_stops/stop_report.py b/src/gtfs_vigo_stops/stop_report.py
index 76eb90d..f8fdc64 100644
--- a/src/gtfs_vigo_stops/stop_report.py
+++ b/src/gtfs_vigo_stops/stop_report.py
@@ -15,8 +15,9 @@ from src.routes import load_routes
from src.services import get_active_services
from src.stop_times import get_stops_for_trips, StopTime
from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code
-from src.street_name import get_street_name, normalise_stop_name
+from src.street_name import normalise_stop_name
from src.trips import get_trips_for_services, TripLine
+from src.providers import get_provider
logger = get_logger("stop_report")
@@ -43,6 +44,12 @@ def parse_args():
action="store_true",
help="Force download even if the feed hasn't been modified (only applies when using --feed-url)",
)
+ parser.add_argument(
+ "--provider",
+ type=str,
+ default="default",
+ help="Feed provider type (vitrasa, renfe, default). Default: default",
+ )
args = parser.parse_args()
if args.feed_dir and args.feed_url:
@@ -264,7 +271,7 @@ def build_trip_previous_shape_map(
return trip_previous_shape
-def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]]]:
+def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]:
"""
Process trips for the given date and organize stop arrivals.
Also includes night services from the previous day (times >= 24:00:00).
@@ -272,6 +279,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
Args:
feed_dir: Path to the GTFS feed directory
date: Date in YYYY-MM-DD format
+ provider: Provider class with feed-specific formatting methods
Returns:
Dictionary mapping stop_code to lists of arrival information.
@@ -323,11 +331,14 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
routes = load_routes(feed_dir)
logger.info(f"Loaded {len(routes)} routes from feed.")
- # Create a reverse lookup from stop_id to stop_code
+ # Create a reverse lookup from stop_id to stop_code (or stop_id as fallback)
stop_id_to_code = {}
for stop_id, stop in stops.items():
if stop.stop_code:
stop_id_to_code[stop_id] = get_numeric_code(stop.stop_code)
+ else:
+ # Fallback to stop_id if stop_code is not available (e.g., train stations)
+ stop_id_to_code[stop_id] = stop_id
# Organize data by stop_code
stop_arrivals = {}
@@ -370,7 +381,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
for name in stop_names:
street = street_cache.get(name)
if street is None:
- street = get_street_name(name) or ""
+ street = provider.extract_street_name(name)
street_cache[name] = street
if street != previous_street:
segment_names.append(street)
@@ -398,10 +409,20 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
starting_stop_name = first_stop.stop_name if first_stop else "Unknown Stop"
terminus_stop_name = last_stop.stop_name if last_stop else "Unknown Stop"
- starting_code = get_numeric_code(
- first_stop.stop_code) if first_stop else ""
- terminus_code = get_numeric_code(
- last_stop.stop_code) if last_stop else ""
+ # Get stop codes with fallback to stop_id if stop_code is empty
+ if first_stop:
+ starting_code = get_numeric_code(first_stop.stop_code)
+ if not starting_code:
+ starting_code = first_stop_time.stop_id
+ else:
+ starting_code = ""
+
+ if last_stop:
+ terminus_code = get_numeric_code(last_stop.stop_code)
+ if not terminus_code:
+ terminus_code = last_stop_time.stop_id
+ else:
+ terminus_code = ""
starting_name = normalise_stop_name(starting_stop_name)
terminus_name = normalise_stop_name(terminus_stop_name)
@@ -465,7 +486,11 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
else:
next_streets = []
- trip_id_fmt = "_".join(trip_id.split("_")[1:3])
+ # Format IDs and route using provider-specific logic
+ service_id_fmt = provider.format_service_id(service_id)
+ trip_id_fmt = provider.format_trip_id(trip_id)
+ route_fmt = provider.format_route(
+ trip_headsign, terminus_name)
# Get previous trip shape_id if available
previous_trip_shape_id = trip_previous_shape_map.get(
@@ -473,10 +498,10 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
stop_arrivals[stop_code].append(
{
- "service_id": service_id.split("_")[1],
+ "service_id": service_id_fmt,
"trip_id": trip_id_fmt,
"line": route_short_name,
- "route": trip_headsign,
+ "route": route_fmt,
"shape_id": getattr(trip, "shape_id", ""),
"stop_sequence": stop_time.stop_sequence,
"shape_dist_traveled": getattr(
@@ -507,7 +532,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
def process_date(
- feed_dir: str, date: str, output_dir: str
+ feed_dir: str, date: str, output_dir: str, provider
) -> tuple[str, Dict[str, int]]:
"""
Process a single date and write its stop JSON files.
@@ -520,7 +545,7 @@ def process_date(
stops_by_code = get_all_stops_by_code(feed_dir)
# Get all stop arrivals for the current date
- stop_arrivals = get_stop_arrivals(feed_dir, date)
+ stop_arrivals = get_stop_arrivals(feed_dir, date, provider)
if not stop_arrivals:
logger.warning(f"No stop arrivals found for date {date}")
@@ -579,6 +604,14 @@ def main():
output_dir = args.output_dir
feed_url = args.feed_url
+ # Get provider configuration
+ try:
+ provider = get_provider(args.provider)
+ logger.info(f"Using provider: {args.provider}")
+ except ValueError as e:
+ logger.error(str(e))
+ sys.exit(1)
+
if not feed_url:
feed_dir = args.feed_dir
else:
@@ -606,7 +639,7 @@ def main():
all_stops_summary = {}
for date in date_list:
- _, stop_summary = process_date(feed_dir, date, output_dir)
+ _, stop_summary = process_date(feed_dir, date, output_dir, provider)
all_stops_summary[date] = stop_summary
logger.info(