aboutsummaryrefslogtreecommitdiff
path: root/build_renfe/build_static_feed.py
diff options
context:
space:
mode:
authorAriel Costas Guerrero <ariel@costas.dev>2026-04-05 22:30:15 +0200
committerAriel Costas Guerrero <ariel@costas.dev>2026-04-05 22:30:27 +0200
commit95f8e03affb17b3b4dd8cff202523f5b131972df (patch)
tree23e31512167f1295defc9cc4639ff6f411c04a54 /build_renfe/build_static_feed.py
parentb2631a82a394af8c38224ae0722bcf728d651cfd (diff)
renfe: generate shapes properly and consistently
- Update OSRM container to use ALL SPAIN (sorry, Trencelta) - Generate a shape per trip (no trying to reuse, since trains that change stop sequence got wrong shapes) - Add more position corrections for FEVE - Run separate generators for FEVE and Renfe, since sometimes OSRM would pick the one that shouldn't and generate a wrong shape - Add a debug script to generate a trip's visualisation from GTFS, since I was about to lose my mind debugging this pile of crap - Update README (before starting anything else) Time spent: ca. 6 hours Closes #1
Diffstat (limited to 'build_renfe/build_static_feed.py')
-rw-r--r--build_renfe/build_static_feed.py139
1 files changed, 63 insertions, 76 deletions
diff --git a/build_renfe/build_static_feed.py b/build_renfe/build_static_feed.py
index eb247a9..6c12c1f 100644
--- a/build_renfe/build_static_feed.py
+++ b/build_renfe/build_static_feed.py
@@ -16,6 +16,7 @@ import os
import shutil
import tempfile
import zipfile
+import binascii
import pandas as pd
@@ -181,13 +182,20 @@ if __name__ == "__main__":
help="NAP API Key (https://nap.transportes.gob.es/)"
)
parser.add_argument(
- "--osrm-url",
+ "--osrm-std",
type=str,
- help="OSRM server URL",
+ help="OSRM standard server URL",
default="http://localhost:5050",
required=False,
)
parser.add_argument(
+ "--osrm-narrow",
+ type=str,
+ help="OSRM narrow gauge server URL",
+ default="http://localhost:5051",
+ required=False,
+ )
+ parser.add_argument(
"--debug",
help="Enable debug logging",
action="store_true"
@@ -201,8 +209,9 @@ if __name__ == "__main__":
args = parser.parse_args()
try:
- osrm_check = requests.head(args.osrm_url, timeout=5)
- GENERATE_SHAPES = osrm_check.status_code < 500
+ osrm_check_std = requests.head(args.osrm_std, timeout=5)
+ osrm_check_narrow = requests.head(args.osrm_narrow, timeout=5)
+ GENERATE_SHAPES = osrm_check_std.status_code < 500 and osrm_check_narrow.status_code < 500
except requests.RequestException:
GENERATE_SHAPES = False
logging.warning("OSRM server is not reachable. Shape generation will be skipped.")
@@ -214,6 +223,10 @@ if __name__ == "__main__":
)
for feed in FEEDS.keys():
+ def get_shape_id(trip_id: str) -> str:
+ trip_crc = binascii.crc32(trip_id.encode("utf-8"))
+ return f"Shape_{feed}_{trip_crc}_{trip_crc}"
+
INPUT_GTFS_FD, INPUT_GTFS_ZIP = tempfile.mkstemp(suffix=".zip", prefix=f"renfe_galicia_in_{feed}_")
INPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_in_{feed}_")
OUTPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_out_{feed}_")
@@ -397,7 +410,7 @@ if __name__ == "__main__":
for tig in trips_in_galicia:
if GENERATE_SHAPES:
- tig["shape_id"] = f"Shape_{tig['trip_id'][0:5]}"
+ tig["shape_id"] = get_shape_id(tig["trip_id"])
tig["trip_headsign"] = stops_by_id[last_stop_in_trips[tig["trip_id"]]]["stop_name"]
with open(
os.path.join(OUTPUT_GTFS_PATH, "trips.txt"),
@@ -423,12 +436,19 @@ if __name__ == "__main__":
logging.info("GTFS data for Galicia has been extracted successfully. Generate shapes for the trips...")
if GENERATE_SHAPES:
- shape_ids_total = len(set(f"Shape_{trip_id[0:5]}" for trip_id in trip_ids))
+ shape_ids_total = len(set(trip["shape_id"] for trip in trips_in_galicia))
shape_ids_generated: set[str] = set()
# Pre-load stops for quick lookup
stops_dict = {stop["stop_id"]: stop for stop in stops_in_trips}
+ # Map trip_id to route_type for OSRM profile selection
+ trip_to_route_type = {tig["trip_id"]: tig.get("route_type", "2") for tig in trips_in_galicia}
+ # Fallback if not in trips_in_galicia (shouldn't happen)
+ route_id_to_type = {r["route_id"]: r["route_type"] for r in routes_in_trips}
+ for tig in trips_in_galicia:
+ trip_to_route_type[tig["trip_id"]] = route_id_to_type.get(tig["route_id"], "2")
+
# Group stop times by trip_id to avoid repeated file reads
stop_times_by_trip: dict[str, list[dict]] = {}
for st in stop_times_in_galicia:
@@ -437,11 +457,21 @@ if __name__ == "__main__":
stop_times_by_trip[tid] = []
stop_times_by_trip[tid].append(st)
- OSRM_BASE_URL = f"{args.osrm_url}/route/v1/driving/"
+ shapes_file = open("shapes_debug.txt", "w", encoding="utf-8")
+
for trip_id in tqdm(trip_ids, total=shape_ids_total, desc="Generating shapes"):
- shape_id = f"Shape_{trip_id[0:5]}"
+ shape_id = get_shape_id(trip_id)
if shape_id in shape_ids_generated:
continue
+
+ route_type = trip_to_route_type.get(trip_id, "2")
+ osrm_profile = "driving"
+
+ # If we are on feed cercanias or the 5-digit trip ID starts with 7, it's a narrow gauge train, otherwise standard gauge
+ if feed == "cercanias" or (len(trip_id) >= 5 and trip_id[0] == "7"):
+ OSRM_BASE_URL = f"{args.osrm_narrow}/route/v1/driving/"
+ else:
+ OSRM_BASE_URL = f"{args.osrm_std}/route/v1/driving/"
stop_seq = stop_times_by_trip.get(trip_id, [])
stop_seq.sort(key=lambda x: int(x["stop_sequence"].strip()))
@@ -450,79 +480,36 @@ if __name__ == "__main__":
continue
final_shape_points = []
- i = 0
- while i < len(stop_seq) - 1:
- stop_a = stops_dict[stop_seq[i]["stop_id"]]
- lat_a, lon_a = float(stop_a["stop_lat"]), float(stop_a["stop_lon"])
+ coordinates = []
+ for st in stop_seq:
+ s = stops_dict[st["stop_id"]]
+ coordinates.append(f"{s['stop_lon']},{s['stop_lat']}")
- if not is_in_bounds(lat_a, lon_a):
- # S_i is out of bounds. Segment S_i -> S_{i+1} is straight line.
- stop_b = stops_dict[stop_seq[i+1]["stop_id"]]
- lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"])
+ coords_str = ";".join(coordinates)
+ osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson&continue_straight=false"
- segment_points = [[lon_a, lat_a], [lon_b, lat_b]]
- if not final_shape_points:
- final_shape_points.extend(segment_points)
- else:
- final_shape_points.extend(segment_points[1:])
- i += 1
- else:
- # S_i is in bounds. Find how many subsequent stops are also in bounds.
- j = i + 1
- while j < len(stop_seq):
- stop_j = stops_dict[stop_seq[j]["stop_id"]]
- if is_in_bounds(float(stop_j["stop_lat"]), float(stop_j["stop_lon"])):
- j += 1
- else:
- break
-
- # Stops from i to j-1 are in bounds.
- if j > i + 1:
- # We have at least two consecutive stops in bounds.
- in_bounds_stops = stop_seq[i:j]
- coordinates = []
- for st in in_bounds_stops:
- s = stops_dict[st["stop_id"]]
- coordinates.append(f"{s['stop_lon']},{s['stop_lat']}")
-
- coords_str = ";".join(coordinates)
- osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson"
+ shapes_file.write(f"{trip_id} ({shape_id}): {osrm_url}\n")
- segment_points = []
- try:
- response = requests.get(osrm_url, timeout=10)
- if response.status_code == 200:
- data = response.json()
- if data.get("code") == "Ok":
- segment_points = data["routes"][0]["geometry"]["coordinates"]
- except Exception:
- pass
-
- if not segment_points:
- # Fallback to straight lines for this whole sub-sequence
- segment_points = []
- for k in range(i, j):
- s = stops_dict[stop_seq[k]["stop_id"]]
- segment_points.append([float(s["stop_lon"]), float(s["stop_lat"])])
-
- if not final_shape_points:
- final_shape_points.extend(segment_points)
- else:
- final_shape_points.extend(segment_points[1:])
-
- i = j - 1 # Next iteration starts from S_{j-1}
+ try:
+ response = requests.get(osrm_url, timeout=10)
+ if response.status_code == 200:
+ data = response.json()
+ if data.get("code") == "Ok":
+ final_shape_points = data["routes"][0]["geometry"]["coordinates"]
+ logging.debug(f"OSRM success for {shape_id} ({len(coordinates)} stops): {len(final_shape_points)} points")
else:
- # Only S_i is in bounds, S_{i+1} is out.
- # Segment S_i -> S_{i+1} is straight line.
- stop_b = stops_dict[stop_seq[i+1]["stop_id"]]
- lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"])
+ logging.warning(f"OSRM returned error code {data.get('code')} for {shape_id} with {len(coordinates)} stops")
+ else:
+ logging.warning(f"OSRM request failed for {shape_id} with status {response.status_code}")
+ except Exception as e:
+ logging.error(f"OSRM exception for {shape_id}: {str(e)}")
- segment_points = [[lon_a, lat_a], [lon_b, lat_b]]
- if not final_shape_points:
- final_shape_points.extend(segment_points)
- else:
- final_shape_points.extend(segment_points[1:])
- i += 1
+ if not final_shape_points:
+ # Fallback to straight lines
+ logging.info(f"Using straight-line fallback for {shape_id}")
+ for st in stop_seq:
+ s = stops_dict[st["stop_id"]]
+ final_shape_points.append([float(s["stop_lon"]), float(s["stop_lat"])])
shape_ids_generated.add(shape_id)