""" Integration tests for WebSocket communication layer. Tests the complete WebSocket message flow, error handling, and real-time communication features of the multi-language chat agent. """ import pytest import json import time from datetime import datetime from unittest.mock import Mock, patch, MagicMock from uuid import uuid4 from flask import Flask from flask_socketio import SocketIO, SocketIOTestClient import redis from chat_agent.websocket import ( ChatWebSocketHandler, MessageValidator, ConnectionManager, create_chat_websocket_handler, create_message_validator, create_connection_manager, register_websocket_events ) from chat_agent.services.chat_agent import ChatAgent, ChatAgentError from chat_agent.services.session_manager import SessionManager, SessionNotFoundError from chat_agent.models.chat_session import ChatSession class TestWebSocketIntegration: """Integration tests for WebSocket communication.""" @pytest.fixture def app(self): """Create Flask app for testing.""" app = Flask(__name__) app.config['TESTING'] = True app.config['SECRET_KEY'] = 'test-secret-key' return app @pytest.fixture def socketio(self, app): """Create SocketIO instance for testing.""" return SocketIO(app, cors_allowed_origins="*") @pytest.fixture def mock_redis(self): """Create mock Redis client.""" mock_redis = Mock(spec=redis.Redis) mock_redis.setex = Mock() mock_redis.get = Mock(return_value=None) mock_redis.delete = Mock() mock_redis.sadd = Mock() mock_redis.srem = Mock() mock_redis.smembers = Mock(return_value=set()) mock_redis.expire = Mock() return mock_redis @pytest.fixture def mock_chat_agent(self): """Create mock chat agent.""" mock_agent = Mock(spec=ChatAgent) mock_agent.stream_response = Mock() mock_agent.switch_language = Mock() mock_agent.get_session_info = Mock() return mock_agent @pytest.fixture def mock_session_manager(self): """Create mock session manager.""" mock_manager = Mock(spec=SessionManager) mock_manager.get_session = Mock() mock_manager.update_session_activity = Mock() return mock_manager @pytest.fixture def connection_manager(self, mock_redis): """Create connection manager.""" return create_connection_manager(mock_redis) @pytest.fixture def message_validator(self): """Create message validator.""" return create_message_validator() @pytest.fixture def websocket_handler(self, mock_chat_agent, mock_session_manager, connection_manager): """Create WebSocket handler.""" return create_chat_websocket_handler( mock_chat_agent, mock_session_manager, connection_manager ) @pytest.fixture def client(self, app, socketio, websocket_handler): """Create SocketIO test client.""" register_websocket_events(socketio, websocket_handler) return socketio.test_client(app) @pytest.fixture def sample_session(self): """Create sample chat session.""" session = Mock(spec=ChatSession) session.id = str(uuid4()) session.user_id = str(uuid4()) session.language = 'python' session.message_count = 5 session.is_active = True return session def test_successful_connection(self, client, mock_session_manager, sample_session): """Test successful WebSocket connection.""" # Setup mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } # Connect assert client.is_connected() == False client.connect(auth=auth_data) assert client.is_connected() == True # Verify session validation was called mock_session_manager.get_session.assert_called_once_with(sample_session.id) mock_session_manager.update_session_activity.assert_called_once_with(sample_session.id) # Check connection status event received = client.get_received() assert len(received) == 1 assert received[0]['name'] == 'connection_status' assert received[0]['args'][0]['status'] == 'connected' assert received[0]['args'][0]['session_id'] == sample_session.id def test_connection_rejected_missing_auth(self, client): """Test connection rejection due to missing auth data.""" # Try to connect without auth client.connect() assert client.is_connected() == False def test_connection_rejected_invalid_session(self, client, mock_session_manager): """Test connection rejection due to invalid session.""" # Setup mock_session_manager.get_session.side_effect = SessionNotFoundError("Session not found") auth_data = { 'session_id': 'invalid-session', 'user_id': 'test-user' } # Try to connect client.connect(auth=auth_data) assert client.is_connected() == False def test_connection_rejected_user_mismatch(self, client, mock_session_manager, sample_session): """Test connection rejection due to user mismatch.""" # Setup mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': 'different-user' # Different from session.user_id } # Try to connect client.connect(auth=auth_data) assert client.is_connected() == False def test_message_processing_success(self, client, mock_session_manager, mock_chat_agent, sample_session): """Test successful message processing with streaming response.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Setup streaming response mock_response_chunks = [ {'type': 'start', 'session_id': sample_session.id, 'language': 'python', 'timestamp': datetime.utcnow().isoformat()}, {'type': 'chunk', 'content': 'Hello', 'timestamp': datetime.utcnow().isoformat()}, {'type': 'chunk', 'content': ' world!', 'timestamp': datetime.utcnow().isoformat()}, {'type': 'complete', 'message_id': str(uuid4()), 'total_chunks': 2, 'processing_time': 0.5, 'timestamp': datetime.utcnow().isoformat()} ] mock_chat_agent.stream_response.return_value = iter(mock_response_chunks) # Send message message_data = { 'content': 'Hello, how are you?', 'session_id': sample_session.id } client.emit('message', message_data) # Verify chat agent was called mock_chat_agent.stream_response.assert_called_once_with( sample_session.id, 'Hello, how are you?', None ) # Check received events received = client.get_received() # Should have: connection_status, message_received, processing_status, response_start, 2x response_chunk, response_complete assert len(received) >= 6 # Find specific events message_received = next((r for r in received if r['name'] == 'message_received'), None) assert message_received is not None processing_status = next((r for r in received if r['name'] == 'processing_status'), None) assert processing_status is not None assert processing_status['args'][0]['status'] == 'processing' response_start = next((r for r in received if r['name'] == 'response_start'), None) assert response_start is not None response_chunks = [r for r in received if r['name'] == 'response_chunk'] assert len(response_chunks) == 2 response_complete = next((r for r in received if r['name'] == 'response_complete'), None) assert response_complete is not None def test_message_validation_failure(self, client, mock_session_manager, sample_session): """Test message validation failure.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Send invalid message (missing content) invalid_message = { 'session_id': sample_session.id # Missing 'content' field } client.emit('message', invalid_message) # Check for error event received = client.get_received() error_event = next((r for r in received if r['name'] == 'error'), None) assert error_event is not None assert error_event['args'][0]['code'] == 'INVALID_MESSAGE' def test_language_switch_success(self, client, mock_session_manager, mock_chat_agent, sample_session): """Test successful language switching.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Setup language switch response switch_result = { 'success': True, 'previous_language': 'python', 'new_language': 'javascript', 'message': 'Language switched to JavaScript', 'timestamp': datetime.utcnow().isoformat() } mock_chat_agent.switch_language.return_value = switch_result # Send language switch request switch_data = { 'language': 'javascript', 'session_id': sample_session.id } client.emit('language_switch', switch_data) # Verify chat agent was called mock_chat_agent.switch_language.assert_called_once_with(sample_session.id, 'javascript') # Check for language_switched event received = client.get_received() language_switched = next((r for r in received if r['name'] == 'language_switched'), None) assert language_switched is not None assert language_switched['args'][0]['new_language'] == 'javascript' assert language_switched['args'][0]['previous_language'] == 'python' def test_language_switch_invalid_language(self, client, mock_session_manager, sample_session): """Test language switch with invalid language.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Send invalid language switch request switch_data = { 'language': 'invalid-language', 'session_id': sample_session.id } client.emit('language_switch', switch_data) # Check for error event received = client.get_received() error_event = next((r for r in received if r['name'] == 'error'), None) assert error_event is not None assert error_event['args'][0]['code'] == 'INVALID_LANGUAGE_SWITCH' def test_typing_indicators(self, client, mock_session_manager, sample_session): """Test typing indicator functionality.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Send typing start client.emit('typing_start', {}) # Send typing stop client.emit('typing_stop', {}) # Note: typing events are broadcast to room excluding sender, # so we won't see them in our own client's received events # This test mainly verifies no errors occur received = client.get_received() error_events = [r for r in received if r['name'] == 'error'] assert len(error_events) == 0 def test_ping_pong(self, client, mock_session_manager, sample_session): """Test ping/pong for connection health checks.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Send ping ping_timestamp = datetime.utcnow().isoformat() client.emit('ping', {'timestamp': ping_timestamp}) # Check for pong response received = client.get_received() pong_event = next((r for r in received if r['name'] == 'pong'), None) assert pong_event is not None assert pong_event['args'][0]['client_timestamp'] == ping_timestamp def test_session_info_request(self, client, mock_session_manager, mock_chat_agent, sample_session): """Test session info request.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Setup session info response session_info = { 'session': { 'id': sample_session.id, 'user_id': sample_session.user_id, 'language': 'python', 'message_count': 5 }, 'language_context': {'current_language': 'python'}, 'statistics': {'total_messages': 5}, 'supported_languages': ['python', 'javascript', 'java'] } mock_chat_agent.get_session_info.return_value = session_info # Request session info client.emit('get_session_info', {}) # Verify chat agent was called mock_chat_agent.get_session_info.assert_called_once_with(sample_session.id) # Check for session_info event received = client.get_received() session_info_event = next((r for r in received if r['name'] == 'session_info'), None) assert session_info_event is not None assert session_info_event['args'][0]['session']['id'] == sample_session.id def test_disconnect_cleanup(self, client, mock_session_manager, sample_session, connection_manager): """Test proper cleanup on disconnect.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Verify connection was added connections = connection_manager.get_all_connections() assert len(connections) > 0 # Disconnect client.disconnect() # Verify cleanup mock_session_manager.update_session_activity.assert_called() def test_error_handling_chat_agent_error(self, client, mock_session_manager, mock_chat_agent, sample_session): """Test error handling when chat agent fails.""" # Setup connection mock_session_manager.get_session.return_value = sample_session auth_data = { 'session_id': sample_session.id, 'user_id': sample_session.user_id } client.connect(auth=auth_data) # Setup chat agent to raise error mock_chat_agent.stream_response.side_effect = ChatAgentError("Processing failed") # Send message message_data = { 'content': 'Hello', 'session_id': sample_session.id } client.emit('message', message_data) # Check for error event received = client.get_received() error_event = next((r for r in received if r['name'] == 'error'), None) assert error_event is not None assert error_event['args'][0]['code'] == 'CHAT_AGENT_ERROR' class TestMessageValidator: """Tests for message validation functionality.""" @pytest.fixture def validator(self): """Create message validator.""" return create_message_validator() def test_valid_message(self, validator): """Test validation of valid message.""" message_data = { 'content': 'Hello, how can I help you with Python?', 'session_id': 'test-session-123' } result = validator.validate_message(message_data) assert result['valid'] == True assert result['errors'] == [] assert result['sanitized_content'] == 'Hello, how can I help you with Python?' def test_message_missing_content(self, validator): """Test validation failure for missing content.""" message_data = { 'session_id': 'test-session-123' } result = validator.validate_message(message_data) assert result['valid'] == False assert 'Message content is required' in result['errors'] def test_message_too_long(self, validator): """Test validation failure for message too long.""" long_content = 'x' * (validator.MAX_MESSAGE_LENGTH + 1) message_data = { 'content': long_content, 'session_id': 'test-session-123' } result = validator.validate_message(message_data) assert result['valid'] == False assert any('too long' in error for error in result['errors']) def test_message_sanitization(self, validator): """Test message content sanitization.""" malicious_content = 'Hello world' message_data = { 'content': malicious_content, 'session_id': 'test-session-123' } result = validator.validate_message(message_data) # Should be rejected due to malicious content assert result['valid'] == False def test_valid_language_switch(self, validator): """Test validation of valid language switch.""" switch_data = { 'language': 'javascript', 'session_id': 'test-session-123' } result = validator.validate_language_switch(switch_data) assert result['valid'] == True assert result['errors'] == [] assert result['language'] == 'javascript' def test_invalid_language_switch(self, validator): """Test validation failure for invalid language.""" switch_data = { 'language': 'invalid-language', 'session_id': 'test-session-123' } result = validator.validate_language_switch(switch_data) assert result['valid'] == False assert any('Unsupported language' in error for error in result['errors']) def test_rate_limiting(self, validator): """Test rate limiting functionality.""" session_id = 'test-session-rate-limit' # Send messages up to the limit for i in range(validator.MAX_MESSAGES_PER_MINUTE): message_data = { 'content': f'Message {i}', 'session_id': session_id } result = validator.validate_message(message_data) assert result['valid'] == True # Next message should be rate limited message_data = { 'content': 'Rate limited message', 'session_id': session_id } result = validator.validate_message(message_data) assert result['valid'] == False assert any('Rate limit exceeded' in error for error in result['errors']) class TestConnectionManager: """Tests for connection management functionality.""" @pytest.fixture def mock_redis(self): """Create mock Redis client.""" mock_redis = Mock(spec=redis.Redis) mock_redis.setex = Mock() mock_redis.get = Mock(return_value=None) mock_redis.delete = Mock() mock_redis.sadd = Mock() mock_redis.srem = Mock() mock_redis.smembers = Mock(return_value=set()) mock_redis.expire = Mock() return mock_redis @pytest.fixture def connection_manager(self, mock_redis): """Create connection manager.""" return create_connection_manager(mock_redis) def test_add_connection(self, connection_manager, mock_redis): """Test adding a connection.""" client_id = 'test-client-123' connection_info = { 'client_id': client_id, 'session_id': 'test-session-123', 'user_id': 'test-user-123', 'connected_at': datetime.utcnow().isoformat(), 'language': 'python' } connection_manager.add_connection(client_id, connection_info) # Verify connection was added to memory retrieved_info = connection_manager.get_connection(client_id) assert retrieved_info is not None assert retrieved_info['session_id'] == 'test-session-123' # Verify Redis calls mock_redis.setex.assert_called() mock_redis.sadd.assert_called() def test_remove_connection(self, connection_manager): """Test removing a connection.""" client_id = 'test-client-123' connection_info = { 'client_id': client_id, 'session_id': 'test-session-123', 'user_id': 'test-user-123', 'connected_at': datetime.utcnow().isoformat(), 'language': 'python' } # Add connection connection_manager.add_connection(client_id, connection_info) # Remove connection removed_info = connection_manager.remove_connection(client_id) assert removed_info is not None assert removed_info['session_id'] == 'test-session-123' # Verify connection is gone retrieved_info = connection_manager.get_connection(client_id) assert retrieved_info is None def test_update_connection_activity(self, connection_manager): """Test updating connection activity.""" client_id = 'test-client-123' connection_info = { 'client_id': client_id, 'session_id': 'test-session-123', 'user_id': 'test-user-123', 'connected_at': datetime.utcnow().isoformat(), 'language': 'python' } # Add connection connection_manager.add_connection(client_id, connection_info) # Update activity success = connection_manager.update_connection_activity(client_id) assert success == True # Verify last_activity was added updated_info = connection_manager.get_connection(client_id) assert 'last_activity' in updated_info def test_get_connection_stats(self, connection_manager): """Test getting connection statistics.""" # Add multiple connections for i in range(3): client_id = f'client-{i}' connection_info = { 'client_id': client_id, 'session_id': f'session-{i}', 'user_id': f'user-{i % 2}', # 2 unique users 'connected_at': datetime.utcnow().isoformat(), 'language': 'python' if i % 2 == 0 else 'javascript' } connection_manager.add_connection(client_id, connection_info) stats = connection_manager.get_connection_stats() assert stats['total_connections'] == 3 assert stats['unique_sessions'] == 3 assert stats['unique_users'] == 2 assert 'python' in stats['languages'] assert 'javascript' in stats['languages'] if __name__ == '__main__': pytest.main([__file__])