From 84db1ca075dc63ccb02da825948d95ad09f94e4d Mon Sep 17 00:00:00 2001 From: Ariel Costas Guerrero Date: Thu, 26 Mar 2026 09:53:39 +0100 Subject: Convert submodules to regular repo files, add custom feeds --- build_renfe/build_static_feed.py | 564 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 564 insertions(+) create mode 100644 build_renfe/build_static_feed.py (limited to 'build_renfe/build_static_feed.py') diff --git a/build_renfe/build_static_feed.py b/build_renfe/build_static_feed.py new file mode 100644 index 0000000..a60360f --- /dev/null +++ b/build_renfe/build_static_feed.py @@ -0,0 +1,564 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "requests", +# "tqdm", +# ] +# /// + +from argparse import ArgumentParser +import csv +import json +import logging +import os +import shutil +import tempfile +import zipfile + +import requests +from tqdm import tqdm + + +# Approximate bounding box for Galicia +BOUNDS = {"SOUTH": 41.820455, "NORTH": 43.937462, "WEST": -9.437256, "EAST": -6.767578} + +FEEDS = { + "general": "1098", + "cercanias": "1130", + "feve": "1131" +} + + +def is_in_bounds(lat: float, lon: float) -> bool: + return ( + BOUNDS["SOUTH"] <= lat <= BOUNDS["NORTH"] + and BOUNDS["WEST"] <= lon <= BOUNDS["EAST"] + ) + + +def get_stops_in_bounds(stops_file: str): + with open(stops_file, "r", encoding="utf-8") as f: + stops = csv.DictReader(f) + + for stop in stops: + lat = float(stop["stop_lat"]) + lon = float(stop["stop_lon"]) + if is_in_bounds(lat, lon): + yield stop + + +def get_trip_ids_for_stops(stoptimes_file: str, stop_ids: list[str]) -> list[str]: + trip_ids: set[str] = set() + + with open(stoptimes_file, "r", encoding="utf-8") as f: + stop_times = csv.DictReader(f) + + for stop_time in stop_times: + if stop_time["stop_id"] in stop_ids: + trip_ids.add(stop_time["trip_id"]) + + return list(trip_ids) + + +def get_routes_for_trips(trips_file: str, trip_ids: list[str]) -> list[str]: + route_ids: set[str] = set() + + with open(trips_file, "r", encoding="utf-8") as f: + trips = csv.DictReader(f) + + for trip in trips: + if trip["trip_id"] in trip_ids: + route_ids.add(trip["route_id"]) + + return list(route_ids) + + +def get_distinct_stops_from_stop_times( + stoptimes_file: str, trip_ids: list[str] +) -> list[str]: + stop_ids: set[str] = set() + + with open(stoptimes_file, "r", encoding="utf-8") as f: + stop_times = csv.DictReader(f) + + for stop_time in stop_times: + if stop_time["trip_id"] in trip_ids: + stop_ids.add(stop_time["stop_id"]) + + return list(stop_ids) + + +def get_last_stop_for_trips( + stoptimes_file: str, trip_ids: list[str] +) -> dict[str, str]: + trip_last: dict[str, str] = {} + trip_last_seq: dict[str, int] = {} + + with open(stoptimes_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + if reader.fieldnames is None: + raise Exception("Fuck you, screw you, fieldnames is None and you just get rekt") + reader.fieldnames = [name.strip() for name in reader.fieldnames] + + for stop_time in reader: + if stop_time["trip_id"] in trip_ids: + trip_id = stop_time["trip_id"] + if trip_last.get(trip_id, None) is None: + trip_last[trip_id] = "" + trip_last_seq[trip_id] = -1 + + this_stop_seq = int(stop_time["stop_sequence"]) + if this_stop_seq > trip_last_seq[trip_id]: + trip_last_seq[trip_id] = this_stop_seq + trip_last[trip_id] = stop_time["stop_id"] + + return trip_last + +def get_rows_by_ids(input_file: str, id_field: str, ids: list[str]) -> list[dict]: + rows: list[dict] = [] + + with open(input_file, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + if reader.fieldnames is None: + raise Exception("Fuck you, screw you, fieldnames is None and you just get rekt") + reader.fieldnames = [name.strip() for name in reader.fieldnames] + + for row in reader: + if row[id_field].strip() in ids: + rows.append(row) + + return rows + +# First colour is background, second is text +SERVICE_COLOURS = { + "REGIONAL": ("9A0060", "FFFFFF"), + "REG.EXP.": ("9A0060", "FFFFFF"), + + "MD": ("F85B0B", "000000"), + "AVANT": ("F85B0B", "000000"), + + "AVLO": ("05CEC6", "000000"), + "AVE": ("FFFFFF", "9A0060"), + "ALVIA": ("FFFFFF", "9A0060"), + + "INTERCITY": ("606060", "FFFFFF"), + + "TRENCELTA": ("00824A", "FFFFFF"), + + # Cercanías Ferrol-Ortigueira + "C1": ("F5333F", "FFFFFF") +} + + +def colour_route(route_short_name: str) -> tuple[str, str]: + """ + Returns the colours to be used for a route from its short name. + + :param route_short_name: The routes.txt's route_short_name + :return: A tuple containing the "route_color" (background) first and "route_text_color" (text) second + :rtype: tuple[str, str] + """ + + route_name_searched = route_short_name.strip().upper() + + if route_name_searched in SERVICE_COLOURS: + return SERVICE_COLOURS[route_name_searched] + + print("Unknown route short name:", route_short_name) + return ("000000", "FFFFFF") + + +if __name__ == "__main__": + parser = ArgumentParser( + description="Extract GTFS data for Galicia from Renfe GTFS feed." + ) + parser.add_argument( + "nap_apikey", + type=str, + help="NAP API Key (https://nap.transportes.gob.es/)" + ) + parser.add_argument( + "--osrm-url", + type=str, + help="OSRM server URL", + default="http://localhost:5050", + required=False, + ) + parser.add_argument( + "--debug", + help="Enable debug logging", + action="store_true" + ) + + args = parser.parse_args() + + try: + osrm_check = requests.head(args.osrm_url, timeout=5) + GENERATE_SHAPES = osrm_check.status_code < 500 + except requests.RequestException: + GENERATE_SHAPES = False + logging.warning("OSRM server is not reachable. Shape generation will be skipped.") + + + logging.basicConfig( + level=logging.DEBUG if args.debug else logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + ) + + for feed in FEEDS.keys(): + INPUT_GTFS_FD, INPUT_GTFS_ZIP = tempfile.mkstemp(suffix=".zip", prefix=f"renfe_galicia_in_{feed}_") + INPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_in_{feed}_") + OUTPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_out_{feed}_") + OUTPUT_GTFS_ZIP = os.path.join(os.path.dirname(__file__), f"gtfs_renfe_galicia_{feed}.zip") + + FEED_URL = f"https://nap.transportes.gob.es/api/Fichero/download/{FEEDS[feed]}" + + logging.info(f"Downloading GTFS feed '{feed}'...") + response = requests.get(FEED_URL, headers={"ApiKey": args.nap_apikey}) + with open(INPUT_GTFS_ZIP, "wb") as f: + f.write(response.content) + + # Unzip the GTFS feed + with zipfile.ZipFile(INPUT_GTFS_ZIP, "r") as zip_ref: + zip_ref.extractall(INPUT_GTFS_PATH) + + STOPS_FILE = os.path.join(INPUT_GTFS_PATH, "stops.txt") + STOP_TIMES_FILE = os.path.join(INPUT_GTFS_PATH, "stop_times.txt") + TRIPS_FILE = os.path.join(INPUT_GTFS_PATH, "trips.txt") + + all_stops_applicable = [stop for stop in get_stops_in_bounds(STOPS_FILE)] + logging.info(f"Total stops in Galicia: {len(all_stops_applicable)}") + + stop_ids = [stop["stop_id"] for stop in all_stops_applicable] + trip_ids = get_trip_ids_for_stops(STOP_TIMES_FILE, stop_ids) + + route_ids = get_routes_for_trips(TRIPS_FILE, trip_ids) + + logging.info(f"Feed parsed successfully. Stops: {len(stop_ids)}, trips: {len(trip_ids)}, routes: {len(route_ids)}") + if len(trip_ids) == 0 or len(route_ids) == 0: + logging.warning(f"No trips or routes found for feed '{feed}'. Skipping...") + shutil.rmtree(INPUT_GTFS_PATH) + shutil.rmtree(OUTPUT_GTFS_PATH) + continue + + # Copy agency.txt, calendar.txt, calendar_dates.txt as is + for filename in ["agency.txt", "calendar.txt", "calendar_dates.txt"]: + src_path = os.path.join(INPUT_GTFS_PATH, filename) + dest_path = os.path.join(OUTPUT_GTFS_PATH, filename) + if os.path.exists(src_path): + shutil.copy(src_path, dest_path) + else: + logging.debug(f"File {filename} does not exist in the input GTFS feed.") + + # Write new stops.txt with the stops in any trip that passes through Galicia + with open( + os.path.join(os.path.dirname(__file__), "stop_overrides.json"), + "r", + encoding="utf-8", + ) as f: + stop_overrides_raw: list = json.load(f) + stop_overrides = { + item["stop_id"]: item + for item in stop_overrides_raw + } + logging.debug(f"Loaded stop overrides for {len(stop_overrides)} stops.") + + deleted_stop_ids: set[str] = set() + for stop_id, override_item in stop_overrides.items(): + if override_item.get("_delete", False): + if override_item.get("feed_id", None) is None or override_item["feed_id"] == feed: + deleted_stop_ids.add(stop_id) + logging.debug(f"Stops marked for deletion in feed '{feed}': {len(deleted_stop_ids)}") + + distinct_stop_ids = get_distinct_stops_from_stop_times( + STOP_TIMES_FILE, trip_ids + ) + stops_in_trips = get_rows_by_ids(STOPS_FILE, "stop_id", distinct_stop_ids) + for stop in stops_in_trips: + stop["stop_code"] = stop["stop_id"] + if stop_overrides.get(stop["stop_id"], None) is not None: + override_item = stop_overrides[stop["stop_id"]] + + if override_item.get("feed_id", None) is not None and override_item["feed_id"] != feed: + continue + + for key, value in override_item.items(): + if key in ("stop_id", "feed_id", "_delete"): + continue + stop[key] = value + + if stop["stop_name"].startswith("Estación de tren "): + stop["stop_name"] = stop["stop_name"][17:].strip() + stop["stop_name"] = " ".join([ + word.capitalize() for word in stop["stop_name"].split(" ") if word != "de" + ]) + + stops_in_trips = [stop for stop in stops_in_trips if stop["stop_id"] not in deleted_stop_ids] + + with open( + os.path.join(OUTPUT_GTFS_PATH, "stops.txt"), + "w", + encoding="utf-8", + newline="", + ) as f: + writer = csv.DictWriter(f, fieldnames=stops_in_trips[0].keys()) + writer.writeheader() + writer.writerows(stops_in_trips) + + # Write new routes.txt with the routes that have trips in Galicia + routes_in_trips = get_rows_by_ids( + os.path.join(INPUT_GTFS_PATH, "routes.txt"), "route_id", route_ids + ) + + if feed == "feve": + feve_c1_route_ids = ["46T0001C1", "46T0002C1"] + new_route_id = "FEVE_C1" + + # Find agency_id and a template route + template_route = routes_in_trips[0] if routes_in_trips else {} + agency_id = "1" + for r in routes_in_trips: + if r["route_id"].strip() in feve_c1_route_ids: + agency_id = r.get("agency_id", "1") + template_route = r + break + + # Filter out old routes + routes_in_trips = [r for r in routes_in_trips if r["route_id"].strip() not in feve_c1_route_ids] + + # Add new route + new_route = template_route.copy() + new_route.update({ + "route_id": new_route_id, + "route_short_name": "C1", + "route_long_name": "Ferrol - Xuvia - San Sadurniño - Ortigueira", + "route_type": "2", + }) + if "agency_id" in template_route: + new_route["agency_id"] = agency_id + + routes_in_trips.append(new_route) + + for route in routes_in_trips: + route["route_color"], route["route_text_color"] = colour_route( + route["route_short_name"] + ) + with open( + os.path.join(OUTPUT_GTFS_PATH, "routes.txt"), + "w", + encoding="utf-8", + newline="", + ) as f: + writer = csv.DictWriter(f, fieldnames=routes_in_trips[0].keys()) + writer.writeheader() + writer.writerows(routes_in_trips) + + # Write new trips.txt with the trips that pass through Galicia + # Load stop_times early so we can filter deleted stops and renumber sequences + stop_times_in_galicia = get_rows_by_ids(STOP_TIMES_FILE, "trip_id", trip_ids) + stop_times_in_galicia = [st for st in stop_times_in_galicia if st["stop_id"].strip() not in deleted_stop_ids] + stop_times_in_galicia.sort(key=lambda x: (x["trip_id"], int(x["stop_sequence"].strip()))) + trip_seq_counter: dict[str, int] = {} + for st in stop_times_in_galicia: + tid = st["trip_id"] + if tid not in trip_seq_counter: + trip_seq_counter[tid] = 0 + st["stop_sequence"] = str(trip_seq_counter[tid]) + trip_seq_counter[tid] += 1 + + last_stop_in_trips: dict[str, str] = {} + trip_last_seq: dict[str, int] = {} + for st in stop_times_in_galicia: + tid = st["trip_id"] + seq = int(st["stop_sequence"]) + if seq > trip_last_seq.get(tid, -1): + trip_last_seq[tid] = seq + last_stop_in_trips[tid] = st["stop_id"].strip() + + trips_in_galicia = get_rows_by_ids(TRIPS_FILE, "trip_id", trip_ids) + + if feed == "feve": + feve_c1_route_ids = ["46T0001C1", "46T0002C1"] + new_route_id = "FEVE_C1" + for tig in trips_in_galicia: + if tig["route_id"].strip() in feve_c1_route_ids: + tig["route_id"] = new_route_id + tig["direction_id"] = "1" if tig["route_id"].strip()[6] == "2" else "0" + + stops_by_id = {stop["stop_id"]: stop for stop in stops_in_trips} + + for tig in trips_in_galicia: + if GENERATE_SHAPES: + tig["shape_id"] = f"Shape_{tig['trip_id'][0:5]}" + tig["trip_headsign"] = stops_by_id[last_stop_in_trips[tig["trip_id"]]]["stop_name"] + with open( + os.path.join(OUTPUT_GTFS_PATH, "trips.txt"), + "w", + encoding="utf-8", + newline="", + ) as f: + writer = csv.DictWriter(f, fieldnames=trips_in_galicia[0].keys()) + writer.writeheader() + writer.writerows(trips_in_galicia) + + # Write new stop_times.txt with the stop times for any trip that passes through Galicia + with open( + os.path.join(OUTPUT_GTFS_PATH, "stop_times.txt"), + "w", + encoding="utf-8", + newline="", + ) as f: + writer = csv.DictWriter(f, fieldnames=stop_times_in_galicia[0].keys()) + writer.writeheader() + writer.writerows(stop_times_in_galicia) + + logging.info("GTFS data for Galicia has been extracted successfully. Generate shapes for the trips...") + + if GENERATE_SHAPES: + shape_ids_total = len(set(f"Shape_{trip_id[0:5]}" for trip_id in trip_ids)) + shape_ids_generated: set[str] = set() + + # Pre-load stops for quick lookup + stops_dict = {stop["stop_id"]: stop for stop in stops_in_trips} + + # Group stop times by trip_id to avoid repeated file reads + stop_times_by_trip: dict[str, list[dict]] = {} + for st in stop_times_in_galicia: + tid = st["trip_id"] + if tid not in stop_times_by_trip: + stop_times_by_trip[tid] = [] + stop_times_by_trip[tid].append(st) + + OSRM_BASE_URL = f"{args.osrm_url}/route/v1/driving/" + for trip_id in tqdm(trip_ids, total=shape_ids_total, desc="Generating shapes"): + shape_id = f"Shape_{trip_id[0:5]}" + if shape_id in shape_ids_generated: + continue + + stop_seq = stop_times_by_trip.get(trip_id, []) + stop_seq.sort(key=lambda x: int(x["stop_sequence"].strip())) + + if not stop_seq: + continue + + final_shape_points = [] + i = 0 + while i < len(stop_seq) - 1: + stop_a = stops_dict[stop_seq[i]["stop_id"]] + lat_a, lon_a = float(stop_a["stop_lat"]), float(stop_a["stop_lon"]) + + if not is_in_bounds(lat_a, lon_a): + # S_i is out of bounds. Segment S_i -> S_{i+1} is straight line. + stop_b = stops_dict[stop_seq[i+1]["stop_id"]] + lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"]) + + segment_points = [[lon_a, lat_a], [lon_b, lat_b]] + if not final_shape_points: + final_shape_points.extend(segment_points) + else: + final_shape_points.extend(segment_points[1:]) + i += 1 + else: + # S_i is in bounds. Find how many subsequent stops are also in bounds. + j = i + 1 + while j < len(stop_seq): + stop_j = stops_dict[stop_seq[j]["stop_id"]] + if is_in_bounds(float(stop_j["stop_lat"]), float(stop_j["stop_lon"])): + j += 1 + else: + break + + # Stops from i to j-1 are in bounds. + if j > i + 1: + # We have at least two consecutive stops in bounds. + in_bounds_stops = stop_seq[i:j] + coordinates = [] + for st in in_bounds_stops: + s = stops_dict[st["stop_id"]] + coordinates.append(f"{s['stop_lon']},{s['stop_lat']}") + + coords_str = ";".join(coordinates) + osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson" + + segment_points = [] + try: + response = requests.get(osrm_url, timeout=10) + if response.status_code == 200: + data = response.json() + if data.get("code") == "Ok": + segment_points = data["routes"][0]["geometry"]["coordinates"] + except Exception: + pass + + if not segment_points: + # Fallback to straight lines for this whole sub-sequence + segment_points = [] + for k in range(i, j): + s = stops_dict[stop_seq[k]["stop_id"]] + segment_points.append([float(s["stop_lon"]), float(s["stop_lat"])]) + + if not final_shape_points: + final_shape_points.extend(segment_points) + else: + final_shape_points.extend(segment_points[1:]) + + i = j - 1 # Next iteration starts from S_{j-1} + else: + # Only S_i is in bounds, S_{i+1} is out. + # Segment S_i -> S_{i+1} is straight line. + stop_b = stops_dict[stop_seq[i+1]["stop_id"]] + lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"]) + + segment_points = [[lon_a, lat_a], [lon_b, lat_b]] + if not final_shape_points: + final_shape_points.extend(segment_points) + else: + final_shape_points.extend(segment_points[1:]) + i += 1 + + shape_ids_generated.add(shape_id) + + with open( + os.path.join(OUTPUT_GTFS_PATH, "shapes.txt"), + "a", + encoding="utf-8", + newline="", + ) as f: + fieldnames = [ + "shape_id", + "shape_pt_lat", + "shape_pt_lon", + "shape_pt_sequence", + ] + writer = csv.DictWriter(f, fieldnames=fieldnames) + + if f.tell() == 0: + writer.writeheader() + + for seq, point in enumerate(final_shape_points): + writer.writerow( + { + "shape_id": shape_id, + "shape_pt_lat": point[1], + "shape_pt_lon": point[0], + "shape_pt_sequence": seq, + } + ) + else: + logging.info("Shape generation skipped as per user request.") + + # Create a ZIP archive of the output GTFS + with zipfile.ZipFile(OUTPUT_GTFS_ZIP, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(OUTPUT_GTFS_PATH): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, OUTPUT_GTFS_PATH) + zipf.write(file_path, arcname) + + logging.info( + f"GTFS data from feed {feed} has been zipped successfully at {OUTPUT_GTFS_ZIP}." + ) + os.close(INPUT_GTFS_FD) + os.remove(INPUT_GTFS_ZIP) + shutil.rmtree(INPUT_GTFS_PATH) + shutil.rmtree(OUTPUT_GTFS_PATH) -- cgit v1.3