aboutsummaryrefslogtreecommitdiff
path: root/src/gtfs_perstop_report/stop_report.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gtfs_perstop_report/stop_report.py')
-rw-r--r--src/gtfs_perstop_report/stop_report.py34
1 files changed, 26 insertions, 8 deletions
diff --git a/src/gtfs_perstop_report/stop_report.py b/src/gtfs_perstop_report/stop_report.py
index 3bbdf11..ef40417 100644
--- a/src/gtfs_perstop_report/stop_report.py
+++ b/src/gtfs_perstop_report/stop_report.py
@@ -13,6 +13,7 @@ from src.logger import get_logger
from src.report_writer import write_stop_json, write_stop_protobuf
from src.routes import load_routes
from src.services import get_active_services
+from src.rolling_dates import create_rolling_date_config
from src.stop_times import get_stops_for_trips, StopTime
from src.stops import get_all_stops, get_all_stops_by_code, get_numeric_code
from src.street_name import normalise_stop_name
@@ -49,6 +50,8 @@ def parse_args():
default="default",
help="Feed provider type (vitrasa, renfe, default). Default: default",
)
+ parser.add_argument('--rolling-dates', type=str,
+ help="Path to rolling dates configuration file (JSON)")
args = parser.parse_args()
if args.feed_dir and args.feed_url:
@@ -270,7 +273,7 @@ def build_trip_previous_shape_map(
def get_stop_arrivals(
- feed_dir: str, date: str, provider
+ feed_dir: str, date: str, provider, rolling_config=None
) -> Dict[str, List[Dict[str, Any]]]:
"""
Process trips for the given date and organize stop arrivals.
@@ -280,6 +283,7 @@ def get_stop_arrivals(
feed_dir: Path to the GTFS feed directory
date: Date in YYYY-MM-DD format
provider: Provider class with feed-specific formatting methods
+ rolling_config: Optional RollingDateConfig for date mapping
Returns:
Dictionary mapping stop_code to lists of arrival information.
@@ -289,14 +293,19 @@ def get_stop_arrivals(
stops = get_all_stops(feed_dir)
logger.info(f"Found {len(stops)} stops in the feed.")
- active_services = get_active_services(feed_dir, date)
+ effective_date = date
+ if rolling_config and rolling_config.is_rolling_date(date):
+ effective_date = rolling_config.get_source_date(date)
+ logger.info(f"Using source date {effective_date} for rolling date {date}")
+
+ active_services = get_active_services(feed_dir, effective_date)
if not active_services:
- logger.info("No active services found for the given date.")
+ logger.info(f"No active services found for the given date {effective_date}.")
- logger.info(f"Found {len(active_services)} active services for date {date}.")
+ logger.info(f"Found {len(active_services)} active services for date {effective_date}.")
# Also get services from the previous day to include night services (times >= 24:00)
- prev_date = (datetime.strptime(date, "%Y-%m-%d") - timedelta(days=1)).strftime(
+ prev_date = (datetime.strptime(effective_date, "%Y-%m-%d") - timedelta(days=1)).strftime(
"%Y-%m-%d"
)
prev_services = get_active_services(feed_dir, prev_date)
@@ -527,7 +536,7 @@ def get_stop_arrivals(
def process_date(
- feed_dir: str, date: str, output_dir: str, provider
+ feed_dir: str, date: str, output_dir: str, provider, rolling_config=None
) -> tuple[str, Dict[str, int]]:
"""
Process a single date and write its stop JSON files.
@@ -540,7 +549,7 @@ def process_date(
stops_by_code = get_all_stops_by_code(feed_dir)
# Get all stop arrivals for the current date
- stop_arrivals = get_stop_arrivals(feed_dir, date, provider)
+ stop_arrivals = get_stop_arrivals(feed_dir, date, provider, rolling_config)
if not stop_arrivals:
logger.warning(f"No stop arrivals found for date {date}")
@@ -622,6 +631,15 @@ def main():
return
date_list = all_dates
+ # Handle rolling dates
+ rolling_config = create_rolling_date_config(args.rolling_dates)
+ if rolling_config.has_mappings():
+ for target_date in rolling_config.get_all_mappings().keys():
+ if target_date not in date_list:
+ date_list.append(target_date)
+ # Sort dates to ensure they are processed in order
+ date_list.sort()
+
# Ensure date_list is not empty before processing
if not date_list:
logger.error("No valid dates to process.")
@@ -633,7 +651,7 @@ def main():
all_stops_summary = {}
for date in date_list:
- _, stop_summary = process_date(feed_dir, date, output_dir, provider)
+ _, stop_summary = process_date(feed_dir, date, output_dir, provider, rolling_config)
all_stops_summary[date] = stop_summary
logger.info("Finished processing all dates. Beginning with shape transformation.")