Spaces:
Running
Running
| import duckdb # type: ignore | |
| import os # type: ignore | |
| import sys # type: ignore | |
| from pathlib import Path # type: ignore | |
| from dotenv import load_dotenv # type: ignore | |
| # Add parent directory to path to allow imports from api/ | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from api.update_static import GTFSSyncManager | |
| from api.bus_cache import AsyncBusCache # type: ignore | |
| # Configuration - always save DB in src/ directory | |
| DB_PATH = str(Path(__file__).parent / "ttc_gtfs.duckdb") | |
| STATIC_DIR = str(Path(__file__).parent.parent / "static") | |
| def init_db(): | |
| sync_mgr = GTFSSyncManager() | |
| remote_data = sync_mgr.get_remote_metadata() | |
| # Connect to existing or new DB | |
| con = duckdb.connect(sync_mgr.DB_PATH) | |
| # 1. Setup metadata tracking | |
| con.execute("CREATE TABLE IF NOT EXISTS sync_metadata (key VARCHAR PRIMARY KEY, value VARCHAR)") | |
| local_update = con.execute("SELECT value FROM sync_metadata WHERE key = 'last_modified'").fetchone() | |
| # 2. Check if we need to sync based on API metadata | |
| should_sync = False | |
| if not local_update or (remote_data and remote_data["updated_at"] > local_update[0]): | |
| should_sync = True | |
| if should_sync and remote_data: | |
| print(f"--- Data Stale. Remote: {remote_data['updated_at']} | Local: {local_update[0] if local_update else 'None'} ---") | |
| con.close() # Close to allow file deletion | |
| sync_mgr.perform_full_sync(remote_data["url"]) | |
| # Reconnect and finalize metadata | |
| con = duckdb.connect(sync_mgr.DB_PATH) | |
| con.execute("CREATE TABLE IF NOT EXISTS sync_metadata (key VARCHAR PRIMARY KEY, value VARCHAR)") | |
| con.execute("INSERT OR REPLACE INTO sync_metadata VALUES ('last_modified', ?)", [remote_data["updated_at"]]) | |
| # 3. Standard Import Loop (runs if DB was nuked or is missing tables) | |
| tables = [t[0] for t in con.execute("SHOW TABLES").fetchall()] | |
| if all(t in tables for t in ["routes", "trips", "stops", "stop_times", "shapes"]): | |
| return con | |
| print("--- Initializing/Updating DuckDB: Importing CSVs ---") | |
| # Updated files list to include the missing shapes | |
| files = ["routes.txt", "trips.txt", "stops.txt", "stop_times.txt", "shapes.txt"] | |
| for f in files: | |
| file_path = Path(STATIC_DIR) / f | |
| table_name = f.replace(".txt", "") | |
| if file_path.exists(): | |
| print(f"Importing {f} into table '{table_name}'...") | |
| abs_file_path = str(file_path.resolve()) | |
| # Use 'CREATE OR REPLACE' to overwrite existing tables without crashing | |
| con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM read_csv_auto('{abs_file_path}')") | |
| else: | |
| print(f"Error: {file_path} not found!") | |
| # Add this inside init_db in db_manager.py after the file import loop | |
| print("--- Creating Indexes for Performance ---") | |
| # This speeds up the /api/shapes/{shape_id} endpoint significantly | |
| con.execute("CREATE INDEX IF NOT EXISTS idx_shape_id ON shapes (shape_id)") | |
| # While you're at it, indexing trip_id in stop_times speeds up your arrival logic | |
| con.execute("CREATE INDEX IF NOT EXISTS idx_stop_times_trip_id ON stop_times (trip_id)") | |
| print("--- Database Import Complete ---") | |
| return con | |
| async def test_data_integrity(con): | |
| """ | |
| Runs a test join to confirm that a trip ID can be linked to a route name and stop list. | |
| Uses AsyncBusCache to get a real trip_id from the live API. | |
| """ | |
| print("--- Running Integrity Test ---") | |
| try: | |
| # Get trip_id from live API using AsyncBusCache | |
| cache = AsyncBusCache(ttl=20) | |
| vehicles = await cache.get_data() | |
| if not vehicles: | |
| print("No vehicles available from API, falling back to database trip_id") | |
| sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0] | |
| else: | |
| # Extract trip_id from the first vehicle | |
| # We need to get the raw GTFS data to access trip_id | |
| import httpx # type: ignore | |
| from google.transit import gtfs_realtime_pb2 # type: ignore | |
| load_dotenv() | |
| gtfs_rt_url = os.getenv("GTFS_RT_URL") | |
| if not gtfs_rt_url: | |
| raise ValueError("GTFS_RT_URL is not set") | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(gtfs_rt_url, timeout=10) | |
| response.raise_for_status() | |
| feed = gtfs_realtime_pb2.FeedMessage() | |
| feed.ParseFromString(response.content) | |
| # Get trip_id from first vehicle entity | |
| sample_trip = None | |
| for entity in feed.entity: | |
| if entity.HasField('vehicle') and entity.vehicle.trip.trip_id: | |
| sample_trip = entity.vehicle.trip.trip_id | |
| break | |
| if not sample_trip: | |
| print("No trip_id found in API response, falling back to database") | |
| sample_trip = con.execute("SELECT trip_id FROM trips LIMIT 1").fetchone()[0] | |
| else: | |
| print(f"Using trip_id from live API: {sample_trip}") | |
| # First, get the total count | |
| count_query = f""" | |
| SELECT COUNT(*) | |
| FROM trips t | |
| JOIN stop_times st ON t.trip_id = st.trip_id | |
| WHERE t.trip_id = '{sample_trip}' | |
| """ | |
| total_count = con.execute(count_query).fetchone()[0] | |
| # Determine sample size - show all if <= 20, otherwise show first 20 | |
| sample_size = min(20, total_count) if total_count > 20 else total_count | |
| query = f""" | |
| SELECT | |
| r.route_short_name, | |
| t.trip_headsign, | |
| st.stop_sequence, | |
| s.stop_name | |
| FROM trips t | |
| JOIN routes r ON t.route_id = r.route_id | |
| JOIN stop_times st ON t.trip_id = st.trip_id | |
| JOIN stops s ON st.stop_id = s.stop_id | |
| WHERE t.trip_id = '{sample_trip}' | |
| ORDER BY st.stop_sequence | |
| LIMIT {sample_size}; | |
| """ | |
| results = con.execute(query).fetchall() | |
| print(f"\nSuccessfully joined data for Trip ID: {sample_trip}") | |
| print(f"Total stops in trip: {total_count}") | |
| if total_count > sample_size: | |
| print(f"Showing first {sample_size} stops (sample):\n") | |
| else: | |
| print(f"Showing all {total_count} stops:\n") | |
| print(f"{'Route':<8} {'Headsign':<30} {'Stop #':<8} {'Stop Name':<50}") | |
| print("-" * 100) | |
| for res in results: | |
| route = res[0] or "N/A" | |
| headsign = (res[1] or "N/A")[:28] # Truncate if too long | |
| stop_seq = res[2] | |
| stop_name = (res[3] or "N/A")[:48] # Truncate if too long | |
| print(f"{route:<8} {headsign:<30} {stop_seq:<8} {stop_name:<50}") | |
| except Exception as e: | |
| print(f"Integrity test failed: {e}") | |
| if __name__ == "__main__": | |
| import asyncio # type: ignore | |
| db_con = init_db() | |
| asyncio.run(test_data_integrity(db_con)) |