aboutsummaryrefslogtreecommitdiff
path: root/build_renfe/build_static_feed.py
diff options
context:
space:
mode:
Diffstat (limited to 'build_renfe/build_static_feed.py')
-rw-r--r--build_renfe/build_static_feed.py564
1 files changed, 564 insertions, 0 deletions
diff --git a/build_renfe/build_static_feed.py b/build_renfe/build_static_feed.py
new file mode 100644
index 0000000..a60360f
--- /dev/null
+++ b/build_renfe/build_static_feed.py
@@ -0,0 +1,564 @@
+# /// script
+# requires-python = ">=3.12"
+# dependencies = [
+# "requests",
+# "tqdm",
+# ]
+# ///
+
+from argparse import ArgumentParser
+import csv
+import json
+import logging
+import os
+import shutil
+import tempfile
+import zipfile
+
+import requests
+from tqdm import tqdm
+
+
+# Approximate bounding box for Galicia
+BOUNDS = {"SOUTH": 41.820455, "NORTH": 43.937462, "WEST": -9.437256, "EAST": -6.767578}
+
+FEEDS = {
+ "general": "1098",
+ "cercanias": "1130",
+ "feve": "1131"
+}
+
+
+def is_in_bounds(lat: float, lon: float) -> bool:
+ return (
+ BOUNDS["SOUTH"] <= lat <= BOUNDS["NORTH"]
+ and BOUNDS["WEST"] <= lon <= BOUNDS["EAST"]
+ )
+
+
+def get_stops_in_bounds(stops_file: str):
+ with open(stops_file, "r", encoding="utf-8") as f:
+ stops = csv.DictReader(f)
+
+ for stop in stops:
+ lat = float(stop["stop_lat"])
+ lon = float(stop["stop_lon"])
+ if is_in_bounds(lat, lon):
+ yield stop
+
+
+def get_trip_ids_for_stops(stoptimes_file: str, stop_ids: list[str]) -> list[str]:
+ trip_ids: set[str] = set()
+
+ with open(stoptimes_file, "r", encoding="utf-8") as f:
+ stop_times = csv.DictReader(f)
+
+ for stop_time in stop_times:
+ if stop_time["stop_id"] in stop_ids:
+ trip_ids.add(stop_time["trip_id"])
+
+ return list(trip_ids)
+
+
+def get_routes_for_trips(trips_file: str, trip_ids: list[str]) -> list[str]:
+ route_ids: set[str] = set()
+
+ with open(trips_file, "r", encoding="utf-8") as f:
+ trips = csv.DictReader(f)
+
+ for trip in trips:
+ if trip["trip_id"] in trip_ids:
+ route_ids.add(trip["route_id"])
+
+ return list(route_ids)
+
+
+def get_distinct_stops_from_stop_times(
+ stoptimes_file: str, trip_ids: list[str]
+) -> list[str]:
+ stop_ids: set[str] = set()
+
+ with open(stoptimes_file, "r", encoding="utf-8") as f:
+ stop_times = csv.DictReader(f)
+
+ for stop_time in stop_times:
+ if stop_time["trip_id"] in trip_ids:
+ stop_ids.add(stop_time["stop_id"])
+
+ return list(stop_ids)
+
+
+def get_last_stop_for_trips(
+ stoptimes_file: str, trip_ids: list[str]
+) -> dict[str, str]:
+ trip_last: dict[str, str] = {}
+ trip_last_seq: dict[str, int] = {}
+
+ with open(stoptimes_file, "r", encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ if reader.fieldnames is None:
+ raise Exception("Fuck you, screw you, fieldnames is None and you just get rekt")
+ reader.fieldnames = [name.strip() for name in reader.fieldnames]
+
+ for stop_time in reader:
+ if stop_time["trip_id"] in trip_ids:
+ trip_id = stop_time["trip_id"]
+ if trip_last.get(trip_id, None) is None:
+ trip_last[trip_id] = ""
+ trip_last_seq[trip_id] = -1
+
+ this_stop_seq = int(stop_time["stop_sequence"])
+ if this_stop_seq > trip_last_seq[trip_id]:
+ trip_last_seq[trip_id] = this_stop_seq
+ trip_last[trip_id] = stop_time["stop_id"]
+
+ return trip_last
+
+def get_rows_by_ids(input_file: str, id_field: str, ids: list[str]) -> list[dict]:
+ rows: list[dict] = []
+
+ with open(input_file, "r", encoding="utf-8") as f:
+ reader = csv.DictReader(f)
+ if reader.fieldnames is None:
+ raise Exception("Fuck you, screw you, fieldnames is None and you just get rekt")
+ reader.fieldnames = [name.strip() for name in reader.fieldnames]
+
+ for row in reader:
+ if row[id_field].strip() in ids:
+ rows.append(row)
+
+ return rows
+
+# First colour is background, second is text
+SERVICE_COLOURS = {
+ "REGIONAL": ("9A0060", "FFFFFF"),
+ "REG.EXP.": ("9A0060", "FFFFFF"),
+
+ "MD": ("F85B0B", "000000"),
+ "AVANT": ("F85B0B", "000000"),
+
+ "AVLO": ("05CEC6", "000000"),
+ "AVE": ("FFFFFF", "9A0060"),
+ "ALVIA": ("FFFFFF", "9A0060"),
+
+ "INTERCITY": ("606060", "FFFFFF"),
+
+ "TRENCELTA": ("00824A", "FFFFFF"),
+
+ # Cercanías Ferrol-Ortigueira
+ "C1": ("F5333F", "FFFFFF")
+}
+
+
+def colour_route(route_short_name: str) -> tuple[str, str]:
+ """
+ Returns the colours to be used for a route from its short name.
+
+ :param route_short_name: The routes.txt's route_short_name
+ :return: A tuple containing the "route_color" (background) first and "route_text_color" (text) second
+ :rtype: tuple[str, str]
+ """
+
+ route_name_searched = route_short_name.strip().upper()
+
+ if route_name_searched in SERVICE_COLOURS:
+ return SERVICE_COLOURS[route_name_searched]
+
+ print("Unknown route short name:", route_short_name)
+ return ("000000", "FFFFFF")
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(
+ description="Extract GTFS data for Galicia from Renfe GTFS feed."
+ )
+ parser.add_argument(
+ "nap_apikey",
+ type=str,
+ help="NAP API Key (https://nap.transportes.gob.es/)"
+ )
+ parser.add_argument(
+ "--osrm-url",
+ type=str,
+ help="OSRM server URL",
+ default="http://localhost:5050",
+ required=False,
+ )
+ parser.add_argument(
+ "--debug",
+ help="Enable debug logging",
+ action="store_true"
+ )
+
+ args = parser.parse_args()
+
+ try:
+ osrm_check = requests.head(args.osrm_url, timeout=5)
+ GENERATE_SHAPES = osrm_check.status_code < 500
+ except requests.RequestException:
+ GENERATE_SHAPES = False
+ logging.warning("OSRM server is not reachable. Shape generation will be skipped.")
+
+
+ logging.basicConfig(
+ level=logging.DEBUG if args.debug else logging.INFO,
+ format="%(asctime)s - %(levelname)s - %(message)s",
+ )
+
+ for feed in FEEDS.keys():
+ INPUT_GTFS_FD, INPUT_GTFS_ZIP = tempfile.mkstemp(suffix=".zip", prefix=f"renfe_galicia_in_{feed}_")
+ INPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_in_{feed}_")
+ OUTPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_out_{feed}_")
+ OUTPUT_GTFS_ZIP = os.path.join(os.path.dirname(__file__), f"gtfs_renfe_galicia_{feed}.zip")
+
+ FEED_URL = f"https://nap.transportes.gob.es/api/Fichero/download/{FEEDS[feed]}"
+
+ logging.info(f"Downloading GTFS feed '{feed}'...")
+ response = requests.get(FEED_URL, headers={"ApiKey": args.nap_apikey})
+ with open(INPUT_GTFS_ZIP, "wb") as f:
+ f.write(response.content)
+
+ # Unzip the GTFS feed
+ with zipfile.ZipFile(INPUT_GTFS_ZIP, "r") as zip_ref:
+ zip_ref.extractall(INPUT_GTFS_PATH)
+
+ STOPS_FILE = os.path.join(INPUT_GTFS_PATH, "stops.txt")
+ STOP_TIMES_FILE = os.path.join(INPUT_GTFS_PATH, "stop_times.txt")
+ TRIPS_FILE = os.path.join(INPUT_GTFS_PATH, "trips.txt")
+
+ all_stops_applicable = [stop for stop in get_stops_in_bounds(STOPS_FILE)]
+ logging.info(f"Total stops in Galicia: {len(all_stops_applicable)}")
+
+ stop_ids = [stop["stop_id"] for stop in all_stops_applicable]
+ trip_ids = get_trip_ids_for_stops(STOP_TIMES_FILE, stop_ids)
+
+ route_ids = get_routes_for_trips(TRIPS_FILE, trip_ids)
+
+ logging.info(f"Feed parsed successfully. Stops: {len(stop_ids)}, trips: {len(trip_ids)}, routes: {len(route_ids)}")
+ if len(trip_ids) == 0 or len(route_ids) == 0:
+ logging.warning(f"No trips or routes found for feed '{feed}'. Skipping...")
+ shutil.rmtree(INPUT_GTFS_PATH)
+ shutil.rmtree(OUTPUT_GTFS_PATH)
+ continue
+
+ # Copy agency.txt, calendar.txt, calendar_dates.txt as is
+ for filename in ["agency.txt", "calendar.txt", "calendar_dates.txt"]:
+ src_path = os.path.join(INPUT_GTFS_PATH, filename)
+ dest_path = os.path.join(OUTPUT_GTFS_PATH, filename)
+ if os.path.exists(src_path):
+ shutil.copy(src_path, dest_path)
+ else:
+ logging.debug(f"File {filename} does not exist in the input GTFS feed.")
+
+ # Write new stops.txt with the stops in any trip that passes through Galicia
+ with open(
+ os.path.join(os.path.dirname(__file__), "stop_overrides.json"),
+ "r",
+ encoding="utf-8",
+ ) as f:
+ stop_overrides_raw: list = json.load(f)
+ stop_overrides = {
+ item["stop_id"]: item
+ for item in stop_overrides_raw
+ }
+ logging.debug(f"Loaded stop overrides for {len(stop_overrides)} stops.")
+
+ deleted_stop_ids: set[str] = set()
+ for stop_id, override_item in stop_overrides.items():
+ if override_item.get("_delete", False):
+ if override_item.get("feed_id", None) is None or override_item["feed_id"] == feed:
+ deleted_stop_ids.add(stop_id)
+ logging.debug(f"Stops marked for deletion in feed '{feed}': {len(deleted_stop_ids)}")
+
+ distinct_stop_ids = get_distinct_stops_from_stop_times(
+ STOP_TIMES_FILE, trip_ids
+ )
+ stops_in_trips = get_rows_by_ids(STOPS_FILE, "stop_id", distinct_stop_ids)
+ for stop in stops_in_trips:
+ stop["stop_code"] = stop["stop_id"]
+ if stop_overrides.get(stop["stop_id"], None) is not None:
+ override_item = stop_overrides[stop["stop_id"]]
+
+ if override_item.get("feed_id", None) is not None and override_item["feed_id"] != feed:
+ continue
+
+ for key, value in override_item.items():
+ if key in ("stop_id", "feed_id", "_delete"):
+ continue
+ stop[key] = value
+
+ if stop["stop_name"].startswith("Estación de tren "):
+ stop["stop_name"] = stop["stop_name"][17:].strip()
+ stop["stop_name"] = " ".join([
+ word.capitalize() for word in stop["stop_name"].split(" ") if word != "de"
+ ])
+
+ stops_in_trips = [stop for stop in stops_in_trips if stop["stop_id"] not in deleted_stop_ids]
+
+ with open(
+ os.path.join(OUTPUT_GTFS_PATH, "stops.txt"),
+ "w",
+ encoding="utf-8",
+ newline="",
+ ) as f:
+ writer = csv.DictWriter(f, fieldnames=stops_in_trips[0].keys())
+ writer.writeheader()
+ writer.writerows(stops_in_trips)
+
+ # Write new routes.txt with the routes that have trips in Galicia
+ routes_in_trips = get_rows_by_ids(
+ os.path.join(INPUT_GTFS_PATH, "routes.txt"), "route_id", route_ids
+ )
+
+ if feed == "feve":
+ feve_c1_route_ids = ["46T0001C1", "46T0002C1"]
+ new_route_id = "FEVE_C1"
+
+ # Find agency_id and a template route
+ template_route = routes_in_trips[0] if routes_in_trips else {}
+ agency_id = "1"
+ for r in routes_in_trips:
+ if r["route_id"].strip() in feve_c1_route_ids:
+ agency_id = r.get("agency_id", "1")
+ template_route = r
+ break
+
+ # Filter out old routes
+ routes_in_trips = [r for r in routes_in_trips if r["route_id"].strip() not in feve_c1_route_ids]
+
+ # Add new route
+ new_route = template_route.copy()
+ new_route.update({
+ "route_id": new_route_id,
+ "route_short_name": "C1",
+ "route_long_name": "Ferrol - Xuvia - San Sadurniño - Ortigueira",
+ "route_type": "2",
+ })
+ if "agency_id" in template_route:
+ new_route["agency_id"] = agency_id
+
+ routes_in_trips.append(new_route)
+
+ for route in routes_in_trips:
+ route["route_color"], route["route_text_color"] = colour_route(
+ route["route_short_name"]
+ )
+ with open(
+ os.path.join(OUTPUT_GTFS_PATH, "routes.txt"),
+ "w",
+ encoding="utf-8",
+ newline="",
+ ) as f:
+ writer = csv.DictWriter(f, fieldnames=routes_in_trips[0].keys())
+ writer.writeheader()
+ writer.writerows(routes_in_trips)
+
+ # Write new trips.txt with the trips that pass through Galicia
+ # Load stop_times early so we can filter deleted stops and renumber sequences
+ stop_times_in_galicia = get_rows_by_ids(STOP_TIMES_FILE, "trip_id", trip_ids)
+ stop_times_in_galicia = [st for st in stop_times_in_galicia if st["stop_id"].strip() not in deleted_stop_ids]
+ stop_times_in_galicia.sort(key=lambda x: (x["trip_id"], int(x["stop_sequence"].strip())))
+ trip_seq_counter: dict[str, int] = {}
+ for st in stop_times_in_galicia:
+ tid = st["trip_id"]
+ if tid not in trip_seq_counter:
+ trip_seq_counter[tid] = 0
+ st["stop_sequence"] = str(trip_seq_counter[tid])
+ trip_seq_counter[tid] += 1
+
+ last_stop_in_trips: dict[str, str] = {}
+ trip_last_seq: dict[str, int] = {}
+ for st in stop_times_in_galicia:
+ tid = st["trip_id"]
+ seq = int(st["stop_sequence"])
+ if seq > trip_last_seq.get(tid, -1):
+ trip_last_seq[tid] = seq
+ last_stop_in_trips[tid] = st["stop_id"].strip()
+
+ trips_in_galicia = get_rows_by_ids(TRIPS_FILE, "trip_id", trip_ids)
+
+ if feed == "feve":
+ feve_c1_route_ids = ["46T0001C1", "46T0002C1"]
+ new_route_id = "FEVE_C1"
+ for tig in trips_in_galicia:
+ if tig["route_id"].strip() in feve_c1_route_ids:
+ tig["route_id"] = new_route_id
+ tig["direction_id"] = "1" if tig["route_id"].strip()[6] == "2" else "0"
+
+ stops_by_id = {stop["stop_id"]: stop for stop in stops_in_trips}
+
+ for tig in trips_in_galicia:
+ if GENERATE_SHAPES:
+ tig["shape_id"] = f"Shape_{tig['trip_id'][0:5]}"
+ tig["trip_headsign"] = stops_by_id[last_stop_in_trips[tig["trip_id"]]]["stop_name"]
+ with open(
+ os.path.join(OUTPUT_GTFS_PATH, "trips.txt"),
+ "w",
+ encoding="utf-8",
+ newline="",
+ ) as f:
+ writer = csv.DictWriter(f, fieldnames=trips_in_galicia[0].keys())
+ writer.writeheader()
+ writer.writerows(trips_in_galicia)
+
+ # Write new stop_times.txt with the stop times for any trip that passes through Galicia
+ with open(
+ os.path.join(OUTPUT_GTFS_PATH, "stop_times.txt"),
+ "w",
+ encoding="utf-8",
+ newline="",
+ ) as f:
+ writer = csv.DictWriter(f, fieldnames=stop_times_in_galicia[0].keys())
+ writer.writeheader()
+ writer.writerows(stop_times_in_galicia)
+
+ logging.info("GTFS data for Galicia has been extracted successfully. Generate shapes for the trips...")
+
+ if GENERATE_SHAPES:
+ shape_ids_total = len(set(f"Shape_{trip_id[0:5]}" for trip_id in trip_ids))
+ shape_ids_generated: set[str] = set()
+
+ # Pre-load stops for quick lookup
+ stops_dict = {stop["stop_id"]: stop for stop in stops_in_trips}
+
+ # Group stop times by trip_id to avoid repeated file reads
+ stop_times_by_trip: dict[str, list[dict]] = {}
+ for st in stop_times_in_galicia:
+ tid = st["trip_id"]
+ if tid not in stop_times_by_trip:
+ stop_times_by_trip[tid] = []
+ stop_times_by_trip[tid].append(st)
+
+ OSRM_BASE_URL = f"{args.osrm_url}/route/v1/driving/"
+ for trip_id in tqdm(trip_ids, total=shape_ids_total, desc="Generating shapes"):
+ shape_id = f"Shape_{trip_id[0:5]}"
+ if shape_id in shape_ids_generated:
+ continue
+
+ stop_seq = stop_times_by_trip.get(trip_id, [])
+ stop_seq.sort(key=lambda x: int(x["stop_sequence"].strip()))
+
+ if not stop_seq:
+ continue
+
+ final_shape_points = []
+ i = 0
+ while i < len(stop_seq) - 1:
+ stop_a = stops_dict[stop_seq[i]["stop_id"]]
+ lat_a, lon_a = float(stop_a["stop_lat"]), float(stop_a["stop_lon"])
+
+ if not is_in_bounds(lat_a, lon_a):
+ # S_i is out of bounds. Segment S_i -> S_{i+1} is straight line.
+ stop_b = stops_dict[stop_seq[i+1]["stop_id"]]
+ lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"])
+
+ segment_points = [[lon_a, lat_a], [lon_b, lat_b]]
+ if not final_shape_points:
+ final_shape_points.extend(segment_points)
+ else:
+ final_shape_points.extend(segment_points[1:])
+ i += 1
+ else:
+ # S_i is in bounds. Find how many subsequent stops are also in bounds.
+ j = i + 1
+ while j < len(stop_seq):
+ stop_j = stops_dict[stop_seq[j]["stop_id"]]
+ if is_in_bounds(float(stop_j["stop_lat"]), float(stop_j["stop_lon"])):
+ j += 1
+ else:
+ break
+
+ # Stops from i to j-1 are in bounds.
+ if j > i + 1:
+ # We have at least two consecutive stops in bounds.
+ in_bounds_stops = stop_seq[i:j]
+ coordinates = []
+ for st in in_bounds_stops:
+ s = stops_dict[st["stop_id"]]
+ coordinates.append(f"{s['stop_lon']},{s['stop_lat']}")
+
+ coords_str = ";".join(coordinates)
+ osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson"
+
+ segment_points = []
+ try:
+ response = requests.get(osrm_url, timeout=10)
+ if response.status_code == 200:
+ data = response.json()
+ if data.get("code") == "Ok":
+ segment_points = data["routes"][0]["geometry"]["coordinates"]
+ except Exception:
+ pass
+
+ if not segment_points:
+ # Fallback to straight lines for this whole sub-sequence
+ segment_points = []
+ for k in range(i, j):
+ s = stops_dict[stop_seq[k]["stop_id"]]
+ segment_points.append([float(s["stop_lon"]), float(s["stop_lat"])])
+
+ if not final_shape_points:
+ final_shape_points.extend(segment_points)
+ else:
+ final_shape_points.extend(segment_points[1:])
+
+ i = j - 1 # Next iteration starts from S_{j-1}
+ else:
+ # Only S_i is in bounds, S_{i+1} is out.
+ # Segment S_i -> S_{i+1} is straight line.
+ stop_b = stops_dict[stop_seq[i+1]["stop_id"]]
+ lat_b, lon_b = float(stop_b["stop_lat"]), float(stop_b["stop_lon"])
+
+ segment_points = [[lon_a, lat_a], [lon_b, lat_b]]
+ if not final_shape_points:
+ final_shape_points.extend(segment_points)
+ else:
+ final_shape_points.extend(segment_points[1:])
+ i += 1
+
+ shape_ids_generated.add(shape_id)
+
+ with open(
+ os.path.join(OUTPUT_GTFS_PATH, "shapes.txt"),
+ "a",
+ encoding="utf-8",
+ newline="",
+ ) as f:
+ fieldnames = [
+ "shape_id",
+ "shape_pt_lat",
+ "shape_pt_lon",
+ "shape_pt_sequence",
+ ]
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+
+ if f.tell() == 0:
+ writer.writeheader()
+
+ for seq, point in enumerate(final_shape_points):
+ writer.writerow(
+ {
+ "shape_id": shape_id,
+ "shape_pt_lat": point[1],
+ "shape_pt_lon": point[0],
+ "shape_pt_sequence": seq,
+ }
+ )
+ else:
+ logging.info("Shape generation skipped as per user request.")
+
+ # Create a ZIP archive of the output GTFS
+ with zipfile.ZipFile(OUTPUT_GTFS_ZIP, "w", zipfile.ZIP_DEFLATED) as zipf:
+ for root, _, files in os.walk(OUTPUT_GTFS_PATH):
+ for file in files:
+ file_path = os.path.join(root, file)
+ arcname = os.path.relpath(file_path, OUTPUT_GTFS_PATH)
+ zipf.write(file_path, arcname)
+
+ logging.info(
+ f"GTFS data from feed {feed} has been zipped successfully at {OUTPUT_GTFS_ZIP}."
+ )
+ os.close(INPUT_GTFS_FD)
+ os.remove(INPUT_GTFS_ZIP)
+ shutil.rmtree(INPUT_GTFS_PATH)
+ shutil.rmtree(OUTPUT_GTFS_PATH)