WMB2Backened / src /app.py
42Cummer's picture
Upload app.py
25ed060 verified
from datetime import datetime # type: ignore
import sys # type: ignore
from pathlib import Path
# Add parent directory to path to allow imports from api/
sys.path.insert(0, str(Path(__file__).parent.parent))
# Add src directory to path to allow imports from same directory
sys.path.insert(0, str(Path(__file__).parent))
from api.bus_cache import AsyncBusCache # type: ignore
from api.utils import hms_to_seconds, get_service_day_start_ts, translate_occupancy # type: ignore
from db_manager import init_db # type: ignore
from dotenv import load_dotenv # type: ignore
from fastapi import FastAPI, HTTPException # type: ignore
from fastapi.middleware.cors import CORSMiddleware # type: ignore
load_dotenv()
ttc_cache = AsyncBusCache(ttl=20)
# Initialize database connection globally
db = init_db()
app = FastAPI(title="WheresMyBus v2.0 API")
# Setup CORS for your React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["https://wheresmybus.vercel.app"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
"""Simple health check endpoint"""
return "backend is running"
@app.get("/api/vehicles")
async def get_vehicles():
data = await ttc_cache.get_data()
vehicles = data.get("vehicles", [])
return {
"status": "success",
"count": len(vehicles),
"vehicles": vehicles
}
@app.get("/api/routes")
async def get_all_routes():
"""
Returns a complete list of TTC routes with their display names and colors.
"""
try:
# Run the query against DuckDB
# We handle missing colors by providing defaults (TTC Red: #FF0000)
query = """
SELECT
route_id,
route_short_name,
route_long_name,
COALESCE(route_color, 'FF0000') as route_color,
COALESCE(route_text_color, 'FFFFFF') as route_text_color
FROM routes
ORDER BY
CASE
WHEN CAST(route_short_name AS VARCHAR) ~ '^[0-9]+$' THEN CAST(route_short_name AS INTEGER)
ELSE 999
END,
route_short_name;
"""
results = db.execute(query).fetchall()
# Convert to a clean list of dictionaries
route_list = [
{
"id": r[0],
"number": r[1],
"name": r[2],
"color": f"#{r[3]}",
"text_color": f"#{r[4]}"
}
for r in results
]
return {
"status": "success",
"count": len(route_list),
"routes": route_list
}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.get("/api/routes/{route_id}")
async def get_route_view(route_id: str):
data = await ttc_cache.get_data()
all_buses = data.get("vehicles", [])
route_buses = [v for v in all_buses if v['route'] == route_id]
# Get all stops for this route
stops_query = """
SELECT DISTINCT
s.stop_id,
s.stop_code,
s.stop_name,
s.stop_lat,
s.stop_lon
FROM routes r
JOIN trips t ON r.route_id = t.route_id
JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR)
JOIN stops s ON CAST(st.stop_id AS VARCHAR) = CAST(s.stop_id AS VARCHAR)
WHERE CAST(r.route_id AS VARCHAR) = ?
ORDER BY s.stop_name
"""
stops_results = db.execute(stops_query, [str(route_id)]).fetchall()
stops = [
{
"stop_id": str(r[0]),
"stop_code": str(r[1]) if r[1] else None,
"stop_name": r[2],
"location": {"lat": r[3], "lon": r[4]}
}
for r in stops_results
]
if not route_buses:
return {"route": route_id, "vehicles": [], "stops": stops}
trip_ids = [str(v['trip_id']) for v in route_buses]
placeholders = ','.join(['?'] * len(trip_ids))
# Updated query to include shape_id
query = f"""
SELECT
CAST(st.trip_id AS VARCHAR),
CAST(st.stop_id AS VARCHAR),
st.arrival_time as scheduled_time,
t.trip_headsign,
t.shape_id
FROM stop_times st
JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR)
WHERE CAST(st.trip_id AS VARCHAR) IN ({placeholders})
"""
db_rows = db.execute(query, trip_ids).fetchall()
# Map (trip_id, stop_id) to scheduled time
schedule_map = {(r[0], r[1]): r[2] for r in db_rows}
# Map trip_id to an object containing headsign and shape_id
name_map = {
r[0]: {
"headsign": r[3],
"shape_id": r[4]
} for r in db_rows
}
service_day_ts = get_service_day_start_ts()
enriched = []
for bus in route_buses:
raw_delay_mins = 0
pred_time = bus.get('predicted_time')
stop_id = bus.get('next_stop_id')
if pred_time and stop_id:
sched_hms = schedule_map.get((str(bus['trip_id']), str(stop_id)))
if sched_hms:
h, m, s = map(int, sched_hms.split(':'))
extra_days = h // 24
plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(sched_hms)
raw_delay_mins = round((plan_ts - pred_time) / 60)
# Pull details from our updated name_map
trip_info = name_map.get(str(bus['trip_id']), {})
enriched.append({
"number": bus['id'],
"name": trip_info.get("headsign", "Not in Schedule"),
"location": {"lat": bus['lat'], "lon": bus['lon']},
"delay_mins": raw_delay_mins,
"fullness": translate_occupancy(bus['occupancy']),
"shape_id": trip_info.get("shape_id") # Bonus feature ready!
})
return {
"route": route_id,
"count": len(enriched),
"vehicles": enriched,
"stops": stops
}
@app.get("/api/vehicles/{vehicle_id}")
async def get_vehicle_view(vehicle_id: str):
# 1. Pull latest from cache
data = await ttc_cache.get_data()
vehicles = data.get("vehicles", [])
# 2. Find this specific bus in the list
bus = next((v for v in vehicles if str(v['id']) == vehicle_id), None)
if not bus:
raise HTTPException(status_code=404, detail="Vehicle not active or not found")
trip_id = str(bus['trip_id'])
next_stop_id = bus.get('next_stop_id')
predicted_time = bus.get('predicted_time')
# 3. Handshake with Database
destination = "Not in Schedule"
shape_id = None # New field
route_color = "FF0000" # New field
delay_mins = 0
# Updated query to pull shape_id and route_color
if next_stop_id:
query = """
SELECT
t.trip_headsign,
st.arrival_time as scheduled_time,
t.shape_id,
r.route_color
FROM trips t
JOIN stop_times st ON CAST(t.trip_id AS VARCHAR) = CAST(st.trip_id AS VARCHAR)
JOIN routes r ON t.route_id = r.route_id
WHERE CAST(t.trip_id AS VARCHAR) = ?
AND CAST(st.stop_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id, str(next_stop_id)]).fetchone()
if row:
destination, scheduled_hms, shape_id, r_color = row
route_color = r_color if r_color else "FF0000"
if predicted_time:
service_day_ts = get_service_day_start_ts()
h, m, s = map(int, scheduled_hms.split(':'))
extra_days = h // 24
plan_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms)
delay_mins = round((plan_ts - predicted_time) / 60)
else:
# Fallback query if no next_stop_id
query = """
SELECT t.trip_headsign, t.shape_id, r.route_color
FROM trips t
JOIN routes r ON t.route_id = r.route_id
WHERE CAST(t.trip_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id]).fetchone()
if row:
destination, shape_id, r_color = row
route_color = r_color if r_color else "FF0000"
return {
"vehicle_number": vehicle_id,
"route_id": bus['route'],
"route_color": route_color, # Frontend now gets the color
"name": destination,
"location": {
"lat": bus['lat'],
"lon": bus['lon']
},
"delay_mins": delay_mins,
"fullness": translate_occupancy(bus['occupancy']),
"trip_id": trip_id,
"shape_id": shape_id # Frontend now gets the ID to fetch lines
}
@app.get("/api/stop/{stop_code}")
async def get_stop_view(stop_code: str):
# 1. Translate Pole Number to Database ID
stop_info = db.execute("SELECT stop_id, stop_name, stop_lat, stop_lon FROM stops WHERE CAST(stop_code AS VARCHAR) = ? LIMIT 1", [str(stop_code)]).fetchone()
if not stop_info:
return {"error": "Stop code not found"}
target_id = str(stop_info[0])
stop_name = stop_info[1]
stop_lat = stop_info[2]
stop_lon = stop_info[3]
# 2. Get the Cache structure (dict with vehicles, predictions, alerts)
cached_data = await ttc_cache.get_data()
vehicles_list = cached_data.get("vehicles", [])
predictions = cached_data.get("predictions", {})
# Build vehicles map for quick lookup
vehicles = {str(v['trip_id']): v for v in vehicles_list}
from datetime import timezone
now = datetime.now(timezone.utc).timestamp()
service_day_ts = get_service_day_start_ts()
two_hours_out = now + 7200
# 3. Determine today's service_id based on day of week
# Service ID 1 = Weekdays (Mon-Fri), 2 = Saturday, 3 = Sunday/Holidays
try:
from zoneinfo import ZoneInfo
eastern_tz = ZoneInfo("America/Toronto")
except ImportError:
import pytz
eastern_tz = pytz.timezone("America/Toronto")
now_eastern = datetime.now(timezone.utc).astimezone(eastern_tz)
weekday = now_eastern.weekday() # 0=Monday, 6=Sunday
if weekday < 5: # Monday-Friday
today_service_id = 1
elif weekday == 5: # Saturday
today_service_id = 2
else: # Sunday
today_service_id = 3
# 4. Build a map of trip_id -> arrival info for merging
arrival_map = {}
# 5. Get scheduled arrivals for this stop, filtered by today's service_id
schedule_query = """
SELECT
CAST(t.trip_id AS VARCHAR) as trip_id,
t.trip_headsign,
COALESCE(st.departure_time, st.arrival_time) as scheduled_time,
r.route_short_name
FROM stop_times st
JOIN trips t ON CAST(st.trip_id AS VARCHAR) = CAST(t.trip_id AS VARCHAR)
JOIN routes r ON t.route_id = r.route_id
WHERE CAST(st.stop_id AS VARCHAR) = ?
AND CAST(t.service_id AS INTEGER) = ?
"""
schedule_rows = db.execute(schedule_query, [target_id, today_service_id]).fetchall()
# Process scheduled arrivals
for row in schedule_rows:
trip_id = row[0]
destination = row[1]
scheduled_hms = row[2]
route = row[3]
# Calculate scheduled timestamp
h, m, s = map(int, scheduled_hms.split(':'))
extra_days = h // 24
scheduled_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms)
# Only include if within 2 hours and hasn't passed
if now <= scheduled_ts <= two_hours_out:
arrival_map[trip_id] = {
"trip_id": trip_id,
"route": route,
"destination": destination,
"scheduled_ts": scheduled_ts,
"pred_time": None,
"has_prediction": False
}
# 5. Now, add/update with real-time predictions (even if not in schedule)
for trip_id, itinerary in predictions.items():
if target_id in itinerary:
pred_time = itinerary[target_id]
# Only include predictions within 2 hours
if now <= pred_time <= two_hours_out:
# If we already have this trip from schedule, update it
if trip_id in arrival_map:
arrival_map[trip_id]["pred_time"] = pred_time
arrival_map[trip_id]["has_prediction"] = True
else:
# This is a real-time-only prediction (not in static schedule)
# Try to get route/destination from database
query = """
SELECT t.trip_headsign, r.route_short_name
FROM trips t
JOIN routes r ON t.route_id = r.route_id
WHERE CAST(t.trip_id AS VARCHAR) = ?
LIMIT 1
"""
row = db.execute(query, [trip_id]).fetchone()
if row:
destination = row[0]
route = row[1]
else:
destination = "Unknown"
route = "Unknown"
# Try to get scheduled time if available
scheduled_query = """
SELECT COALESCE(st.departure_time, st.arrival_time) as scheduled_time
FROM stop_times st
WHERE CAST(st.trip_id AS VARCHAR) = ? AND CAST(st.stop_id AS VARCHAR) = ?
LIMIT 1
"""
sched_row = db.execute(scheduled_query, [trip_id, target_id]).fetchone()
if sched_row:
scheduled_hms = sched_row[0]
h, m, s = map(int, scheduled_hms.split(':'))
extra_days = h // 24
scheduled_ts = service_day_ts + (extra_days * 86400) + hms_to_seconds(scheduled_hms)
else:
scheduled_ts = pred_time # Use prediction as fallback
arrival_map[trip_id] = {
"trip_id": trip_id,
"route": route,
"destination": destination,
"scheduled_ts": scheduled_ts,
"pred_time": pred_time,
"has_prediction": True
}
# 6. Build final arrivals list
arrivals = []
for trip_id, info in arrival_map.items():
# Use prediction if available, otherwise use scheduled
if info["has_prediction"] and info["pred_time"]:
eta_mins = round((info["pred_time"] - now) / 60)
delay_mins = round((info["scheduled_ts"] - info["pred_time"]) / 60)
else:
eta_mins = round((info["scheduled_ts"] - now) / 60)
delay_mins = 0
# Skip entries with Unknown route
if info["route"] == "Unknown":
continue
# Find the actual bus for fullness (if it's on the road)
bus = vehicles.get(trip_id)
arrivals.append({
"route": info["route"],
"destination": info["destination"],
"eta_mins": eta_mins,
"delay_mins": delay_mins,
"fullness": translate_occupancy(bus['occupancy']) if bus else "Unknown",
"vehicle_id": bus['id'] if bus else None
})
arrivals.sort(key=lambda x: x['eta_mins'])
return {
"stop_name": stop_name,
"stop_code": stop_code,
"location": {
"lat": stop_lat,
"lon": stop_lon
},
"arrivals": arrivals
}
@app.get("/api/alerts")
async def get_all_alerts():
"""
Returns every active service alert for the entire TTC network.
"""
from datetime import timezone
data = await ttc_cache.get_data()
return {
"timestamp": datetime.now(timezone.utc).timestamp(),
"count": len(data["alerts"]),
"alerts": data["alerts"]
}
@app.get("/api/alerts/{route_id}")
async def get_alerts_for_route(route_id: str):
data = await ttc_cache.get_data()
alerts = data.get("alerts", {})
route_alerts = alerts.get(route_id, [])
if not route_alerts:
return {
"route_id": route_id,
"count": 0,
"alerts": "No alerts"
}
return {
"route_id": route_id,
"count": len(route_alerts),
"alerts": route_alerts
}
@app.get("/api/nearby")
async def get_nearby_context(lat: float, lon: float):
# Bounding box approximation for ~1km radius (fast, no Haversine needed)
# At Toronto's latitude (~43.6°): 1° lat ≈ 111km, 1° lon ≈ 55.4km
# For 1km: lat_range = 0.009°, lon_range ≈ 0.0144°
lat_range = 0.009 # ~1km in latitude (constant globally)
lon_range = 0.0144 # ~1km in longitude at Toronto's latitude
# 1. Find all stops within the bounding box
query_stops = """
SELECT stop_id, stop_code, stop_name, stop_lat, stop_lon
FROM stops
WHERE stop_lat BETWEEN ? AND ?
AND stop_lon BETWEEN ? AND ?
"""
stops = db.execute(query_stops, [lat - lat_range, lat + lat_range, lon - lon_range, lon + lon_range]).fetchall()
stop_ids = [str(s[0]) for s in stops]
if not stop_ids:
return {"stops": [], "routes": []}
# 2. Find all unique routes serving these specific stops
placeholders = ','.join(['?'] * len(stop_ids))
query_routes = f"""
SELECT DISTINCT r.route_id, r.route_short_name, r.route_long_name, r.route_color
FROM routes r
JOIN trips t ON r.route_id = t.route_id
JOIN stop_times st ON t.trip_id = st.trip_id
WHERE CAST(st.stop_id AS VARCHAR) IN ({placeholders})
"""
routes = db.execute(query_routes, stop_ids).fetchall()
return {
"stops": [
{"id": s[0], "code": s[1], "name": s[2], "lat": s[3], "lon": s[4]}
for s in stops
],
"routes": [
{"id": r[0], "short_name": r[1], "long_name": r[2], "color": f"#{r[3]}"}
for r in routes
]
}
@app.get("/api/shapes/{shape_id}")
async def get_route_shape(shape_id: str):
"""
Returns the ordered list of lat/lon coordinates for a specific route shape.
"""
try:
# Query ordered by sequence to ensure the line draws correctly
query = """
SELECT shape_pt_lat, shape_pt_lon
FROM shapes
WHERE CAST(shape_id AS VARCHAR) = ?
ORDER BY shape_pt_sequence ASC
"""
results = db.execute(query, [shape_id]).fetchall()
# Format for Leaflet: [[lat, lon], [lat, lon], ...]
path = [[r[0], r[1]] for r in results]
return {
"status": "success",
"shape_id": shape_id,
"coordinates": path
}
except Exception as e:
return {"status": "error", "message": str(e)}
if __name__ == "__main__":
import uvicorn # type: ignore
# Start the server
uvicorn.run(app, host="0.0.0.0", port=7860)