aboutsummaryrefslogtreecommitdiff
path: root/src/gtfs_vigo_stops
diff options
context:
space:
mode:
authorAriel Costas Guerrero <ariel@costas.dev>2025-12-07 22:08:15 +0100
committerAriel Costas Guerrero <ariel@costas.dev>2025-12-07 22:14:54 +0100
commit5fa8d1ffeb4a3a0c5c6846de3986ec779a4fe564 (patch)
tree1da8ba51a6711121e8431eec316b9e8286a8e9d6 /src/gtfs_vigo_stops
parent8b810ceb425df619deb7153a1caa521c59050b40 (diff)
feat: implement provider-specific configurations for GTFS feed formats
Diffstat (limited to 'src/gtfs_vigo_stops')
-rw-r--r--src/gtfs_vigo_stops/src/common.py4
-rw-r--r--src/gtfs_vigo_stops/src/providers.py136
-rw-r--r--src/gtfs_vigo_stops/src/services.py16
-rw-r--r--src/gtfs_vigo_stops/stop_report.py61
4 files changed, 198 insertions, 19 deletions
diff --git a/src/gtfs_vigo_stops/src/common.py b/src/gtfs_vigo_stops/src/common.py
index 4b5cea5..fcf93d5 100644
--- a/src/gtfs_vigo_stops/src/common.py
+++ b/src/gtfs_vigo_stops/src/common.py
@@ -36,6 +36,10 @@ def get_all_feed_dates(feed_dir: str) -> List[str]:
result.append(start.strftime("%Y-%m-%d"))
start += timedelta(days=1)
return result
+ else:
+ # Return from today to 7 days ahead if no valid dates found
+ today = datetime.now()
+ return [(today + timedelta(days=i)).strftime("%Y-%m-%d") for i in range(8)]
# Fallback: use calendar_dates.txt
if os.path.exists(calendar_dates_path):
diff --git a/src/gtfs_vigo_stops/src/providers.py b/src/gtfs_vigo_stops/src/providers.py
new file mode 100644
index 0000000..f6414f6
--- /dev/null
+++ b/src/gtfs_vigo_stops/src/providers.py
@@ -0,0 +1,136 @@
+"""
+Provider-specific configuration for different GTFS feed formats.
+"""
+
+from typing import Protocol, Optional
+from src.street_name import get_street_name
+
+
+class FeedProvider(Protocol):
+ """Protocol defining provider-specific behavior for GTFS feeds."""
+
+ @staticmethod
+ def format_service_id(service_id: str) -> str:
+ """Format service_id for output."""
+ ...
+
+ @staticmethod
+ def format_trip_id(trip_id: str) -> str:
+ """Format trip_id for output."""
+ ...
+
+ @staticmethod
+ def format_route(route: str, terminus_name: str) -> str:
+ """Format route/headsign, potentially using terminus name as fallback."""
+ ...
+
+ @staticmethod
+ def extract_street_name(stop_name: str) -> str:
+ """Extract street name from stop name, or return full name."""
+ ...
+
+
+class VitrasaProvider:
+ """Provider configuration for Vitrasa (Vigo bus system)."""
+
+ @staticmethod
+ def format_service_id(service_id: str) -> str:
+ """Extract middle part from underscore-separated service_id."""
+ parts = service_id.split("_")
+ return parts[1] if len(parts) >= 2 else service_id
+
+ @staticmethod
+ def format_trip_id(trip_id: str) -> str:
+ """Extract middle parts from underscore-separated trip_id."""
+ parts = trip_id.split("_")
+ return "_".join(parts[1:3]) if len(parts) >= 3 else trip_id
+
+ @staticmethod
+ def format_route(route: str, terminus_name: str) -> str:
+ """Return route as-is for Vitrasa."""
+ return route
+
+ @staticmethod
+ def extract_street_name(stop_name: str) -> str:
+ """Extract street name from stop name using standard logic."""
+ return get_street_name(stop_name) or ""
+
+
+class RenfeProvider:
+ """Provider configuration for Renfe (Spanish rail system)."""
+
+ @staticmethod
+ def format_service_id(service_id: str) -> str:
+ """Use full service_id for Renfe (no underscores)."""
+ return service_id
+
+ @staticmethod
+ def format_trip_id(trip_id: str) -> str:
+ """Use full trip_id for Renfe (no underscores)."""
+ return trip_id
+
+ @staticmethod
+ def format_route(route: str, terminus_name: str) -> str:
+ """Use terminus name as route if route is empty."""
+ return route if route else terminus_name
+
+ @staticmethod
+ def extract_street_name(stop_name: str) -> str:
+ """Preserve full stop name for train stations."""
+ return stop_name
+
+
+class DefaultProvider:
+ """Default provider configuration for generic GTFS feeds."""
+
+ @staticmethod
+ def format_service_id(service_id: str) -> str:
+ """Try to extract from underscores, fallback to full ID."""
+ parts = service_id.split("_")
+ return parts[1] if len(parts) >= 2 else service_id
+
+ @staticmethod
+ def format_trip_id(trip_id: str) -> str:
+ """Try to extract from underscores, fallback to full ID."""
+ parts = trip_id.split("_")
+ return "_".join(parts[1:3]) if len(parts) >= 3 else trip_id
+
+ @staticmethod
+ def format_route(route: str, terminus_name: str) -> str:
+ """Use terminus name as route if route is empty."""
+ return route if route else terminus_name
+
+ @staticmethod
+ def extract_street_name(stop_name: str) -> str:
+ """Extract street name from stop name using standard logic."""
+ return get_street_name(stop_name) or ""
+
+
+# Provider registry
+PROVIDERS = {
+ "vitrasa": VitrasaProvider,
+ "renfe": RenfeProvider,
+ "default": DefaultProvider,
+}
+
+
+def get_provider(provider_name: str) -> type[FeedProvider]:
+ """
+ Get provider configuration by name.
+
+ Args:
+ provider_name: Name of the provider (case-insensitive)
+
+ Returns:
+ Provider class with configuration methods
+
+ Raises:
+ ValueError: If provider name is not recognized
+ """
+ provider_name_lower = provider_name.lower()
+ if provider_name_lower not in PROVIDERS:
+ raise ValueError(
+ f"Unknown provider: {provider_name}. "
+ f"Available providers: {', '.join(PROVIDERS.keys())}"
+ )
+ return PROVIDERS[provider_name_lower]
diff --git a/src/gtfs_vigo_stops/src/services.py b/src/gtfs_vigo_stops/src/services.py
index 9b16173..fb1110d 100644
--- a/src/gtfs_vigo_stops/src/services.py
+++ b/src/gtfs_vigo_stops/src/services.py
@@ -4,16 +4,17 @@ from src.logger import get_logger
logger = get_logger("services")
+
def get_active_services(feed_dir: str, date: str) -> list[str]:
"""
Get active services for a given date based on the 'calendar.txt' and 'calendar_dates.txt' files.
-
+
Args:
date (str): Date in 'YYYY-MM-DD' format.
-
+
Returns:
list[str]: List of active service IDs for the given date.
-
+
Raises:
ValueError: If the date format is incorrect.
"""
@@ -24,7 +25,7 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
try:
with open(os.path.join(feed_dir, 'calendar.txt'), 'r', encoding="utf-8") as calendar_file:
lines = calendar_file.readlines()
- if len(lines) >1:
+ if len(lines) > 1:
# First parse the header, get each column's index
header = lines[0].strip().split(',')
try:
@@ -36,6 +37,8 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
friday_index = header.index('friday')
saturday_index = header.index('saturday')
sunday_index = header.index('sunday')
+ start_date_index = header.index('start_date')
+ end_date_index = header.index('end_date')
except ValueError as e:
logger.error(f"Required column not found in header: {e}")
return active_services
@@ -59,8 +62,11 @@ def get_active_services(feed_dir: str, date: str) -> list[str]:
service_id = parts[service_id_index]
day_value = parts[weekday_columns[weekday]]
+ start_date = parts[start_date_index]
+ end_date = parts[end_date_index]
- if day_value == '1':
+ # Check if day of week is active AND date is within the service range
+ if day_value == '1' and start_date <= search_date <= end_date:
active_services.append(service_id)
except FileNotFoundError:
logger.warning("calendar.txt file not found.")
diff --git a/src/gtfs_vigo_stops/stop_report.py b/src/gtfs_vigo_stops/stop_report.py
index 76eb90d..f8fdc64 100644
--- a/src/gtfs_vigo_stops/stop_report.py
+++ b/src/gtfs_vigo_stops/stop_report.py
@@ -15,8 +15,9 @@ from src.routes import load_routes
from src.services import get_active_services
from src.stop_times import get_stops_for_trips, StopTime
from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code
-from src.street_name import get_street_name, normalise_stop_name
+from src.street_name import normalise_stop_name
from src.trips import get_trips_for_services, TripLine
+from src.providers import get_provider
logger = get_logger("stop_report")
@@ -43,6 +44,12 @@ def parse_args():
action="store_true",
help="Force download even if the feed hasn't been modified (only applies when using --feed-url)",
)
+ parser.add_argument(
+ "--provider",
+ type=str,
+ default="default",
+ help="Feed provider type (vitrasa, renfe, default). Default: default",
+ )
args = parser.parse_args()
if args.feed_dir and args.feed_url:
@@ -264,7 +271,7 @@ def build_trip_previous_shape_map(
return trip_previous_shape
-def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]]]:
+def get_stop_arrivals(feed_dir: str, date: str, provider) -> Dict[str, List[Dict[str, Any]]]:
"""
Process trips for the given date and organize stop arrivals.
Also includes night services from the previous day (times >= 24:00:00).
@@ -272,6 +279,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
Args:
feed_dir: Path to the GTFS feed directory
date: Date in YYYY-MM-DD format
+ provider: Provider class with feed-specific formatting methods
Returns:
Dictionary mapping stop_code to lists of arrival information.
@@ -323,11 +331,14 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
routes = load_routes(feed_dir)
logger.info(f"Loaded {len(routes)} routes from feed.")
- # Create a reverse lookup from stop_id to stop_code
+ # Create a reverse lookup from stop_id to stop_code (or stop_id as fallback)
stop_id_to_code = {}
for stop_id, stop in stops.items():
if stop.stop_code:
stop_id_to_code[stop_id] = get_numeric_code(stop.stop_code)
+ else:
+ # Fallback to stop_id if stop_code is not available (e.g., train stations)
+ stop_id_to_code[stop_id] = stop_id
# Organize data by stop_code
stop_arrivals = {}
@@ -370,7 +381,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
for name in stop_names:
street = street_cache.get(name)
if street is None:
- street = get_street_name(name) or ""
+ street = provider.extract_street_name(name)
street_cache[name] = street
if street != previous_street:
segment_names.append(street)
@@ -398,10 +409,20 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
starting_stop_name = first_stop.stop_name if first_stop else "Unknown Stop"
terminus_stop_name = last_stop.stop_name if last_stop else "Unknown Stop"
- starting_code = get_numeric_code(
- first_stop.stop_code) if first_stop else ""
- terminus_code = get_numeric_code(
- last_stop.stop_code) if last_stop else ""
+ # Get stop codes with fallback to stop_id if stop_code is empty
+ if first_stop:
+ starting_code = get_numeric_code(first_stop.stop_code)
+ if not starting_code:
+ starting_code = first_stop_time.stop_id
+ else:
+ starting_code = ""
+
+ if last_stop:
+ terminus_code = get_numeric_code(last_stop.stop_code)
+ if not terminus_code:
+ terminus_code = last_stop_time.stop_id
+ else:
+ terminus_code = ""
starting_name = normalise_stop_name(starting_stop_name)
terminus_name = normalise_stop_name(terminus_stop_name)
@@ -465,7 +486,11 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
else:
next_streets = []
- trip_id_fmt = "_".join(trip_id.split("_")[1:3])
+ # Format IDs and route using provider-specific logic
+ service_id_fmt = provider.format_service_id(service_id)
+ trip_id_fmt = provider.format_trip_id(trip_id)
+ route_fmt = provider.format_route(
+ trip_headsign, terminus_name)
# Get previous trip shape_id if available
previous_trip_shape_id = trip_previous_shape_map.get(
@@ -473,10 +498,10 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
stop_arrivals[stop_code].append(
{
- "service_id": service_id.split("_")[1],
+ "service_id": service_id_fmt,
"trip_id": trip_id_fmt,
"line": route_short_name,
- "route": trip_headsign,
+ "route": route_fmt,
"shape_id": getattr(trip, "shape_id", ""),
"stop_sequence": stop_time.stop_sequence,
"shape_dist_traveled": getattr(
@@ -507,7 +532,7 @@ def get_stop_arrivals(feed_dir: str, date: str) -> Dict[str, List[Dict[str, Any]
def process_date(
- feed_dir: str, date: str, output_dir: str
+ feed_dir: str, date: str, output_dir: str, provider
) -> tuple[str, Dict[str, int]]:
"""
Process a single date and write its stop JSON files.
@@ -520,7 +545,7 @@ def process_date(
stops_by_code = get_all_stops_by_code(feed_dir)
# Get all stop arrivals for the current date
- stop_arrivals = get_stop_arrivals(feed_dir, date)
+ stop_arrivals = get_stop_arrivals(feed_dir, date, provider)
if not stop_arrivals:
logger.warning(f"No stop arrivals found for date {date}")
@@ -579,6 +604,14 @@ def main():
output_dir = args.output_dir
feed_url = args.feed_url
+ # Get provider configuration
+ try:
+ provider = get_provider(args.provider)
+ logger.info(f"Using provider: {args.provider}")
+ except ValueError as e:
+ logger.error(str(e))
+ sys.exit(1)
+
if not feed_url:
feed_dir = args.feed_dir
else:
@@ -606,7 +639,7 @@ def main():
all_stops_summary = {}
for date in date_list:
- _, stop_summary = process_date(feed_dir, date, output_dir)
+ _, stop_summary = process_date(feed_dir, date, output_dir, provider)
all_stops_summary[date] = stop_summary
logger.info(