Spaces:
Sleeping
Sleeping
| from fastapi import WebSocket | |
| from sqlalchemy.orm import Session | |
| from typing import Dict, Optional | |
| import json | |
| from datetime import datetime | |
| from models import Message, Conversation, ConversationMember, GroupMessage, GroupMember, Group, User | |
| from chat import msg_to_dict, group_msg_to_dict | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active: Dict[int, WebSocket] = {} | |
| async def connect(self, websocket: WebSocket, user_id: int): | |
| await websocket.accept() | |
| self.active[user_id] = websocket | |
| def disconnect(self, user_id: int): | |
| self.active.pop(user_id, None) | |
| async def send_to(self, user_id: int, data: dict): | |
| ws = self.active.get(user_id) | |
| if ws: | |
| try: | |
| await ws.send_text(json.dumps(data)) | |
| except Exception: | |
| self.active.pop(user_id, None) | |
| async def handle_dm(self, msg: dict, sender: User, db: Session): | |
| conversation_id = msg.get("conversation_id") | |
| target_user_id = msg.get("target_user_id") | |
| ciphertext_for_recipient = msg.get("ciphertext_for_recipient") | |
| ciphertext_for_sender = msg.get("ciphertext_for_sender") | |
| if not ciphertext_for_recipient: | |
| return | |
| # Auto-create conversation if target_user_id is provided instead of conversation_id | |
| if not conversation_id and target_user_id: | |
| target = db.query(User).filter(User.id == target_user_id).first() | |
| if not target: | |
| return | |
| my_ids = set(m.conversation_id for m in db.query(ConversationMember).filter(ConversationMember.user_id == sender.id).all()) | |
| their_ids = set(m.conversation_id for m in db.query(ConversationMember).filter(ConversationMember.user_id == target_user_id).all()) | |
| common = my_ids & their_ids | |
| if common: | |
| conversation_id = common.pop() | |
| else: | |
| conv = Conversation() | |
| db.add(conv) | |
| db.flush() | |
| db.add(ConversationMember(conversation_id=conv.id, user_id=sender.id)) | |
| db.add(ConversationMember(conversation_id=conv.id, user_id=target_user_id)) | |
| db.flush() | |
| conversation_id = conv.id | |
| if not conversation_id: | |
| return | |
| # Verify sender is in conversation | |
| membership = db.query(ConversationMember).filter( | |
| ConversationMember.conversation_id == conversation_id, | |
| ConversationMember.user_id == sender.id | |
| ).first() | |
| if not membership: | |
| return | |
| # Save message | |
| message = Message( | |
| conversation_id=conversation_id, | |
| sender_id=sender.id, | |
| ciphertext_for_recipient=ciphertext_for_recipient, | |
| ciphertext_for_sender=ciphertext_for_sender or ciphertext_for_recipient, | |
| ) | |
| db.add(message) | |
| db.commit() | |
| db.refresh(message) | |
| # Get all members of this conversation | |
| members = db.query(ConversationMember).filter( | |
| ConversationMember.conversation_id == conversation_id | |
| ).all() | |
| for member in members: | |
| payload = { | |
| "type": "dm", | |
| "conversation_id": conversation_id, | |
| "message": msg_to_dict(message, member.user_id) | |
| } | |
| await self.send_to(member.user_id, payload) | |
| async def handle_group(self, msg: dict, sender: User, db: Session): | |
| group_id = msg.get("group_id") | |
| ciphertext = msg.get("ciphertext") | |
| iv = msg.get("iv") | |
| encrypted_aes_keys = msg.get("encrypted_aes_keys", {}) | |
| if not group_id or not ciphertext or not iv: | |
| return | |
| # Verify membership | |
| membership = db.query(GroupMember).filter( | |
| GroupMember.group_id == group_id, | |
| GroupMember.user_id == sender.id | |
| ).first() | |
| if not membership: | |
| return | |
| import json as _json | |
| message = GroupMessage( | |
| group_id=group_id, | |
| sender_id=sender.id, | |
| ciphertext=ciphertext, | |
| iv=iv, | |
| encrypted_aes_keys=_json.dumps(encrypted_aes_keys), | |
| ) | |
| db.add(message) | |
| db.commit() | |
| db.refresh(message) | |
| members = db.query(GroupMember).filter(GroupMember.group_id == group_id).all() | |
| for member in members: | |
| payload = { | |
| "type": "group", | |
| "group_id": group_id, | |
| "message": group_msg_to_dict(message, member.user_id) | |
| } | |
| await self.send_to(member.user_id, payload) | |
| async def handle_typing(self, msg: dict, sender: User): | |
| target_id = msg.get("target_id") | |
| group_id = msg.get("group_id") | |
| is_typing = msg.get("is_typing", False) | |
| if target_id: | |
| await self.send_to(target_id, { | |
| "type": "typing", | |
| "user_id": sender.id, | |
| "display_name": sender.display_name, | |
| "is_typing": is_typing, | |
| "conversation_id": msg.get("conversation_id") | |
| }) | |
| elif group_id: | |
| # Would need to broadcast to all group members | |
| pass | |
| async def handle_call_offer(self, msg: dict, sender: User): | |
| target_id = msg.get("target_id") | |
| await self.send_to(target_id, { | |
| "type": "call_offer", | |
| "from": sender.id, | |
| "from_display_name": sender.display_name, | |
| "from_avatar_color": sender.avatar_color, | |
| "call_type": msg.get("call_type", "audio"), | |
| "sdp": msg.get("sdp") | |
| }) | |
| async def handle_call_answer(self, msg: dict, sender: User): | |
| target_id = msg.get("target_id") | |
| await self.send_to(target_id, { | |
| "type": "call_answer", | |
| "from": sender.id, | |
| "sdp": msg.get("sdp") | |
| }) | |
| async def handle_ice_candidate(self, msg: dict, sender: User): | |
| target_id = msg.get("target_id") | |
| await self.send_to(target_id, { | |
| "type": "ice_candidate", | |
| "from": sender.id, | |
| "candidate": msg.get("candidate") | |
| }) | |
| async def handle_call_end(self, msg: dict, sender: User): | |
| target_id = msg.get("target_id") | |
| await self.send_to(target_id, { | |
| "type": "call_end", | |
| "from": sender.id | |
| }) | |
| async def handle_video_toggle(self, msg: dict, sender: User): | |
| target_id = msg.get("target_id") | |
| await self.send_to(target_id, { | |
| "type": "video_toggle", | |
| "from": sender.id, | |
| "enabled": msg.get("enabled", True), | |
| "display_name": sender.display_name, | |
| "avatar_color": sender.avatar_color | |
| }) | |
| async def handle_read(self, msg: dict, sender: User, db: Session): | |
| conversation_id = msg.get("conversation_id") | |
| if conversation_id: | |
| messages = db.query(Message).filter( | |
| Message.conversation_id == conversation_id, | |
| Message.sender_id != sender.id, | |
| Message.is_read == False | |
| ).all() | |
| if messages: | |
| for m in messages: | |
| m.is_read = True | |
| db.commit() | |
| sender_ids = set(m.sender_id for m in messages) | |
| for sid in sender_ids: | |
| await self.send_to(sid, { | |
| "type": "read_receipt", | |
| "conversation_id": conversation_id, | |
| "read_by": sender.id | |
| }) | |
| manager = ConnectionManager() |