Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException, status, Request | |
| from fastapi.responses import RedirectResponse | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy import or_ | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| from passlib.context import CryptContext | |
| from jose import JWTError, jwt | |
| from datetime import datetime, timedelta | |
| import os | |
| import httpx | |
| import secrets | |
| import random | |
| from database import get_db | |
| from models import User | |
| SECRET_KEY = os.getenv("SECRET_KEY", secrets.token_hex(32)) | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_HOURS = 24 * 7 # 7 days | |
| GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID", "") | |
| GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET", "") | |
| GOOGLE_REDIRECT_URI = os.getenv("GOOGLE_REDIRECT_URI", "http://localhost:7860/api/auth/google/callback") | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| router = APIRouter() | |
| AVATAR_COLORS = ["#6366f1", "#8b5cf6", "#ec4899", "#f43f5e", "#f97316", "#eab308", "#22c55e", "#06b6d4", "#3b82f6"] | |
| class RegisterRequest(BaseModel): | |
| username: str | |
| display_name: str | |
| password: str = "" | |
| email: Optional[str] = None | |
| public_key: Optional[str] = None | |
| google_setup_token: Optional[str] = None | |
| class LoginRequest(BaseModel): | |
| username: str | |
| password: str | |
| class UpdateProfileRequest(BaseModel): | |
| display_name: Optional[str] = None | |
| public_key: Optional[str] = None | |
| class SetPasswordRequest(BaseModel): | |
| password: str | |
| def hash_password(password: str) -> str: | |
| return pwd_context.hash(password) | |
| def verify_password(plain: str, hashed: str) -> bool: | |
| return pwd_context.verify(plain, hashed) | |
| def create_token(user_id: int) -> str: | |
| expire = datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) | |
| return jwt.encode({"sub": str(user_id), "exp": expire}, SECRET_KEY, algorithm=ALGORITHM) | |
| def decode_token(token: str) -> Optional[int]: | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| return int(payload["sub"]) | |
| except JWTError: | |
| return None | |
| def get_current_user(token: str, db: Session) -> Optional[User]: | |
| user_id = decode_token(token) | |
| if not user_id: | |
| return None | |
| return db.query(User).filter(User.id == user_id).first() | |
| async def get_current_user_ws(token: str, db: Session) -> Optional[User]: | |
| return get_current_user(token, db) | |
| def get_current_user_from_header(request: Request, db: Session = Depends(get_db)) -> User: | |
| auth = request.headers.get("Authorization", "") | |
| if not auth.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Not authenticated") | |
| token = auth[7:] | |
| user = get_current_user(token, db) | |
| if not user: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| return user | |
| def register(req: RegisterRequest, db: Session = Depends(get_db)): | |
| # Google setup flow: validate setup token instead of password | |
| if req.google_setup_token: | |
| try: | |
| payload = jwt.decode(req.google_setup_token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| if payload.get("type") != "google_setup": | |
| raise HTTPException(status_code=400, detail="Invalid setup token") | |
| google_id = payload["google_id"] | |
| google_email = payload.get("email") | |
| except JWTError: | |
| raise HTTPException(status_code=400, detail="Invalid or expired setup token") | |
| existing = db.query(User).filter( | |
| or_(User.google_id == google_id, User.email == google_email) | |
| ).first() | |
| if existing: | |
| raise HTTPException(status_code=400, detail="Account already exists") | |
| google_id_attr = google_id | |
| email_attr = google_email or None | |
| password_hash = None | |
| else: | |
| if not req.password: | |
| raise HTTPException(status_code=400, detail="Password is required") | |
| google_id_attr = None | |
| email_attr = req.email | |
| password_hash = hash_password(req.password) | |
| if req.email and db.query(User).filter(User.email == req.email).first(): | |
| raise HTTPException(status_code=400, detail="Email already in use") | |
| if db.query(User).filter(User.username == req.username.lower().strip()).first(): | |
| raise HTTPException(status_code=400, detail="Username already taken") | |
| user = User( | |
| username=req.username.lower().strip(), | |
| display_name=req.display_name, | |
| email=email_attr, | |
| hashed_password=password_hash, | |
| google_id=google_id_attr, | |
| public_key=req.public_key, | |
| avatar_color=random.choice(AVATAR_COLORS), | |
| ) | |
| db.add(user) | |
| db.commit() | |
| db.refresh(user) | |
| return {"token": create_token(user.id), "user": user_to_dict(user)} | |
| def login(req: LoginRequest, db: Session = Depends(get_db)): | |
| user = db.query(User).filter(User.username == req.username.lower().strip()).first() | |
| if not user or not user.hashed_password or not verify_password(req.password, user.hashed_password): | |
| raise HTTPException(status_code=401, detail="Invalid credentials") | |
| return {"token": create_token(user.id), "user": user_to_dict(user)} | |
| def _google_redirect_uri(request: Request) -> str: | |
| override = os.getenv("GOOGLE_REDIRECT_URI") | |
| if override: | |
| return override | |
| scheme = request.headers.get("X-Forwarded-Proto", request.url.scheme) | |
| host = request.headers.get("X-Forwarded-Host", request.url.hostname) | |
| port = request.url.port | |
| netloc = f"{host}:{port}" if port and port not in (80, 443) else host | |
| return f"{scheme}://{netloc}/api/auth/google/callback" | |
| def google_auth(request: Request): | |
| if not GOOGLE_CLIENT_ID: | |
| raise HTTPException(status_code=400, detail="Google OAuth not configured") | |
| redirect_uri = _google_redirect_uri(request) | |
| params = f"client_id={GOOGLE_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=openid email profile" | |
| return RedirectResponse(f"https://accounts.google.com/o/oauth2/v2/auth?{params}") | |
| async def google_callback(code: str, request: Request, db: Session = Depends(get_db)): | |
| if not GOOGLE_CLIENT_ID: | |
| raise HTTPException(status_code=400, detail="Google OAuth not configured") | |
| redirect_uri = _google_redirect_uri(request) | |
| async with httpx.AsyncClient() as client: | |
| token_resp = await client.post("https://oauth2.googleapis.com/token", data={ | |
| "code": code, | |
| "client_id": GOOGLE_CLIENT_ID, | |
| "client_secret": GOOGLE_CLIENT_SECRET, | |
| "redirect_uri": redirect_uri, | |
| "grant_type": "authorization_code", | |
| }) | |
| token_data = token_resp.json() | |
| access_token = token_data.get("access_token") | |
| user_resp = await client.get("https://www.googleapis.com/oauth2/v2/userinfo", | |
| headers={"Authorization": f"Bearer {access_token}"}) | |
| guser = user_resp.json() | |
| google_id = guser.get("id") | |
| email = guser.get("email") | |
| name = guser.get("name", email.split("@")[0] if email else "User") | |
| user = db.query(User).filter(User.google_id == google_id).first() | |
| if not user: | |
| user = db.query(User).filter(User.email == email).first() | |
| if user: | |
| user.google_id = google_id | |
| else: | |
| # New Google user: send to frontend to pick username/display name | |
| setup_token = jwt.encode({ | |
| "type": "google_setup", | |
| "google_id": google_id, | |
| "email": email or "", | |
| "name": name, | |
| "exp": datetime.utcnow() + timedelta(minutes=15) | |
| }, SECRET_KEY, algorithm=ALGORITHM) | |
| return RedirectResponse(f"/?google_setup={setup_token}") | |
| db.commit() | |
| db.refresh(user) | |
| token = create_token(user.id) | |
| return RedirectResponse(f"/?token={token}&new_user={not user.public_key}") | |
| def get_me(request: Request, db: Session = Depends(get_db)): | |
| user = get_current_user_from_header(request, db) | |
| return user_to_dict(user) | |
| def update_profile(req: UpdateProfileRequest, request: Request, db: Session = Depends(get_db)): | |
| user = get_current_user_from_header(request, db) | |
| if req.display_name: | |
| user.display_name = req.display_name | |
| if req.public_key: | |
| user.public_key = req.public_key | |
| db.commit() | |
| db.refresh(user) | |
| return user_to_dict(user) | |
| def set_password(req: SetPasswordRequest, request: Request, db: Session = Depends(get_db)): | |
| user = get_current_user_from_header(request, db) | |
| if user.hashed_password: | |
| raise HTTPException(status_code=400, detail="Password already set") | |
| if len(req.password) < 8: | |
| raise HTTPException(status_code=400, detail="Password must be at least 8 characters") | |
| user.hashed_password = hash_password(req.password) | |
| db.commit() | |
| return {"ok": True} | |
| def user_to_dict(user: User) -> dict: | |
| return { | |
| "id": user.id, | |
| "username": user.username, | |
| "display_name": user.display_name, | |
| "email": user.email, | |
| "avatar_color": user.avatar_color, | |
| "public_key": user.public_key, | |
| "is_online": user.is_online, | |
| "has_password": bool(user.hashed_password), | |
| "created_at": str(user.created_at), | |
| } |