Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends, HTTPException, Request | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy import and_, or_ | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import json | |
| from database import get_db | |
| from models import User, Conversation, ConversationMember, Message, Group, GroupMember, GroupMessage | |
| from auth import get_current_user_from_header, user_to_dict | |
| router = APIRouter() | |
| # βββ Users βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def search_users(q: str, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| users = db.query(User).filter( | |
| User.id != me.id, | |
| User.username.ilike(f"%{q}%") if q else True | |
| ).limit(20).all() | |
| return [user_to_dict(u) for u in users] | |
| def get_public_key(user_id: int, request: Request, db: Session = Depends(get_db)): | |
| get_current_user_from_header(request, db) | |
| user = db.query(User).filter(User.id == user_id).first() | |
| if not user: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"public_key": user.public_key} | |
| # βββ DM Conversations βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_or_create_conversation(target_user_id: int, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| # Find existing conversation between the two users | |
| my_convs = db.query(ConversationMember.conversation_id).filter(ConversationMember.user_id == me.id) | |
| their_convs = db.query(ConversationMember.conversation_id).filter(ConversationMember.user_id == target_user_id) | |
| common = my_convs.intersect(their_convs).all() | |
| if common: | |
| conv_id = common[0][0] | |
| conv = db.query(Conversation).filter(Conversation.id == conv_id).first() | |
| else: | |
| conv = Conversation() | |
| db.add(conv) | |
| db.flush() | |
| db.add(ConversationMember(conversation_id=conv.id, user_id=me.id)) | |
| db.add(ConversationMember(conversation_id=conv.id, user_id=target_user_id)) | |
| db.commit() | |
| db.refresh(conv) | |
| other = db.query(User).filter(User.id == target_user_id).first() | |
| return { | |
| "id": conv.id, | |
| "other_user": user_to_dict(other), | |
| "messages": [msg_to_dict(m, me.id) for m in conv.messages[-50:]] | |
| } | |
| def list_conversations(request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| memberships = db.query(ConversationMember).filter(ConversationMember.user_id == me.id).all() | |
| result = [] | |
| for m in memberships: | |
| conv = m.conversation | |
| other_member = next((x for x in conv.members if x.user_id != me.id), None) | |
| if not other_member: | |
| continue | |
| other = other_member.user | |
| last_msg = conv.messages[-1] if conv.messages else None | |
| result.append({ | |
| "id": conv.id, | |
| "other_user": user_to_dict(other), | |
| "last_message_at": last_msg.created_at.isoformat() if last_msg and hasattr(last_msg.created_at, 'isoformat') else str(last_msg.created_at) if last_msg else None, | |
| "unread_count": sum(1 for msg in conv.messages if not msg.is_read and msg.sender_id != me.id) | |
| }) | |
| result.sort(key=lambda x: x["last_message_at"] or "", reverse=True) | |
| return result | |
| def delete_conversation(conversation_id: int, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| membership = db.query(ConversationMember).filter( | |
| ConversationMember.conversation_id == conversation_id, | |
| ConversationMember.user_id == me.id | |
| ).first() | |
| if not membership: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| db.delete(membership) | |
| db.commit() | |
| return {"ok": True} | |
| # βββ Groups βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CreateGroupRequest(BaseModel): | |
| name: str | |
| description: Optional[str] = None | |
| member_ids: List[int] = [] | |
| class AddMemberRequest(BaseModel): | |
| user_id: int | |
| def create_group(req: CreateGroupRequest, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| group = Group(name=req.name, description=req.description, created_by=me.id) | |
| db.add(group) | |
| db.flush() | |
| # Add creator as admin | |
| db.add(GroupMember(group_id=group.id, user_id=me.id, role="admin")) | |
| # Add other members | |
| for uid in req.member_ids: | |
| if uid != me.id: | |
| db.add(GroupMember(group_id=group.id, user_id=uid)) | |
| db.commit() | |
| db.refresh(group) | |
| return group_to_dict(group, me.id) | |
| def list_groups(request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| memberships = db.query(GroupMember).filter(GroupMember.user_id == me.id).all() | |
| return [group_to_dict(m.group, me.id) for m in memberships] | |
| def get_group(group_id: int, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| group = db.query(Group).filter(Group.id == group_id).first() | |
| if not group: | |
| raise HTTPException(status_code=404, detail="Group not found") | |
| membership = db.query(GroupMember).filter( | |
| GroupMember.group_id == group_id, | |
| GroupMember.user_id == me.id | |
| ).first() | |
| if not membership: | |
| raise HTTPException(status_code=403, detail="Not a member") | |
| return { | |
| **group_to_dict(group, me.id), | |
| "messages": [group_msg_to_dict(m, me.id) for m in group.messages[-50:]], | |
| "members": [{"user": user_to_dict(m.user), "role": m.role} for m in group.members] | |
| } | |
| def add_member(group_id: int, req: AddMemberRequest, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| membership = db.query(GroupMember).filter( | |
| GroupMember.group_id == group_id, | |
| GroupMember.user_id == me.id | |
| ).first() | |
| if not membership: | |
| raise HTTPException(status_code=403, detail="Not a member") | |
| existing = db.query(GroupMember).filter( | |
| GroupMember.group_id == group_id, | |
| GroupMember.user_id == req.user_id | |
| ).first() | |
| if existing: | |
| raise HTTPException(status_code=400, detail="Already a member") | |
| db.add(GroupMember(group_id=group_id, user_id=req.user_id)) | |
| db.commit() | |
| user = db.query(User).filter(User.id == req.user_id).first() | |
| return {"ok": True, "user": user_to_dict(user)} | |
| def remove_member(group_id: int, user_id: int, request: Request, db: Session = Depends(get_db)): | |
| me = get_current_user_from_header(request, db) | |
| membership = db.query(GroupMember).filter( | |
| GroupMember.group_id == group_id, | |
| GroupMember.user_id == me.id | |
| ).first() | |
| if not membership: | |
| raise HTTPException(status_code=403, detail="Not a member") | |
| m = db.query(GroupMember).filter(GroupMember.group_id == group_id, GroupMember.user_id == user_id).first() | |
| if m: | |
| db.delete(m) | |
| db.commit() | |
| return {"ok": True} | |
| # βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def msg_to_dict(msg: Message, viewer_id: int) -> dict: | |
| is_sender = msg.sender_id == viewer_id | |
| return { | |
| "id": msg.id, | |
| "sender_id": msg.sender_id, | |
| "sender_name": msg.sender.display_name if msg.sender else "Unknown", | |
| "ciphertext": msg.ciphertext_for_sender if is_sender else msg.ciphertext_for_recipient, | |
| "created_at": msg.created_at.isoformat() if hasattr(msg.created_at, 'isoformat') else str(msg.created_at), | |
| "is_read": msg.is_read, | |
| } | |
| def group_msg_to_dict(msg: GroupMessage, viewer_id: int) -> dict: | |
| keys = json.loads(msg.encrypted_aes_keys) if msg.encrypted_aes_keys else {} | |
| return { | |
| "id": msg.id, | |
| "sender_id": msg.sender_id, | |
| "sender_name": msg.sender.display_name if msg.sender else "Unknown", | |
| "ciphertext": msg.ciphertext, | |
| "iv": msg.iv, | |
| "encrypted_aes_key": keys.get(str(viewer_id)), | |
| "created_at": msg.created_at.isoformat() if hasattr(msg.created_at, 'isoformat') else str(msg.created_at), | |
| } | |
| def group_to_dict(group: Group, viewer_id: int) -> dict: | |
| return { | |
| "id": group.id, | |
| "name": group.name, | |
| "description": group.description, | |
| "created_by": group.created_by, | |
| "avatar_color": group.avatar_color, | |
| "member_count": len(group.members), | |
| "created_at": group.created_at.isoformat() if hasattr(group.created_at, 'isoformat') else str(group.created_at), | |
| } |