aboutsummaryrefslogtreecommitdiff
path: root/src/delay_collector/database.py
blob: b3f58f0b388d4d15346a06359aeb8a0beaf0f795 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Database module for storing delay observations."""

import os
from datetime import datetime
from typing import List, Dict

import psycopg2
from psycopg2.extras import execute_values


def get_connection():
    """Get a PostgreSQL database connection."""
    conn = psycopg2.connect(
        host=os.getenv("DB_HOST", "localhost"),
        port=int(os.getenv("DB_PORT", "5432")),
        database=os.getenv("DB_NAME", "busurbano"),
        user=os.getenv("DB_USER", "postgres"),
        password=os.getenv("DB_PASSWORD", "")
    )
    return conn


def insert_observations(
    observations: List[Dict],
    stop_code: int,
    observed_at: datetime
) -> int:
    """
    Insert delay observations into the database.

    Args:
        observations: List of observation dicts with keys:
            line, route, service_id, trip_id, running, scheduled_minutes, real_time_minutes
        stop_code: The stop code where observations were made
        observed_at: The datetime when observations were collected

    Returns:
        Number of records inserted
    """
    if not observations:
        return 0

    conn = get_connection()
    try:
        cursor = conn.cursor()

        insert_sql = """
            INSERT INTO delay_observations (
                observed_at,
                stop_code,
                line,
                route,
                service_id,
                trip_id,
                running,
                scheduled_minutes,
                real_time_minutes
            ) VALUES %s
        """

        records = [
            (
                observed_at,
                stop_code,
                obs["line"],
                obs["route"],
                obs["service_id"],
                obs["trip_id"],
                obs["running"],
                obs["scheduled_minutes"],
                obs["real_time_minutes"],
            )
            for obs in observations
        ]

        execute_values(cursor, insert_sql, records)
        conn.commit()
        return len(records)
    finally:
        cursor.close()
        conn.close()


def get_statistics() -> Dict:
    """Get basic statistics about the collected data."""
    conn = get_connection()
    cursor = conn.cursor()
    try:
        # Total observations
        cursor.execute("SELECT COUNT(*) as total FROM delay_observations")
        result = cursor.fetchone()
        total = result[0] if result else 0

        # Date range
        cursor.execute(
            """
            SELECT
                MIN(observed_at) as first_observation,
                MAX(observed_at) as last_observation
            FROM delay_observations
            """
        )
        date_range = cursor.fetchone()

        # Unique stops and lines
        cursor.execute(
            """
            SELECT
                COUNT(DISTINCT stop_code) as unique_stops,
                COUNT(DISTINCT line) as unique_lines
            FROM delay_observations
            """
        )
        unique_counts = cursor.fetchone()

        return {
            "total_observations": total,
            "first_observation": str(date_range[0]) if date_range and date_range[0] else None,
            "last_observation": str(date_range[1]) if date_range and date_range[1] else None,
            "unique_stops": unique_counts[0] if unique_counts else 0,
            "unique_lines": unique_counts[1] if unique_counts else 0,
        }
    finally:
        cursor.close()
        conn.close()