diff options
| author | Ariel Costas Guerrero <ariel@costas.dev> | 2026-04-05 22:30:15 +0200 |
|---|---|---|
| committer | Ariel Costas Guerrero <ariel@costas.dev> | 2026-04-05 22:30:27 +0200 |
| commit | 95f8e03affb17b3b4dd8cff202523f5b131972df (patch) | |
| tree | 23e31512167f1295defc9cc4639ff6f411c04a54 /build_renfe/build_static_feed.py | |
| parent | b2631a82a394af8c38224ae0722bcf728d651cfd (diff) | |
renfe: generate shapes properly and consistently
- Update OSRM container to use ALL SPAIN (sorry, Trencelta)
- Generate a shape per trip (no trying to reuse, since trains that change stop sequence got wrong shapes)
- Add more position corrections for FEVE
- Run separate generators for FEVE and Renfe, since sometimes OSRM would pick the one that shouldn't and generate a wrong shape
- Add a debug script to generate a trip's visualisation from GTFS, since I was about to lose my mind debugging this pile of crap
- Update README (before starting anything else)
Time spent: ca. 6 hours
Closes #1
Diffstat (limited to 'build_renfe/build_static_feed.py')
| -rw-r--r-- | build_renfe/build_static_feed.py | 139 |
1 files changed, 63 insertions, 76 deletions
diff --git a/build_renfe/build_static_feed.py b/build_renfe/build_static_feed.py index eb247a9..6c12c1f 100644 --- a/build_renfe/build_static_feed.py +++ b/build_renfe/build_static_feed.py @@ -16,6 +16,7 @@ import os import shutil import tempfile import zipfile +import binascii import pandas as pd @@ -181,13 +182,20 @@ if __name__ == "__main__": help="NAP API Key (https://nap.transportes.gob.es/)" ) parser.add_argument( - "--osrm-url", + "--osrm-std", type=str, - help="OSRM server URL", + help="OSRM standard server URL", default="http://localhost:5050", required=False, ) parser.add_argument( + "--osrm-narrow", + type=str, + help="OSRM narrow gauge server URL", + default="http://localhost:5051", + required=False, + ) + parser.add_argument( "--debug", help="Enable debug logging", action="store_true" @@ -201,8 +209,9 @@ if __name__ == "__main__": args = parser.parse_args() try: - osrm_check = requests.head(args.osrm_url, timeout=5) - GENERATE_SHAPES = osrm_check.status_code < 500 + osrm_check_std = requests.head(args.osrm_std, timeout=5) + osrm_check_narrow = requests.head(args.osrm_narrow, timeout=5) + GENERATE_SHAPES = osrm_check_std.status_code < 500 and osrm_check_narrow.status_code < 500 except requests.RequestException: GENERATE_SHAPES = False logging.warning("OSRM server is not reachable. Shape generation will be skipped.") @@ -214,6 +223,10 @@ if __name__ == "__main__": ) for feed in FEEDS.keys(): + def get_shape_id(trip_id: str) -> str: + trip_crc = binascii.crc32(trip_id.encode("utf-8")) + return f"Shape_{feed}_{trip_crc}_{trip_crc}" + INPUT_GTFS_FD, INPUT_GTFS_ZIP = tempfile.mkstemp(suffix=".zip", prefix=f"renfe_galicia_in_{feed}_") INPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_in_{feed}_") OUTPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_out_{feed}_") @@ -397,7 +410,7 @@ if __name__ == "__main__": for tig in trips_in_galicia: if GENERATE_SHAPES: - tig["shape_id"] = f"Shape_{tig['trip_id'][0:5]}" + tig["shape_id"] = get_shape_id(tig["trip_id"]) tig["trip_headsign"] = stops_by_id[last_stop_in_trips[tig["trip_id"]]]["stop_name"] with open( os.path.join(OUTPUT_GTFS_PATH, "trips.txt"), @@ -423,12 +436,19 @@ if __name__ == "__main__": logging.info("GTFS data for Galicia has been extracted successfully. Generate shapes for the trips...") if GENERATE_SHAPES: - shape_ids_total = len(set(f"Shape_{trip_id[0:5]}" for trip_id in trip_ids)) + shape_ids_total = len(set(trip["shape_id"] for trip in trips_in_galicia)) shape_ids_generated: set[str] = set() # Pre-load stops for quick lookup stops_dict = {stop["stop_id"]: stop for stop in stops_in_trips} + # Map trip_id to route_type for OSRM profile selection + trip_to_route_type = {tig["trip_id"]: tig.get("route_type", "2") for tig in trips_in_galicia} + # Fallback if not in trips_in_galicia (shouldn't happen) + route_id_to_type = {r["route_id"]: r["route_type"] for r in routes_in_trips} + for tig in trips_in_galicia: + trip_to_route_type[tig["trip_id"]] = route_id_to_type.get(tig["route_id"], "2") + # Group stop times by trip_id to avoid repeated file reads stop_times_by_trip: dict[str, list[dict]] = {} for st in stop_times_in_galicia: @@ -437,11 +457,21 @@ if __name__ == "__main__": stop_times_by_trip[tid] = [] stop_times_by_trip[tid].append(st) - OSRM_BASE_URL = f"{args.osrm_url}/route/v1/driving/" + shapes_file = open("shapes_debug.txt", "w", encoding="utf-8") + for trip_id in tqdm(trip_ids, total=shape_ids_total, desc="Generating shapes"): - shape_id = f"Shape_{trip_id[0:5]}" + shape_id = get_shape_id(trip_id) if shape_id in shape_ids_generated: continue + + route_type = trip_to_route_type.get(trip_id, "2") + osrm_profile = "driving" + + # If we are on feed cercanias or the 5-digit trip ID starts with 7, it's a narrow gauge train, otherwise standard gauge + if feed == "cercanias" or (len(trip_id) >= 5 and trip_id[0] == "7"): + OSRM_BASE_URL = f"{args.osrm_narrow}/route/v1/driving/" + else: + OSRM_BASE_URL = f"{args.osrm_std}/route/v1/driving/" stop_seq = stop_times_by_trip.get(trip_id, []) stop_seq.sort(key=lambda x: int(x["stop_sequence"].strip())) @@ -450,79 +480,36 @@ if __name__ == "__main__": continue final_shape_points = [] - i = 0 - while i < len(stop_seq) - 1: - stop_a = stops_dict[stop_seq[i]["stop_id"]] - lat_a, lon_a = float(stop_a["stop_lat"]), float(stop_a["stop_lon"]) + coordinates = [] + for st in stop_seq: + s = stops_dict[st["stop_id"]] + coordinates.append(f"{s['stop_lon']},{s['stop_lat']}") - if not is_in_bounds(lat_a, lon_a): - # S_i is out of bounds. Segment S_i -> S_{i+1} is straight line. - stop_b = stops_dict[stop_seq[i+1]["stop_id"]] - lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"]) + coords_str = ";".join(coordinates) + osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson&continue_straight=false" - segment_points = [[lon_a, lat_a], [lon_b, lat_b]] - if not final_shape_points: - final_shape_points.extend(segment_points) - else: - final_shape_points.extend(segment_points[1:]) - i += 1 - else: - # S_i is in bounds. Find how many subsequent stops are also in bounds. - j = i + 1 - while j < len(stop_seq): - stop_j = stops_dict[stop_seq[j]["stop_id"]] - if is_in_bounds(float(stop_j["stop_lat"]), float(stop_j["stop_lon"])): - j += 1 - else: - break - - # Stops from i to j-1 are in bounds. - if j > i + 1: - # We have at least two consecutive stops in bounds. - in_bounds_stops = stop_seq[i:j] - coordinates = [] - for st in in_bounds_stops: - s = stops_dict[st["stop_id"]] - coordinates.append(f"{s['stop_lon']},{s['stop_lat']}") - - coords_str = ";".join(coordinates) - osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson" + shapes_file.write(f"{trip_id} ({shape_id}): {osrm_url}\n") - segment_points = [] - try: - response = requests.get(osrm_url, timeout=10) - if response.status_code == 200: - data = response.json() - if data.get("code") == "Ok": - segment_points = data["routes"][0]["geometry"]["coordinates"] - except Exception: - pass - - if not segment_points: - # Fallback to straight lines for this whole sub-sequence - segment_points = [] - for k in range(i, j): - s = stops_dict[stop_seq[k]["stop_id"]] - segment_points.append([float(s["stop_lon"]), float(s["stop_lat"])]) - - if not final_shape_points: - final_shape_points.extend(segment_points) - else: - final_shape_points.extend(segment_points[1:]) - - i = j - 1 # Next iteration starts from S_{j-1} + try: + response = requests.get(osrm_url, timeout=10) + if response.status_code == 200: + data = response.json() + if data.get("code") == "Ok": + final_shape_points = data["routes"][0]["geometry"]["coordinates"] + logging.debug(f"OSRM success for {shape_id} ({len(coordinates)} stops): {len(final_shape_points)} points") else: - # Only S_i is in bounds, S_{i+1} is out. - # Segment S_i -> S_{i+1} is straight line. - stop_b = stops_dict[stop_seq[i+1]["stop_id"]] - lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"]) + logging.warning(f"OSRM returned error code {data.get('code')} for {shape_id} with {len(coordinates)} stops") + else: + logging.warning(f"OSRM request failed for {shape_id} with status {response.status_code}") + except Exception as e: + logging.error(f"OSRM exception for {shape_id}: {str(e)}") - segment_points = [[lon_a, lat_a], [lon_b, lat_b]] - if not final_shape_points: - final_shape_points.extend(segment_points) - else: - final_shape_points.extend(segment_points[1:]) - i += 1 + if not final_shape_points: + # Fallback to straight lines + logging.info(f"Using straight-line fallback for {shape_id}") + for st in stop_seq: + s = stops_dict[st["stop_id"]] + final_shape_points.append([float(s["stop_lon"]), float(s["stop_lat"])]) shape_ids_generated.add(shape_id) |
