Spaces:
Running
Running
| from datetime import datetime # type: ignore | |
| import sys # type: ignore | |
| from pathlib import Path | |
| # Add parent directory to path to allow imports from api/ | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| # Add src directory to path to allow imports from same directory | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from api.bus_cache import AsyncBusCache # type: ignore | |
| from api.utils import hms_to_seconds, get_service_day_start_ts, translate_occupancy # type: ignore | |
| from db_manager import init_db # type: ignore | |
| from dotenv import load_dotenv # type: ignore | |
| from fastapi import FastAPI, HTTPException # type: ignore | |
| from fastapi.middleware.cors import CORSMiddleware # type: ignore | |
| load_dotenv() | |
| ttc_cache = AsyncBusCache(ttl=20) | |
| # Initialize database connection globally | |
| db = init_db() | |
| app = FastAPI(title="WheresMyBus v2.0 API") | |
| # Setup CORS for your React frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://wheresmybus.vercel.app"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health_check(): | |
| """Simple health check endpoint""" | |
| return "backend is running" | |
| async def get_vehicles(): | |
| data = await ttc_cache.get_data() | |
| vehicles = data.get("vehicles", []) | |
| return { | |
| "status": "success", | |
| "count": len(vehicles), | |
| "vehicles": vehicles | |
| } | |
| async def get_all_routes(): | |
| """ | |
| Returns a complete list of TTC routes with their display names and colors. | |
| """ | |
| try: | |
| # Run the query against DuckDB | |
| # We handle missing colors by providing defaults (TTC Red: #FF0000) | |
| query = """ | |
| SELECT | |
| route_id, | |
| route_short_name, | |
| route_long_name, | |
| COALESCE(route_color, 'FF0000') as route_color, | |
| COALESCE(route_text_color, 'FFFFFF') as route_text_color | |
| FROM routes | |
| ORDER BY | |
| CASE | |
| WHEN CAST(route_short_name AS VARCHAR) ~ '^[0-9]+$' THEN CAST(route_short_name AS INTEGER) | |
| ELSE 999 | |
| END, | |
| route_short_name; | |
| """ | |
| results = db.execute(query).fetchall() | |
| # Convert to a clean list of dictionaries | |
| route_list = [ | |
| { | |
| "id": r[0], | |
| "number": r[1], | |
| "name": r[2], | |
| "color": f"#{r[3]}", | |
| "text_color": f"#{r[4]}" | |
| } | |
| for r in results | |
| ] | |
| return { | |
| "status": "success", | |
| "count": len(route_list), | |
| "routes": route_list | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_route_view(route_id: str): | |
| data = await ttc_cache.get_data() | |
| all_buses = data.get("vehicles", []) | |
| route_buses = [v for v in all_buses if v['route'] == route_id] | |
| # Get all stops for this route | |
| stops_query = """ | |
| SELECT DISTINCT | |
| s.stop_id, | |
| s.stop_code, | |
| s.stop_name, | |
| s.stop_lat, | |
| s.stop_lon | |
| FROM routes r | |
| JOIN trips t ON r.route_id = t.route_id | |
| JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR) | |
| JOIN stops s ON CAST(st.stop_id AS VARCHAR) = CAST(s.stop_id AS VARCHAR) | |
| WHERE CAST(r.route_id AS VARCHAR) = ? | |
| ORDER BY s.stop_name | |
| """ | |
| stops_results = db.execute(stops_query, [str(route_id)]).fetchall() | |
| stops = [ | |
| { | |
| "stop_id": str(r[0]), | |
| "stop_code": str(r[1]) if r[1] else None, | |
| "stop_name": r[2], | |
| "location": {"lat": r[3], "lon": r[4]} | |
| } | |
| for r in stops_results | |
| ] | |
| if not route_buses: | |
| return {"route": route_id, "vehicles": [], "stops": stops} | |
| trip_ids = [str(v['trip_id']) for v in route_buses] | |
| placeholders = ','.join(['?'] * len(trip_ids)) | |
| # Updated query to include shape_id | |
| query = f""" | |
| SELECT | |
| CAST(st.trip_id AS VARCHAR), | |
| CAST(st.stop_id AS VARCHAR), | |
| st.arrival_time as scheduled_time, | |
| t.trip_headsign, | |
| t.shape_id | |
| FROM stop_times st | |
| JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR) | |
| WHERE CAST(st.trip_id AS VARCHAR) IN ({placeholders}) | |
| """ | |
| db_rows = db.execute(query, trip_ids).fetchall() | |
| # Map (trip_id, stop_id) to scheduled time | |
| schedule_map = {(r[0], r[1]): r[2] for r in db_rows} | |
| # Map trip_id to an object containing headsign and shape_id | |
| name_map = { | |
| r[0]: { | |
| "headsign": r[3], | |
| "shape_id": r[4] | |
| } for r in db_rows | |
| } | |
| service_day_ts = get_service_day_start_ts() | |
| enriched = [] | |
| for bus in route_buses: | |
| raw_delay_mins = 0 | |
| pred_time = bus.get('predicted_time') | |
| stop_id = bus.get('next_stop_id') | |
| if pred_time and stop_id: | |
| sched_hms = schedule_map.get((str(bus['trip_id']), str(stop_id))) | |
| if sched_hms: | |
| h, m, s = map(int, sched_hms.split(':')) | |
| extra_days = h // 24 | |
| plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(sched_hms) | |
| raw_delay_mins = round((plan_ts - pred_time) / 60) | |
| # Pull details from our updated name_map | |
| trip_info = name_map.get(str(bus['trip_id']), {}) | |
| enriched.append({ | |
| "number": bus['id'], | |
| "name": trip_info.get("headsign", "Not in Schedule"), | |
| "location": {"lat": bus['lat'], "lon": bus['lon']}, | |
| "delay_mins": raw_delay_mins, | |
| "fullness": translate_occupancy(bus['occupancy']), | |
| "shape_id": trip_info.get("shape_id") # Bonus feature ready! | |
| }) | |
| return { | |
| "route": route_id, | |
| "count": len(enriched), | |
| "vehicles": enriched, | |
| "stops": stops | |
| } | |
| async def get_vehicle_view(vehicle_id: str): | |
| # 1. Pull latest from cache | |
| data = await ttc_cache.get_data() | |
| vehicles = data.get("vehicles", []) | |
| # 2. Find this specific bus in the list | |
| bus = next((v for v in vehicles if str(v['id']) == vehicle_id), None) | |
| if not bus: | |
| raise HTTPException(status_code=404, detail="Vehicle not active or not found") | |
| trip_id = str(bus['trip_id']) | |
| next_stop_id = bus.get('next_stop_id') | |
| predicted_time = bus.get('predicted_time') | |
| # 3. Handshake with Database | |
| destination = "Not in Schedule" | |
| shape_id = None # New field | |
| route_color = "FF0000" # New field | |
| delay_mins = 0 | |
| # Updated query to pull shape_id and route_color | |
| if next_stop_id: | |
| query = """ | |
| SELECT | |
| t.trip_headsign, | |
| st.arrival_time as scheduled_time, | |
| t.shape_id, | |
| r.route_color | |
| FROM trips t | |
| JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR) | |
| JOIN routes r ON t.route_id = r.route_id | |
| WHERE CAST(t.trip_id AS VARCHAR) = ? | |
| AND CAST(st.stop_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| row = db.execute(query, [trip_id, str(next_stop_id)]).fetchone() | |
| if row: | |
| destination, scheduled_hms, shape_id, r_color = row | |
| route_color = r_color if r_color else "FF0000" | |
| if predicted_time: | |
| service_day_ts = get_service_day_start_ts() | |
| h, m, s = map(int, scheduled_hms.split(':')) | |
| extra_days = h // 24 | |
| plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms) | |
| delay_mins = round((plan_ts - predicted_time) / 60) | |
| else: | |
| # Fallback query if no next_stop_id | |
| query = """ | |
| SELECT t.trip_headsign, t.shape_id, r.route_color | |
| FROM trips t | |
| JOIN routes r ON t.route_id = r.route_id | |
| WHERE CAST(t.trip_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| row = db.execute(query, [trip_id]).fetchone() | |
| if row: | |
| destination, shape_id, r_color = row | |
| route_color = r_color if r_color else "FF0000" | |
| return { | |
| "vehicle_number": vehicle_id, | |
| "route_id": bus['route'], | |
| "route_color": route_color, # Frontend now gets the color | |
| "name": destination, | |
| "location": { | |
| "lat": bus['lat'], | |
| "lon": bus['lon'] | |
| }, | |
| "delay_mins": delay_mins, | |
| "fullness": translate_occupancy(bus['occupancy']), | |
| "trip_id": trip_id, | |
| "shape_id": shape_id # Frontend now gets the ID to fetch lines | |
| } | |
| async def get_stop_view(stop_code: str): | |
| # 1. Translate Pole Number to Database ID | |
| stop_info = db.execute("SELECT stop_id, stop_name, stop_lat, stop_lon FROM stops WHERE CAST(stop_code AS VARCHAR) = ? LIMIT 1", [str(stop_code)]).fetchone() | |
| if not stop_info: | |
| return {"error": "Stop code not found"} | |
| target_id = str(stop_info[0]) | |
| stop_name = stop_info[1] | |
| stop_lat = stop_info[2] | |
| stop_lon = stop_info[3] | |
| # 2. Get the Cache structure (dict with vehicles, predictions, alerts) | |
| cached_data = await ttc_cache.get_data() | |
| vehicles_list = cached_data.get("vehicles", []) | |
| predictions = cached_data.get("predictions", {}) | |
| # Build vehicles map for quick lookup | |
| vehicles = {str(v['trip_id']): v for v in vehicles_list} | |
| from datetime import timezone | |
| now = datetime.now(timezone.utc).timestamp() | |
| service_day_ts = get_service_day_start_ts() | |
| two_hours_out = now + 7200 | |
| # 3. Determine today's service_id based on day of week | |
| # Service ID 1 = Weekdays (Mon-Fri), 2 = Saturday, 3 = Sunday/Holidays | |
| try: | |
| from zoneinfo import ZoneInfo | |
| eastern_tz = ZoneInfo("America/Toronto") | |
| except ImportError: | |
| import pytz | |
| eastern_tz = pytz.timezone("America/Toronto") | |
| now_eastern = datetime.now(timezone.utc).astimezone(eastern_tz) | |
| weekday = now_eastern.weekday() # 0=Monday, 6=Sunday | |
| if weekday < 5: # Monday-Friday | |
| today_service_id = 1 | |
| elif weekday == 5: # Saturday | |
| today_service_id = 2 | |
| else: # Sunday | |
| today_service_id = 3 | |
| # 4. Build a map of trip_id -> arrival info for merging | |
| arrival_map = {} | |
| # 5. Get scheduled arrivals for this stop, filtered by today's service_id | |
| schedule_query = """ | |
| SELECT | |
| CAST(t.trip_id AS VARCHAR) as trip_id, | |
| t.trip_headsign, | |
| COALESCE(st.departure_time, st.arrival_time) as scheduled_time, | |
| r.route_short_name | |
| FROM stop_times st | |
| JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR) | |
| JOIN routes r ON t.route_id = r.route_id | |
| WHERE CAST(st.stop_id AS VARCHAR) = ? | |
| AND CAST(t.service_id AS INTEGER) = ? | |
| """ | |
| schedule_rows = db.execute(schedule_query, [target_id, today_service_id]).fetchall() | |
| # Process scheduled arrivals | |
| for row in schedule_rows: | |
| trip_id = row[0] | |
| destination = row[1] | |
| scheduled_hms = row[2] | |
| route = row[3] | |
| # Calculate scheduled timestamp | |
| h, m, s = map(int, scheduled_hms.split(':')) | |
| extra_days = h // 24 | |
| scheduled_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms) | |
| # Only include if within 2 hours and hasn't passed | |
| if now <= scheduled_ts <= two_hours_out: | |
| arrival_map[trip_id] = { | |
| "trip_id": trip_id, | |
| "route": route, | |
| "destination": destination, | |
| "scheduled_ts": scheduled_ts, | |
| "pred_time": None, | |
| "has_prediction": False | |
| } | |
| # 5. Now, add/update with real-time predictions (even if not in schedule) | |
| for trip_id, itinerary in predictions.items(): | |
| if target_id in itinerary: | |
| pred_time = itinerary[target_id] | |
| # Only include predictions within 2 hours | |
| if now <= pred_time <= two_hours_out: | |
| # If we already have this trip from schedule, update it | |
| if trip_id in arrival_map: | |
| arrival_map[trip_id]["pred_time"] = pred_time | |
| arrival_map[trip_id]["has_prediction"] = True | |
| else: | |
| # This is a real-time-only prediction (not in static schedule) | |
| # Try to get route/destination from database | |
| query = """ | |
| SELECT t.trip_headsign, r.route_short_name | |
| FROM trips t | |
| JOIN routes r ON t.route_id = r.route_id | |
| WHERE CAST(t.trip_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| row = db.execute(query, [trip_id]).fetchone() | |
| if row: | |
| destination = row[0] | |
| route = row[1] | |
| else: | |
| destination = "Unknown" | |
| route = "Unknown" | |
| # Try to get scheduled time if available | |
| scheduled_query = """ | |
| SELECT COALESCE(st.departure_time, st.arrival_time) as scheduled_time | |
| FROM stop_times st | |
| WHERE CAST(st.trip_id AS VARCHAR) = ? AND CAST(st.stop_id AS VARCHAR) = ? | |
| LIMIT 1 | |
| """ | |
| sched_row = db.execute(scheduled_query, [trip_id, target_id]).fetchone() | |
| if sched_row: | |
| scheduled_hms = sched_row[0] | |
| h, m, s = map(int, scheduled_hms.split(':')) | |
| extra_days = h // 24 | |
| scheduled_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms) | |
| else: | |
| scheduled_ts = pred_time # Use prediction as fallback | |
| arrival_map[trip_id] = { | |
| "trip_id": trip_id, | |
| "route": route, | |
| "destination": destination, | |
| "scheduled_ts": scheduled_ts, | |
| "pred_time": pred_time, | |
| "has_prediction": True | |
| } | |
| # 6. Build final arrivals list | |
| arrivals = [] | |
| for trip_id, info in arrival_map.items(): | |
| # Use prediction if available, otherwise use scheduled | |
| if info["has_prediction"] and info["pred_time"]: | |
| eta_mins = round((info["pred_time"] - now) / 60) | |
| delay_mins = round((info["scheduled_ts"] - info["pred_time"]) / 60) | |
| else: | |
| eta_mins = round((info["scheduled_ts"] - now) / 60) | |
| delay_mins = 0 | |
| # Skip entries with Unknown route | |
| if info["route"] == "Unknown": | |
| continue | |
| # Find the actual bus for fullness (if it's on the road) | |
| bus = vehicles.get(trip_id) | |
| arrivals.append({ | |
| "route": info["route"], | |
| "destination": info["destination"], | |
| "eta_mins": eta_mins, | |
| "delay_mins": delay_mins, | |
| "fullness": translate_occupancy(bus['occupancy']) if bus else "Unknown", | |
| "vehicle_id": bus['id'] if bus else None | |
| }) | |
| arrivals.sort(key=lambda x: x['eta_mins']) | |
| return { | |
| "stop_name": stop_name, | |
| "stop_code": stop_code, | |
| "location": { | |
| "lat": stop_lat, | |
| "lon": stop_lon | |
| }, | |
| "arrivals": arrivals | |
| } | |
| async def get_all_alerts(): | |
| """ | |
| Returns every active service alert for the entire TTC network. | |
| """ | |
| from datetime import timezone | |
| data = await ttc_cache.get_data() | |
| return { | |
| "timestamp": datetime.now(timezone.utc).timestamp(), | |
| "count": len(data["alerts"]), | |
| "alerts": data["alerts"] | |
| } | |
| async def get_alerts_for_route(route_id: str): | |
| data = await ttc_cache.get_data() | |
| alerts = data.get("alerts", {}) | |
| route_alerts = alerts.get(route_id, []) | |
| if not route_alerts: | |
| return { | |
| "route_id": route_id, | |
| "count": 0, | |
| "alerts": "No alerts" | |
| } | |
| return { | |
| "route_id": route_id, | |
| "count": len(route_alerts), | |
| "alerts": route_alerts | |
| } | |
| async def get_nearby_context(lat: float, lon: float): | |
| # Bounding box approximation for ~1km radius (fast, no Haversine needed) | |
| # At Toronto's latitude (~43.6°): 1° lat ≈ 111km, 1° lon ≈ 55.4km | |
| # For 1km: lat_range = 0.009°, lon_range ≈ 0.0144° | |
| lat_range = 0.009 # ~1km in latitude (constant globally) | |
| lon_range = 0.0144 # ~1km in longitude at Toronto's latitude | |
| # 1. Find all stops within the bounding box | |
| query_stops = """ | |
| SELECT stop_id, stop_code, stop_name, stop_lat, stop_lon | |
| FROM stops | |
| WHERE stop_lat BETWEEN ? AND ? | |
| AND stop_lon BETWEEN ? AND ? | |
| """ | |
| stops = db.execute(query_stops, [lat - lat_range, lat + lat_range, lon - lon_range, lon + lon_range]).fetchall() | |
| stop_ids = [str(s[0]) for s in stops] | |
| if not stop_ids: | |
| return {"stops": [], "routes": []} | |
| # 2. Find all unique routes serving these specific stops | |
| placeholders = ','.join(['?'] * len(stop_ids)) | |
| query_routes = f""" | |
| SELECT DISTINCT r.route_id, r.route_short_name, r.route_long_name, r.route_color | |
| FROM routes r | |
| JOIN trips t ON r.route_id = t.route_id | |
| JOIN stop_times st ON t.trip_id = st.trip_id | |
| WHERE CAST(st.stop_id AS VARCHAR) IN ({placeholders}) | |
| """ | |
| routes = db.execute(query_routes, stop_ids).fetchall() | |
| return { | |
| "stops": [ | |
| {"id": s[0], "code": s[1], "name": s[2], "lat": s[3], "lon": s[4]} | |
| for s in stops | |
| ], | |
| "routes": [ | |
| {"id": r[0], "short_name": r[1], "long_name": r[2], "color": f"#{r[3]}"} | |
| for r in routes | |
| ] | |
| } | |
| async def get_route_shape(shape_id: str): | |
| """ | |
| Returns the ordered list of lat/lon coordinates for a specific route shape. | |
| """ | |
| try: | |
| # Query ordered by sequence to ensure the line draws correctly | |
| query = """ | |
| SELECT shape_pt_lat, shape_pt_lon | |
| FROM shapes | |
| WHERE CAST(shape_id AS VARCHAR) = ? | |
| ORDER BY shape_pt_sequence ASC | |
| """ | |
| results = db.execute(query, [shape_id]).fetchall() | |
| # Format for Leaflet: [[lat, lon], [lat, lon], ...] | |
| path = [[r[0], r[1]] for r in results] | |
| return { | |
| "status": "success", | |
| "shape_id": shape_id, | |
| "coordinates": path | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| if __name__ == "__main__": | |
| import uvicorn # type: ignore | |
| # Start the server | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |