# /// script # requires-python = ">=3.12" # dependencies = [ # "pandas", # "requests", # "tqdm", # ] # /// from argparse import ArgumentParser import csv import io import json import logging import os import shutil import tempfile import zipfile import binascii import pandas as pd import requests from tqdm import tqdm # Approximate bounding box for Galicia BOUNDS = {"SOUTH": 41.820455, "NORTH": 43.937462, "WEST": -9.437256, "EAST": -6.767578} FEEDS = { "general": "1098", "cercanias": "1130" } def is_in_bounds(lat: float, lon: float) -> bool: return ( BOUNDS["SOUTH"] <= lat <= BOUNDS["NORTH"] and BOUNDS["WEST"] <= lon <= BOUNDS["EAST"] ) def get_stops_in_bounds(stops_file: str): with open(stops_file, "r", encoding="utf-8") as f: stops = csv.DictReader(f) for stop in stops: lat = float(stop["stop_lat"]) lon = float(stop["stop_lon"]) if is_in_bounds(lat, lon): yield stop def get_trip_ids_for_stops(stoptimes_file: str, stop_ids: list[str]) -> list[str]: trip_ids: set[str] = set() with open(stoptimes_file, "r", encoding="utf-8") as f: stop_times = csv.DictReader(f) for stop_time in stop_times: if stop_time["stop_id"] in stop_ids: trip_ids.add(stop_time["trip_id"]) return list(trip_ids) def get_routes_for_trips(trips_file: str, trip_ids: list[str]) -> list[str]: route_ids: set[str] = set() with open(trips_file, "r", encoding="utf-8") as f: trips = csv.DictReader(f) for trip in trips: if trip["trip_id"] in trip_ids: route_ids.add(trip["route_id"]) return list(route_ids) def get_distinct_stops_from_stop_times( stoptimes_file: str, trip_ids: list[str] ) -> list[str]: stop_ids: set[str] = set() with open(stoptimes_file, "r", encoding="utf-8") as f: stop_times = csv.DictReader(f) for stop_time in stop_times: if stop_time["trip_id"] in trip_ids: stop_ids.add(stop_time["stop_id"]) return list(stop_ids) def get_last_stop_for_trips( stoptimes_file: str, trip_ids: list[str] ) -> dict[str, str]: trip_last: dict[str, str] = {} trip_last_seq: dict[str, int] = {} with open(stoptimes_file, "r", encoding="utf-8") as f: reader = csv.DictReader(f) if reader.fieldnames is None: raise Exception("Fuck you, screw you, fieldnames is None and you just get rekt") reader.fieldnames = [name.strip() for name in reader.fieldnames] for stop_time in reader: if stop_time["trip_id"] in trip_ids: trip_id = stop_time["trip_id"] if trip_last.get(trip_id, None) is None: trip_last[trip_id] = "" trip_last_seq[trip_id] = -1 this_stop_seq = int(stop_time["stop_sequence"]) if this_stop_seq > trip_last_seq[trip_id]: trip_last_seq[trip_id] = this_stop_seq trip_last[trip_id] = stop_time["stop_id"] return trip_last def get_rows_by_ids(input_file: str, id_field: str, ids: list[str]) -> list[dict]: rows: list[dict] = [] with open(input_file, "r", encoding="utf-8") as f: reader = csv.DictReader(f) if reader.fieldnames is None: raise Exception("Fuck you, screw you, fieldnames is None and you just get rekt") reader.fieldnames = [name.strip() for name in reader.fieldnames] for row in reader: if row[id_field].strip() in ids: rows.append(row) return rows # First colour is background, second is text SERVICE_COLOURS = { "REGIONAL": ("9A0060", "FFFFFF"), "REG.EXP.": ("9A0060", "FFFFFF"), "MD": ("F85B0B", "000000"), "AVANT": ("F85B0B", "000000"), "AVLO": ("05CEC6", "000000"), "AVE": ("FFFFFF", "9A0060"), "ALVIA": ("FFFFFF", "9A0060"), "INTERCITY": ("606060", "FFFFFF"), "TRENCELTA": ("00824A", "FFFFFF"), # Cercanías Ferrol-Ortigueira "C1": ("F5333F", "FFFFFF") } def colour_route(route_short_name: str) -> tuple[str, str]: """ Returns the colours to be used for a route from its short name. :param route_short_name: The routes.txt's route_short_name :return: A tuple containing the "route_color" (background) first and "route_text_color" (text) second :rtype: tuple[str, str] """ route_name_searched = route_short_name.strip().upper() if route_name_searched in SERVICE_COLOURS: return SERVICE_COLOURS[route_name_searched] print("Unknown route short name:", route_short_name) return ("000000", "FFFFFF") if __name__ == "__main__": parser = ArgumentParser( description="Extract GTFS data for Galicia from Renfe GTFS feed." ) parser.add_argument( "nap_apikey", type=str, help="NAP API Key (https://nap.transportes.gob.es/)" ) parser.add_argument( "--osrm-std", type=str, help="OSRM standard server URL", default="http://localhost:5050", required=False, ) parser.add_argument( "--osrm-narrow", type=str, help="OSRM narrow gauge server URL", default="http://localhost:5051", required=False, ) parser.add_argument( "--debug", help="Enable debug logging", action="store_true" ) parser.add_argument( "--merge", help="Merge the generated feeds into a single GTFS ZIP file instead of separate ones for each feed", action="store_true" ) args = parser.parse_args() try: osrm_check_std = requests.head(args.osrm_std, timeout=5) osrm_check_narrow = requests.head(args.osrm_narrow, timeout=5) GENERATE_SHAPES = osrm_check_std.status_code < 500 and osrm_check_narrow.status_code < 500 except requests.RequestException: GENERATE_SHAPES = False logging.warning("OSRM server is not reachable. Shape generation will be skipped.") logging.basicConfig( level=logging.DEBUG if args.debug else logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", ) for feed in FEEDS.keys(): def get_shape_id(trip_id: str) -> str: trip_crc = binascii.crc32(trip_id.encode("utf-8")) return f"Shape_{feed}_{trip_crc}_{trip_crc}" INPUT_GTFS_FD, INPUT_GTFS_ZIP = tempfile.mkstemp(suffix=".zip", prefix=f"renfe_galicia_in_{feed}_") INPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_in_{feed}_") OUTPUT_GTFS_PATH = tempfile.mkdtemp(prefix=f"renfe_galicia_out_{feed}_") OUTPUT_GTFS_ZIP = os.path.join(os.path.dirname(__file__), f"gtfs_renfe_galicia_{feed}.zip") FEED_URL = f"https://nap.transportes.gob.es/api/Fichero/download/{FEEDS[feed]}" logging.info(f"Downloading GTFS feed '{feed}'...") response = requests.get(FEED_URL, headers={"ApiKey": args.nap_apikey}) with open(INPUT_GTFS_ZIP, "wb") as f: f.write(response.content) # Unzip the GTFS feed with zipfile.ZipFile(INPUT_GTFS_ZIP, "r") as zip_ref: zip_ref.extractall(INPUT_GTFS_PATH) STOPS_FILE = os.path.join(INPUT_GTFS_PATH, "stops.txt") STOP_TIMES_FILE = os.path.join(INPUT_GTFS_PATH, "stop_times.txt") TRIPS_FILE = os.path.join(INPUT_GTFS_PATH, "trips.txt") all_stops_applicable = [stop for stop in get_stops_in_bounds(STOPS_FILE)] logging.info(f"Total stops in Galicia: {len(all_stops_applicable)}") stop_ids = [stop["stop_id"] for stop in all_stops_applicable] trip_ids = get_trip_ids_for_stops(STOP_TIMES_FILE, stop_ids) route_ids = get_routes_for_trips(TRIPS_FILE, trip_ids) logging.info(f"Feed parsed successfully. Stops: {len(stop_ids)}, trips: {len(trip_ids)}, routes: {len(route_ids)}") if len(trip_ids) == 0 or len(route_ids) == 0: logging.warning(f"No trips or routes found for feed '{feed}'. Skipping...") shutil.rmtree(INPUT_GTFS_PATH) shutil.rmtree(OUTPUT_GTFS_PATH) continue # Copy agency.txt, calendar.txt, calendar_dates.txt as is for filename in ["agency.txt", "calendar.txt", "calendar_dates.txt"]: src_path = os.path.join(INPUT_GTFS_PATH, filename) dest_path = os.path.join(OUTPUT_GTFS_PATH, filename) if os.path.exists(src_path): shutil.copy(src_path, dest_path) else: logging.debug(f"File {filename} does not exist in the input GTFS feed.") # Write new stops.txt with the stops in any trip that passes through Galicia with open( os.path.join(os.path.dirname(__file__), "stop_overrides.json"), "r", encoding="utf-8", ) as f: stop_overrides_raw: list = json.load(f) stop_overrides = { item["stop_id"]: item for item in stop_overrides_raw } logging.debug(f"Loaded stop overrides for {len(stop_overrides)} stops.") deleted_stop_ids: set[str] = set() for stop_id, override_item in stop_overrides.items(): if override_item.get("_delete", False): if override_item.get("feed_id", None) is None or override_item["feed_id"] == feed: deleted_stop_ids.add(stop_id) logging.debug(f"Stops marked for deletion in feed '{feed}': {len(deleted_stop_ids)}") distinct_stop_ids = get_distinct_stops_from_stop_times( STOP_TIMES_FILE, trip_ids ) stops_in_trips = get_rows_by_ids(STOPS_FILE, "stop_id", distinct_stop_ids) for stop in stops_in_trips: stop["stop_code"] = stop["stop_id"] if stop_overrides.get(stop["stop_id"], None) is not None: override_item = stop_overrides[stop["stop_id"]] if override_item.get("feed_id", None) is not None and override_item["feed_id"] != feed: continue for key, value in override_item.items(): if key in ("stop_id", "feed_id", "_delete"): continue stop[key] = value if stop["stop_name"].startswith("Estación de tren "): stop["stop_name"] = stop["stop_name"][17:].strip() stop["stop_name"] = " ".join([ word.capitalize() for word in stop["stop_name"].split(" ") if word != "de" ]) stops_in_trips = [stop for stop in stops_in_trips if stop["stop_id"] not in deleted_stop_ids] with open( os.path.join(OUTPUT_GTFS_PATH, "stops.txt"), "w", encoding="utf-8", newline="", ) as f: writer = csv.DictWriter(f, fieldnames=stops_in_trips[0].keys()) writer.writeheader() writer.writerows(stops_in_trips) # Write new routes.txt with the routes that have trips in Galicia routes_in_trips = get_rows_by_ids( os.path.join(INPUT_GTFS_PATH, "routes.txt"), "route_id", route_ids ) if feed == "cercanias": cercanias_c1_route_ids = ["46T0001C1", "46T0002C1"] new_route_id = "FERROL_C1" # Find agency_id and a template route template_route = routes_in_trips[0] if routes_in_trips else {} agency_id = "1" for r in routes_in_trips: if r["route_id"].strip() in cercanias_c1_route_ids: agency_id = r.get("agency_id", "1") template_route = r break # Filter out old routes routes_in_trips = [r for r in routes_in_trips if r["route_id"].strip() not in cercanias_c1_route_ids] # Add new route new_route = template_route.copy() new_route.update({ "route_id": new_route_id, "route_short_name": "C1", "route_long_name": "Ferrol - Xuvia - San Sadurniño - Ortigueira", "route_type": "2", }) if "agency_id" in template_route: new_route["agency_id"] = agency_id routes_in_trips.append(new_route) for route in routes_in_trips: route["route_color"], route["route_text_color"] = colour_route( route["route_short_name"] ) with open( os.path.join(OUTPUT_GTFS_PATH, "routes.txt"), "w", encoding="utf-8", newline="", ) as f: writer = csv.DictWriter(f, fieldnames=routes_in_trips[0].keys()) writer.writeheader() writer.writerows(routes_in_trips) # Write new trips.txt with the trips that pass through Galicia # Load stop_times early so we can filter deleted stops and renumber sequences stop_times_in_galicia = get_rows_by_ids(STOP_TIMES_FILE, "trip_id", trip_ids) stop_times_in_galicia = [st for st in stop_times_in_galicia if st["stop_id"].strip() not in deleted_stop_ids] stop_times_in_galicia.sort(key=lambda x: (x["trip_id"], int(x["stop_sequence"].strip()))) trip_seq_counter: dict[str, int] = {} for st in stop_times_in_galicia: tid = st["trip_id"] if tid not in trip_seq_counter: trip_seq_counter[tid] = 0 st["stop_sequence"] = str(trip_seq_counter[tid]) trip_seq_counter[tid] += 1 last_stop_in_trips: dict[str, str] = {} trip_last_seq: dict[str, int] = {} for st in stop_times_in_galicia: tid = st["trip_id"] seq = int(st["stop_sequence"]) if seq > trip_last_seq.get(tid, -1): trip_last_seq[tid] = seq last_stop_in_trips[tid] = st["stop_id"].strip() trips_in_galicia = get_rows_by_ids(TRIPS_FILE, "trip_id", trip_ids) if feed == "cercanias": cercanias_c1_route_ids = ["46T0001C1", "46T0002C1"] new_route_id = "FERROL_C1" for tig in trips_in_galicia: if tig["route_id"].strip() in cercanias_c1_route_ids: tig["direction_id"] = "1" if tig["route_id"].strip()[6] == "2" else "0" tig["route_id"] = new_route_id stops_by_id = {stop["stop_id"]: stop for stop in stops_in_trips} for tig in trips_in_galicia: if GENERATE_SHAPES: tig["shape_id"] = get_shape_id(tig["trip_id"]) tig["trip_headsign"] = stops_by_id[last_stop_in_trips[tig["trip_id"]]]["stop_name"] with open( os.path.join(OUTPUT_GTFS_PATH, "trips.txt"), "w", encoding="utf-8", newline="", ) as f: writer = csv.DictWriter(f, fieldnames=trips_in_galicia[0].keys()) writer.writeheader() writer.writerows(trips_in_galicia) # Write new stop_times.txt with the stop times for any trip that passes through Galicia with open( os.path.join(OUTPUT_GTFS_PATH, "stop_times.txt"), "w", encoding="utf-8", newline="", ) as f: writer = csv.DictWriter(f, fieldnames=stop_times_in_galicia[0].keys()) writer.writeheader() writer.writerows(stop_times_in_galicia) logging.info("GTFS data for Galicia has been extracted successfully. Generate shapes for the trips...") if GENERATE_SHAPES: shape_ids_total = len(set(trip["shape_id"] for trip in trips_in_galicia)) shape_ids_generated: set[str] = set() # Pre-load stops for quick lookup stops_dict = {stop["stop_id"]: stop for stop in stops_in_trips} # Map trip_id to route_type for OSRM profile selection trip_to_route_type = {tig["trip_id"]: tig.get("route_type", "2") for tig in trips_in_galicia} # Fallback if not in trips_in_galicia (shouldn't happen) route_id_to_type = {r["route_id"]: r["route_type"] for r in routes_in_trips} for tig in trips_in_galicia: trip_to_route_type[tig["trip_id"]] = route_id_to_type.get(tig["route_id"], "2") # Group stop times by trip_id to avoid repeated file reads stop_times_by_trip: dict[str, list[dict]] = {} for st in stop_times_in_galicia: tid = st["trip_id"] if tid not in stop_times_by_trip: stop_times_by_trip[tid] = [] stop_times_by_trip[tid].append(st) shapes_file = open("shapes_debug.txt", "w", encoding="utf-8") for trip_id in tqdm(trip_ids, total=shape_ids_total, desc="Generating shapes"): shape_id = get_shape_id(trip_id) if shape_id in shape_ids_generated: continue route_type = trip_to_route_type.get(trip_id, "2") osrm_profile = "driving" # If we are on feed cercanias or the 5-digit trip ID starts with 7, it's a narrow gauge train, otherwise standard gauge if feed == "cercanias" or (len(trip_id) >= 5 and trip_id[0] == "7"): OSRM_BASE_URL = f"{args.osrm_narrow}/route/v1/driving/" else: OSRM_BASE_URL = f"{args.osrm_std}/route/v1/driving/" stop_seq = stop_times_by_trip.get(trip_id, []) stop_seq.sort(key=lambda x: int(x["stop_sequence"].strip())) if not stop_seq: continue final_shape_points = [] coordinates = [] for st in stop_seq: s = stops_dict[st["stop_id"]] coordinates.append(f"{s['stop_lon']},{s['stop_lat']}") coords_str = ";".join(coordinates) osrm_url = f"{OSRM_BASE_URL}{coords_str}?overview=full&geometries=geojson&continue_straight=false" shapes_file.write(f"{trip_id} ({shape_id}): {osrm_url}\n") try: response = requests.get(osrm_url, timeout=10) if response.status_code == 200: data = response.json() if data.get("code") == "Ok": final_shape_points = data["routes"][0]["geometry"]["coordinates"] logging.debug(f"OSRM success for {shape_id} ({len(coordinates)} stops): {len(final_shape_points)} points") else: logging.warning(f"OSRM returned error code {data.get('code')} for {shape_id} with {len(coordinates)} stops") else: logging.warning(f"OSRM request failed for {shape_id} with status {response.status_code}") except Exception as e: logging.error(f"OSRM exception for {shape_id}: {str(e)}") if not final_shape_points: # Fallback to straight lines logging.info(f"Using straight-line fallback for {shape_id}") for st in stop_seq: s = stops_dict[st["stop_id"]] final_shape_points.append([float(s["stop_lon"]), float(s["stop_lat"])]) shape_ids_generated.add(shape_id) with open( os.path.join(OUTPUT_GTFS_PATH, "shapes.txt"), "a", encoding="utf-8", newline="", ) as f: fieldnames = [ "shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence", ] writer = csv.DictWriter(f, fieldnames=fieldnames) if f.tell() == 0: writer.writeheader() for seq, point in enumerate(final_shape_points): writer.writerow( { "shape_id": shape_id, "shape_pt_lat": point[1], "shape_pt_lon": point[0], "shape_pt_sequence": seq, } ) else: logging.info("Shape generation skipped as per user request.") # Create a ZIP archive of the output GTFS with zipfile.ZipFile(OUTPUT_GTFS_ZIP, "w", zipfile.ZIP_DEFLATED) as zipf: for root, _, files in os.walk(OUTPUT_GTFS_PATH): for file in files: file_path = os.path.join(root, file) arcname = os.path.relpath(file_path, OUTPUT_GTFS_PATH) zipf.write(file_path, arcname) logging.info( f"GTFS data from feed {feed} has been zipped successfully at {OUTPUT_GTFS_ZIP}." ) os.close(INPUT_GTFS_FD) os.remove(INPUT_GTFS_ZIP) shutil.rmtree(INPUT_GTFS_PATH) shutil.rmtree(OUTPUT_GTFS_PATH) if args.merge: # Columns to keep for each GTFS file when merging. # Files not listed here keep all columns present in the data. MERGE_KEEP_COLS: dict[str, list[str]] = { "agency.txt": ["agency_id", "agency_name", "agency_url", "agency_timezone", "agency_lang"], "stops.txt": ["stop_id", "stop_code", "stop_name", "stop_lat", "stop_lon", "wheelchair_boarding"], "routes.txt": ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type", "route_color", "route_text_color"], "trips.txt": ["route_id", "service_id", "trip_id", "trip_headsign", "direction_id", "shape_id", "wheelchair_accessible"], "stop_times.txt": ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence", "pickup_type", "drop_off_type"], } # Default values to fill for columns that are missing or NaN after concat. MERGE_FILL_DEFAULTS: dict[str, dict[str, str]] = { "routes.txt": {"agency_id": "1071VC"}, "trips.txt": {"direction_id": "0", "shape_id": "", "wheelchair_accessible": ""}, "stop_times.txt": {"pickup_type": "0", "drop_off_type": "0"}, } # Deduplicate rows by this column, keeping the first occurrence. MERGE_DEDUP_KEY: dict[str, str] = { "stops.txt": "stop_id", } merged_zip_path = os.path.join(os.path.dirname(__file__), "gtfs_renfe_galicia_merged.zip") feed_zip_paths = [os.path.join(os.path.dirname(__file__), f"gtfs_renfe_galicia_{feed}.zip") for feed in FEEDS.keys()] frames: dict[str, list[pd.DataFrame]] = {} for feed_zip_path in feed_zip_paths: with zipfile.ZipFile(feed_zip_path, "r") as feed_zip: for filename in feed_zip.namelist(): with feed_zip.open(filename) as f: df = pd.read_csv(f, dtype=str, encoding="utf-8") df.columns = df.columns.str.strip() df = df.apply(lambda col: col.str.strip() if col.dtype == object else col) frames.setdefault(filename, []).append(df) with zipfile.ZipFile(merged_zip_path, "w", zipfile.ZIP_DEFLATED) as merged_zip: for filename, dfs in frames.items(): merged = pd.concat(dfs, ignore_index=True) keep = MERGE_KEEP_COLS.get(filename) defaults = MERGE_FILL_DEFAULTS.get(filename, {}) if keep is not None: for col in keep: if col not in merged.columns: merged[col] = defaults.get(col, "") for col, val in defaults.items(): if col in merged.columns: merged[col] = merged[col].fillna(val) merged = merged[keep] dedup_key = MERGE_DEDUP_KEY.get(filename) if dedup_key: merged = merged.drop_duplicates(subset=[dedup_key], keep="first") buf = io.StringIO() merged.to_csv(buf, index=False) merged_zip.writestr(filename, buf.getvalue()) logging.info(f"Feeds merged successfully into {merged_zip_path}.")