Spaces:
Runtime error
Runtime error
| """ | |
| Integration tests for WebSocket communication layer. | |
| Tests the complete WebSocket message flow, error handling, and real-time communication | |
| features of the multi-language chat agent. | |
| """ | |
| import pytest | |
| import json | |
| import time | |
| from datetime import datetime | |
| from unittest.mock import Mock, patch, MagicMock | |
| from uuid import uuid4 | |
| from flask import Flask | |
| from flask_socketio import SocketIO, SocketIOTestClient | |
| import redis | |
| from chat_agent.websocket import ( | |
| ChatWebSocketHandler, MessageValidator, ConnectionManager, | |
| create_chat_websocket_handler, create_message_validator, | |
| create_connection_manager, register_websocket_events | |
| ) | |
| from chat_agent.services.chat_agent import ChatAgent, ChatAgentError | |
| from chat_agent.services.session_manager import SessionManager, SessionNotFoundError | |
| from chat_agent.models.chat_session import ChatSession | |
| class TestWebSocketIntegration: | |
| """Integration tests for WebSocket communication.""" | |
| def app(self): | |
| """Create Flask app for testing.""" | |
| app = Flask(__name__) | |
| app.config['TESTING'] = True | |
| app.config['SECRET_KEY'] = 'test-secret-key' | |
| return app | |
| def socketio(self, app): | |
| """Create SocketIO instance for testing.""" | |
| return SocketIO(app, cors_allowed_origins="*") | |
| def mock_redis(self): | |
| """Create mock Redis client.""" | |
| mock_redis = Mock(spec=redis.Redis) | |
| mock_redis.setex = Mock() | |
| mock_redis.get = Mock(return_value=None) | |
| mock_redis.delete = Mock() | |
| mock_redis.sadd = Mock() | |
| mock_redis.srem = Mock() | |
| mock_redis.smembers = Mock(return_value=set()) | |
| mock_redis.expire = Mock() | |
| return mock_redis | |
| def mock_chat_agent(self): | |
| """Create mock chat agent.""" | |
| mock_agent = Mock(spec=ChatAgent) | |
| mock_agent.stream_response = Mock() | |
| mock_agent.switch_language = Mock() | |
| mock_agent.get_session_info = Mock() | |
| return mock_agent | |
| def mock_session_manager(self): | |
| """Create mock session manager.""" | |
| mock_manager = Mock(spec=SessionManager) | |
| mock_manager.get_session = Mock() | |
| mock_manager.update_session_activity = Mock() | |
| return mock_manager | |
| def connection_manager(self, mock_redis): | |
| """Create connection manager.""" | |
| return create_connection_manager(mock_redis) | |
| def message_validator(self): | |
| """Create message validator.""" | |
| return create_message_validator() | |
| def websocket_handler(self, mock_chat_agent, mock_session_manager, connection_manager): | |
| """Create WebSocket handler.""" | |
| return create_chat_websocket_handler( | |
| mock_chat_agent, mock_session_manager, connection_manager | |
| ) | |
| def client(self, app, socketio, websocket_handler): | |
| """Create SocketIO test client.""" | |
| register_websocket_events(socketio, websocket_handler) | |
| return socketio.test_client(app) | |
| def sample_session(self): | |
| """Create sample chat session.""" | |
| session = Mock(spec=ChatSession) | |
| session.id = str(uuid4()) | |
| session.user_id = str(uuid4()) | |
| session.language = 'python' | |
| session.message_count = 5 | |
| session.is_active = True | |
| return session | |
| def test_successful_connection(self, client, mock_session_manager, sample_session): | |
| """Test successful WebSocket connection.""" | |
| # Setup | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| # Connect | |
| assert client.is_connected() == False | |
| client.connect(auth=auth_data) | |
| assert client.is_connected() == True | |
| # Verify session validation was called | |
| mock_session_manager.get_session.assert_called_once_with(sample_session.id) | |
| mock_session_manager.update_session_activity.assert_called_once_with(sample_session.id) | |
| # Check connection status event | |
| received = client.get_received() | |
| assert len(received) == 1 | |
| assert received[0]['name'] == 'connection_status' | |
| assert received[0]['args'][0]['status'] == 'connected' | |
| assert received[0]['args'][0]['session_id'] == sample_session.id | |
| def test_connection_rejected_missing_auth(self, client): | |
| """Test connection rejection due to missing auth data.""" | |
| # Try to connect without auth | |
| client.connect() | |
| assert client.is_connected() == False | |
| def test_connection_rejected_invalid_session(self, client, mock_session_manager): | |
| """Test connection rejection due to invalid session.""" | |
| # Setup | |
| mock_session_manager.get_session.side_effect = SessionNotFoundError("Session not found") | |
| auth_data = { | |
| 'session_id': 'invalid-session', | |
| 'user_id': 'test-user' | |
| } | |
| # Try to connect | |
| client.connect(auth=auth_data) | |
| assert client.is_connected() == False | |
| def test_connection_rejected_user_mismatch(self, client, mock_session_manager, sample_session): | |
| """Test connection rejection due to user mismatch.""" | |
| # Setup | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': 'different-user' # Different from session.user_id | |
| } | |
| # Try to connect | |
| client.connect(auth=auth_data) | |
| assert client.is_connected() == False | |
| def test_message_processing_success(self, client, mock_session_manager, mock_chat_agent, sample_session): | |
| """Test successful message processing with streaming response.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Setup streaming response | |
| mock_response_chunks = [ | |
| {'type': 'start', 'session_id': sample_session.id, 'language': 'python', 'timestamp': datetime.utcnow().isoformat()}, | |
| {'type': 'chunk', 'content': 'Hello', 'timestamp': datetime.utcnow().isoformat()}, | |
| {'type': 'chunk', 'content': ' world!', 'timestamp': datetime.utcnow().isoformat()}, | |
| {'type': 'complete', 'message_id': str(uuid4()), 'total_chunks': 2, 'processing_time': 0.5, 'timestamp': datetime.utcnow().isoformat()} | |
| ] | |
| mock_chat_agent.stream_response.return_value = iter(mock_response_chunks) | |
| # Send message | |
| message_data = { | |
| 'content': 'Hello, how are you?', | |
| 'session_id': sample_session.id | |
| } | |
| client.emit('message', message_data) | |
| # Verify chat agent was called | |
| mock_chat_agent.stream_response.assert_called_once_with( | |
| sample_session.id, 'Hello, how are you?', None | |
| ) | |
| # Check received events | |
| received = client.get_received() | |
| # Should have: connection_status, message_received, processing_status, response_start, 2x response_chunk, response_complete | |
| assert len(received) >= 6 | |
| # Find specific events | |
| message_received = next((r for r in received if r['name'] == 'message_received'), None) | |
| assert message_received is not None | |
| processing_status = next((r for r in received if r['name'] == 'processing_status'), None) | |
| assert processing_status is not None | |
| assert processing_status['args'][0]['status'] == 'processing' | |
| response_start = next((r for r in received if r['name'] == 'response_start'), None) | |
| assert response_start is not None | |
| response_chunks = [r for r in received if r['name'] == 'response_chunk'] | |
| assert len(response_chunks) == 2 | |
| response_complete = next((r for r in received if r['name'] == 'response_complete'), None) | |
| assert response_complete is not None | |
| def test_message_validation_failure(self, client, mock_session_manager, sample_session): | |
| """Test message validation failure.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Send invalid message (missing content) | |
| invalid_message = { | |
| 'session_id': sample_session.id | |
| # Missing 'content' field | |
| } | |
| client.emit('message', invalid_message) | |
| # Check for error event | |
| received = client.get_received() | |
| error_event = next((r for r in received if r['name'] == 'error'), None) | |
| assert error_event is not None | |
| assert error_event['args'][0]['code'] == 'INVALID_MESSAGE' | |
| def test_language_switch_success(self, client, mock_session_manager, mock_chat_agent, sample_session): | |
| """Test successful language switching.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Setup language switch response | |
| switch_result = { | |
| 'success': True, | |
| 'previous_language': 'python', | |
| 'new_language': 'javascript', | |
| 'message': 'Language switched to JavaScript', | |
| 'timestamp': datetime.utcnow().isoformat() | |
| } | |
| mock_chat_agent.switch_language.return_value = switch_result | |
| # Send language switch request | |
| switch_data = { | |
| 'language': 'javascript', | |
| 'session_id': sample_session.id | |
| } | |
| client.emit('language_switch', switch_data) | |
| # Verify chat agent was called | |
| mock_chat_agent.switch_language.assert_called_once_with(sample_session.id, 'javascript') | |
| # Check for language_switched event | |
| received = client.get_received() | |
| language_switched = next((r for r in received if r['name'] == 'language_switched'), None) | |
| assert language_switched is not None | |
| assert language_switched['args'][0]['new_language'] == 'javascript' | |
| assert language_switched['args'][0]['previous_language'] == 'python' | |
| def test_language_switch_invalid_language(self, client, mock_session_manager, sample_session): | |
| """Test language switch with invalid language.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Send invalid language switch request | |
| switch_data = { | |
| 'language': 'invalid-language', | |
| 'session_id': sample_session.id | |
| } | |
| client.emit('language_switch', switch_data) | |
| # Check for error event | |
| received = client.get_received() | |
| error_event = next((r for r in received if r['name'] == 'error'), None) | |
| assert error_event is not None | |
| assert error_event['args'][0]['code'] == 'INVALID_LANGUAGE_SWITCH' | |
| def test_typing_indicators(self, client, mock_session_manager, sample_session): | |
| """Test typing indicator functionality.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Send typing start | |
| client.emit('typing_start', {}) | |
| # Send typing stop | |
| client.emit('typing_stop', {}) | |
| # Note: typing events are broadcast to room excluding sender, | |
| # so we won't see them in our own client's received events | |
| # This test mainly verifies no errors occur | |
| received = client.get_received() | |
| error_events = [r for r in received if r['name'] == 'error'] | |
| assert len(error_events) == 0 | |
| def test_ping_pong(self, client, mock_session_manager, sample_session): | |
| """Test ping/pong for connection health checks.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Send ping | |
| ping_timestamp = datetime.utcnow().isoformat() | |
| client.emit('ping', {'timestamp': ping_timestamp}) | |
| # Check for pong response | |
| received = client.get_received() | |
| pong_event = next((r for r in received if r['name'] == 'pong'), None) | |
| assert pong_event is not None | |
| assert pong_event['args'][0]['client_timestamp'] == ping_timestamp | |
| def test_session_info_request(self, client, mock_session_manager, mock_chat_agent, sample_session): | |
| """Test session info request.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Setup session info response | |
| session_info = { | |
| 'session': { | |
| 'id': sample_session.id, | |
| 'user_id': sample_session.user_id, | |
| 'language': 'python', | |
| 'message_count': 5 | |
| }, | |
| 'language_context': {'current_language': 'python'}, | |
| 'statistics': {'total_messages': 5}, | |
| 'supported_languages': ['python', 'javascript', 'java'] | |
| } | |
| mock_chat_agent.get_session_info.return_value = session_info | |
| # Request session info | |
| client.emit('get_session_info', {}) | |
| # Verify chat agent was called | |
| mock_chat_agent.get_session_info.assert_called_once_with(sample_session.id) | |
| # Check for session_info event | |
| received = client.get_received() | |
| session_info_event = next((r for r in received if r['name'] == 'session_info'), None) | |
| assert session_info_event is not None | |
| assert session_info_event['args'][0]['session']['id'] == sample_session.id | |
| def test_disconnect_cleanup(self, client, mock_session_manager, sample_session, connection_manager): | |
| """Test proper cleanup on disconnect.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Verify connection was added | |
| connections = connection_manager.get_all_connections() | |
| assert len(connections) > 0 | |
| # Disconnect | |
| client.disconnect() | |
| # Verify cleanup | |
| mock_session_manager.update_session_activity.assert_called() | |
| def test_error_handling_chat_agent_error(self, client, mock_session_manager, mock_chat_agent, sample_session): | |
| """Test error handling when chat agent fails.""" | |
| # Setup connection | |
| mock_session_manager.get_session.return_value = sample_session | |
| auth_data = { | |
| 'session_id': sample_session.id, | |
| 'user_id': sample_session.user_id | |
| } | |
| client.connect(auth=auth_data) | |
| # Setup chat agent to raise error | |
| mock_chat_agent.stream_response.side_effect = ChatAgentError("Processing failed") | |
| # Send message | |
| message_data = { | |
| 'content': 'Hello', | |
| 'session_id': sample_session.id | |
| } | |
| client.emit('message', message_data) | |
| # Check for error event | |
| received = client.get_received() | |
| error_event = next((r for r in received if r['name'] == 'error'), None) | |
| assert error_event is not None | |
| assert error_event['args'][0]['code'] == 'CHAT_AGENT_ERROR' | |
| class TestMessageValidator: | |
| """Tests for message validation functionality.""" | |
| def validator(self): | |
| """Create message validator.""" | |
| return create_message_validator() | |
| def test_valid_message(self, validator): | |
| """Test validation of valid message.""" | |
| message_data = { | |
| 'content': 'Hello, how can I help you with Python?', | |
| 'session_id': 'test-session-123' | |
| } | |
| result = validator.validate_message(message_data) | |
| assert result['valid'] == True | |
| assert result['errors'] == [] | |
| assert result['sanitized_content'] == 'Hello, how can I help you with Python?' | |
| def test_message_missing_content(self, validator): | |
| """Test validation failure for missing content.""" | |
| message_data = { | |
| 'session_id': 'test-session-123' | |
| } | |
| result = validator.validate_message(message_data) | |
| assert result['valid'] == False | |
| assert 'Message content is required' in result['errors'] | |
| def test_message_too_long(self, validator): | |
| """Test validation failure for message too long.""" | |
| long_content = 'x' * (validator.MAX_MESSAGE_LENGTH + 1) | |
| message_data = { | |
| 'content': long_content, | |
| 'session_id': 'test-session-123' | |
| } | |
| result = validator.validate_message(message_data) | |
| assert result['valid'] == False | |
| assert any('too long' in error for error in result['errors']) | |
| def test_message_sanitization(self, validator): | |
| """Test message content sanitization.""" | |
| malicious_content = '<script>alert("xss")</script>Hello world' | |
| message_data = { | |
| 'content': malicious_content, | |
| 'session_id': 'test-session-123' | |
| } | |
| result = validator.validate_message(message_data) | |
| # Should be rejected due to malicious content | |
| assert result['valid'] == False | |
| def test_valid_language_switch(self, validator): | |
| """Test validation of valid language switch.""" | |
| switch_data = { | |
| 'language': 'javascript', | |
| 'session_id': 'test-session-123' | |
| } | |
| result = validator.validate_language_switch(switch_data) | |
| assert result['valid'] == True | |
| assert result['errors'] == [] | |
| assert result['language'] == 'javascript' | |
| def test_invalid_language_switch(self, validator): | |
| """Test validation failure for invalid language.""" | |
| switch_data = { | |
| 'language': 'invalid-language', | |
| 'session_id': 'test-session-123' | |
| } | |
| result = validator.validate_language_switch(switch_data) | |
| assert result['valid'] == False | |
| assert any('Unsupported language' in error for error in result['errors']) | |
| def test_rate_limiting(self, validator): | |
| """Test rate limiting functionality.""" | |
| session_id = 'test-session-rate-limit' | |
| # Send messages up to the limit | |
| for i in range(validator.MAX_MESSAGES_PER_MINUTE): | |
| message_data = { | |
| 'content': f'Message {i}', | |
| 'session_id': session_id | |
| } | |
| result = validator.validate_message(message_data) | |
| assert result['valid'] == True | |
| # Next message should be rate limited | |
| message_data = { | |
| 'content': 'Rate limited message', | |
| 'session_id': session_id | |
| } | |
| result = validator.validate_message(message_data) | |
| assert result['valid'] == False | |
| assert any('Rate limit exceeded' in error for error in result['errors']) | |
| class TestConnectionManager: | |
| """Tests for connection management functionality.""" | |
| def mock_redis(self): | |
| """Create mock Redis client.""" | |
| mock_redis = Mock(spec=redis.Redis) | |
| mock_redis.setex = Mock() | |
| mock_redis.get = Mock(return_value=None) | |
| mock_redis.delete = Mock() | |
| mock_redis.sadd = Mock() | |
| mock_redis.srem = Mock() | |
| mock_redis.smembers = Mock(return_value=set()) | |
| mock_redis.expire = Mock() | |
| return mock_redis | |
| def connection_manager(self, mock_redis): | |
| """Create connection manager.""" | |
| return create_connection_manager(mock_redis) | |
| def test_add_connection(self, connection_manager, mock_redis): | |
| """Test adding a connection.""" | |
| client_id = 'test-client-123' | |
| connection_info = { | |
| 'client_id': client_id, | |
| 'session_id': 'test-session-123', | |
| 'user_id': 'test-user-123', | |
| 'connected_at': datetime.utcnow().isoformat(), | |
| 'language': 'python' | |
| } | |
| connection_manager.add_connection(client_id, connection_info) | |
| # Verify connection was added to memory | |
| retrieved_info = connection_manager.get_connection(client_id) | |
| assert retrieved_info is not None | |
| assert retrieved_info['session_id'] == 'test-session-123' | |
| # Verify Redis calls | |
| mock_redis.setex.assert_called() | |
| mock_redis.sadd.assert_called() | |
| def test_remove_connection(self, connection_manager): | |
| """Test removing a connection.""" | |
| client_id = 'test-client-123' | |
| connection_info = { | |
| 'client_id': client_id, | |
| 'session_id': 'test-session-123', | |
| 'user_id': 'test-user-123', | |
| 'connected_at': datetime.utcnow().isoformat(), | |
| 'language': 'python' | |
| } | |
| # Add connection | |
| connection_manager.add_connection(client_id, connection_info) | |
| # Remove connection | |
| removed_info = connection_manager.remove_connection(client_id) | |
| assert removed_info is not None | |
| assert removed_info['session_id'] == 'test-session-123' | |
| # Verify connection is gone | |
| retrieved_info = connection_manager.get_connection(client_id) | |
| assert retrieved_info is None | |
| def test_update_connection_activity(self, connection_manager): | |
| """Test updating connection activity.""" | |
| client_id = 'test-client-123' | |
| connection_info = { | |
| 'client_id': client_id, | |
| 'session_id': 'test-session-123', | |
| 'user_id': 'test-user-123', | |
| 'connected_at': datetime.utcnow().isoformat(), | |
| 'language': 'python' | |
| } | |
| # Add connection | |
| connection_manager.add_connection(client_id, connection_info) | |
| # Update activity | |
| success = connection_manager.update_connection_activity(client_id) | |
| assert success == True | |
| # Verify last_activity was added | |
| updated_info = connection_manager.get_connection(client_id) | |
| assert 'last_activity' in updated_info | |
| def test_get_connection_stats(self, connection_manager): | |
| """Test getting connection statistics.""" | |
| # Add multiple connections | |
| for i in range(3): | |
| client_id = f'client-{i}' | |
| connection_info = { | |
| 'client_id': client_id, | |
| 'session_id': f'session-{i}', | |
| 'user_id': f'user-{i % 2}', # 2 unique users | |
| 'connected_at': datetime.utcnow().isoformat(), | |
| 'language': 'python' if i % 2 == 0 else 'javascript' | |
| } | |
| connection_manager.add_connection(client_id, connection_info) | |
| stats = connection_manager.get_connection_stats() | |
| assert stats['total_connections'] == 3 | |
| assert stats['unique_sessions'] == 3 | |
| assert stats['unique_users'] == 2 | |
| assert 'python' in stats['languages'] | |
| assert 'javascript' in stats['languages'] | |
| if __name__ == '__main__': | |
| pytest.main([__file__]) |