from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi.responses import RedirectResponse from sqlalchemy.orm import Session from sqlalchemy import or_ from pydantic import BaseModel from typing import Optional from passlib.context import CryptContext from jose import JWTError, jwt from datetime import datetime, timedelta import os import httpx import secrets import random from database import get_db from models import User SECRET_KEY = os.getenv("SECRET_KEY", secrets.token_hex(32)) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_HOURS = 24 * 7 # 7 days GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID", "") GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET", "") GOOGLE_REDIRECT_URI = os.getenv("GOOGLE_REDIRECT_URI", "http://localhost:7860/api/auth/google/callback") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") router = APIRouter() AVATAR_COLORS = ["#6366f1", "#8b5cf6", "#ec4899", "#f43f5e", "#f97316", "#eab308", "#22c55e", "#06b6d4", "#3b82f6"] class RegisterRequest(BaseModel): username: str display_name: str password: str = "" email: Optional[str] = None public_key: Optional[str] = None google_setup_token: Optional[str] = None class LoginRequest(BaseModel): username: str password: str class UpdateProfileRequest(BaseModel): display_name: Optional[str] = None public_key: Optional[str] = None class SetPasswordRequest(BaseModel): password: str def hash_password(password: str) -> str: return pwd_context.hash(password) def verify_password(plain: str, hashed: str) -> bool: return pwd_context.verify(plain, hashed) def create_token(user_id: int) -> str: expire = datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) return jwt.encode({"sub": str(user_id), "exp": expire}, SECRET_KEY, algorithm=ALGORITHM) def decode_token(token: str) -> Optional[int]: try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return int(payload["sub"]) except JWTError: return None def get_current_user(token: str, db: Session) -> Optional[User]: user_id = decode_token(token) if not user_id: return None return db.query(User).filter(User.id == user_id).first() async def get_current_user_ws(token: str, db: Session) -> Optional[User]: return get_current_user(token, db) def get_current_user_from_header(request: Request, db: Session = Depends(get_db)) -> User: auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="Not authenticated") token = auth[7:] user = get_current_user(token, db) if not user: raise HTTPException(status_code=401, detail="Invalid token") return user @router.post("/register") def register(req: RegisterRequest, db: Session = Depends(get_db)): # Google setup flow: validate setup token instead of password if req.google_setup_token: try: payload = jwt.decode(req.google_setup_token, SECRET_KEY, algorithms=[ALGORITHM]) if payload.get("type") != "google_setup": raise HTTPException(status_code=400, detail="Invalid setup token") google_id = payload["google_id"] google_email = payload.get("email") except JWTError: raise HTTPException(status_code=400, detail="Invalid or expired setup token") existing = db.query(User).filter( or_(User.google_id == google_id, User.email == google_email) ).first() if existing: raise HTTPException(status_code=400, detail="Account already exists") google_id_attr = google_id email_attr = google_email or None password_hash = None else: if not req.password: raise HTTPException(status_code=400, detail="Password is required") google_id_attr = None email_attr = req.email password_hash = hash_password(req.password) if req.email and db.query(User).filter(User.email == req.email).first(): raise HTTPException(status_code=400, detail="Email already in use") if db.query(User).filter(User.username == req.username.lower().strip()).first(): raise HTTPException(status_code=400, detail="Username already taken") user = User( username=req.username.lower().strip(), display_name=req.display_name, email=email_attr, hashed_password=password_hash, google_id=google_id_attr, public_key=req.public_key, avatar_color=random.choice(AVATAR_COLORS), ) db.add(user) db.commit() db.refresh(user) return {"token": create_token(user.id), "user": user_to_dict(user)} @router.post("/login") def login(req: LoginRequest, db: Session = Depends(get_db)): user = db.query(User).filter(User.username == req.username.lower().strip()).first() if not user or not user.hashed_password or not verify_password(req.password, user.hashed_password): raise HTTPException(status_code=401, detail="Invalid credentials") return {"token": create_token(user.id), "user": user_to_dict(user)} def _google_redirect_uri(request: Request) -> str: override = os.getenv("GOOGLE_REDIRECT_URI") if override: return override scheme = request.headers.get("X-Forwarded-Proto", request.url.scheme) host = request.headers.get("X-Forwarded-Host", request.url.hostname) port = request.url.port netloc = f"{host}:{port}" if port and port not in (80, 443) else host return f"{scheme}://{netloc}/api/auth/google/callback" @router.get("/google") def google_auth(request: Request): if not GOOGLE_CLIENT_ID: raise HTTPException(status_code=400, detail="Google OAuth not configured") redirect_uri = _google_redirect_uri(request) params = f"client_id={GOOGLE_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=openid email profile" return RedirectResponse(f"https://accounts.google.com/o/oauth2/v2/auth?{params}") @router.get("/google/callback") async def google_callback(code: str, request: Request, db: Session = Depends(get_db)): if not GOOGLE_CLIENT_ID: raise HTTPException(status_code=400, detail="Google OAuth not configured") redirect_uri = _google_redirect_uri(request) async with httpx.AsyncClient() as client: token_resp = await client.post("https://oauth2.googleapis.com/token", data={ "code": code, "client_id": GOOGLE_CLIENT_ID, "client_secret": GOOGLE_CLIENT_SECRET, "redirect_uri": redirect_uri, "grant_type": "authorization_code", }) token_data = token_resp.json() access_token = token_data.get("access_token") user_resp = await client.get("https://www.googleapis.com/oauth2/v2/userinfo", headers={"Authorization": f"Bearer {access_token}"}) guser = user_resp.json() google_id = guser.get("id") email = guser.get("email") name = guser.get("name", email.split("@")[0] if email else "User") user = db.query(User).filter(User.google_id == google_id).first() if not user: user = db.query(User).filter(User.email == email).first() if user: user.google_id = google_id else: # New Google user: send to frontend to pick username/display name setup_token = jwt.encode({ "type": "google_setup", "google_id": google_id, "email": email or "", "name": name, "exp": datetime.utcnow() + timedelta(minutes=15) }, SECRET_KEY, algorithm=ALGORITHM) return RedirectResponse(f"/?google_setup={setup_token}") db.commit() db.refresh(user) token = create_token(user.id) return RedirectResponse(f"/?token={token}&new_user={not user.public_key}") @router.get("/me") def get_me(request: Request, db: Session = Depends(get_db)): user = get_current_user_from_header(request, db) return user_to_dict(user) @router.put("/profile") def update_profile(req: UpdateProfileRequest, request: Request, db: Session = Depends(get_db)): user = get_current_user_from_header(request, db) if req.display_name: user.display_name = req.display_name if req.public_key: user.public_key = req.public_key db.commit() db.refresh(user) return user_to_dict(user) @router.post("/set-password") def set_password(req: SetPasswordRequest, request: Request, db: Session = Depends(get_db)): user = get_current_user_from_header(request, db) if user.hashed_password: raise HTTPException(status_code=400, detail="Password already set") if len(req.password) < 8: raise HTTPException(status_code=400, detail="Password must be at least 8 characters") user.hashed_password = hash_password(req.password) db.commit() return {"ok": True} def user_to_dict(user: User) -> dict: return { "id": user.id, "username": user.username, "display_name": user.display_name, "email": user.email, "avatar_color": user.avatar_color, "public_key": user.public_key, "is_online": user.is_online, "has_password": bool(user.hashed_password), "created_at": str(user.created_at), }