scratch_chat / tests /integration /test_websocket_integration.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""
Integration tests for WebSocket communication layer.
Tests the complete WebSocket message flow, error handling, and real-time communication
features of the multi-language chat agent.
"""
import pytest
import json
import time
from datetime import datetime
from unittest.mock import Mock, patch, MagicMock
from uuid import uuid4
from flask import Flask
from flask_socketio import SocketIO, SocketIOTestClient
import redis
from chat_agent.websocket import (
ChatWebSocketHandler, MessageValidator, ConnectionManager,
create_chat_websocket_handler, create_message_validator,
create_connection_manager, register_websocket_events
)
from chat_agent.services.chat_agent import ChatAgent, ChatAgentError
from chat_agent.services.session_manager import SessionManager, SessionNotFoundError
from chat_agent.models.chat_session import ChatSession
class TestWebSocketIntegration:
"""Integration tests for WebSocket communication."""
@pytest.fixture
def app(self):
"""Create Flask app for testing."""
app = Flask(__name__)
app.config['TESTING'] = True
app.config['SECRET_KEY'] = 'test-secret-key'
return app
@pytest.fixture
def socketio(self, app):
"""Create SocketIO instance for testing."""
return SocketIO(app, cors_allowed_origins="*")
@pytest.fixture
def mock_redis(self):
"""Create mock Redis client."""
mock_redis = Mock(spec=redis.Redis)
mock_redis.setex = Mock()
mock_redis.get = Mock(return_value=None)
mock_redis.delete = Mock()
mock_redis.sadd = Mock()
mock_redis.srem = Mock()
mock_redis.smembers = Mock(return_value=set())
mock_redis.expire = Mock()
return mock_redis
@pytest.fixture
def mock_chat_agent(self):
"""Create mock chat agent."""
mock_agent = Mock(spec=ChatAgent)
mock_agent.stream_response = Mock()
mock_agent.switch_language = Mock()
mock_agent.get_session_info = Mock()
return mock_agent
@pytest.fixture
def mock_session_manager(self):
"""Create mock session manager."""
mock_manager = Mock(spec=SessionManager)
mock_manager.get_session = Mock()
mock_manager.update_session_activity = Mock()
return mock_manager
@pytest.fixture
def connection_manager(self, mock_redis):
"""Create connection manager."""
return create_connection_manager(mock_redis)
@pytest.fixture
def message_validator(self):
"""Create message validator."""
return create_message_validator()
@pytest.fixture
def websocket_handler(self, mock_chat_agent, mock_session_manager, connection_manager):
"""Create WebSocket handler."""
return create_chat_websocket_handler(
mock_chat_agent, mock_session_manager, connection_manager
)
@pytest.fixture
def client(self, app, socketio, websocket_handler):
"""Create SocketIO test client."""
register_websocket_events(socketio, websocket_handler)
return socketio.test_client(app)
@pytest.fixture
def sample_session(self):
"""Create sample chat session."""
session = Mock(spec=ChatSession)
session.id = str(uuid4())
session.user_id = str(uuid4())
session.language = 'python'
session.message_count = 5
session.is_active = True
return session
def test_successful_connection(self, client, mock_session_manager, sample_session):
"""Test successful WebSocket connection."""
# Setup
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
# Connect
assert client.is_connected() == False
client.connect(auth=auth_data)
assert client.is_connected() == True
# Verify session validation was called
mock_session_manager.get_session.assert_called_once_with(sample_session.id)
mock_session_manager.update_session_activity.assert_called_once_with(sample_session.id)
# Check connection status event
received = client.get_received()
assert len(received) == 1
assert received[0]['name'] == 'connection_status'
assert received[0]['args'][0]['status'] == 'connected'
assert received[0]['args'][0]['session_id'] == sample_session.id
def test_connection_rejected_missing_auth(self, client):
"""Test connection rejection due to missing auth data."""
# Try to connect without auth
client.connect()
assert client.is_connected() == False
def test_connection_rejected_invalid_session(self, client, mock_session_manager):
"""Test connection rejection due to invalid session."""
# Setup
mock_session_manager.get_session.side_effect = SessionNotFoundError("Session not found")
auth_data = {
'session_id': 'invalid-session',
'user_id': 'test-user'
}
# Try to connect
client.connect(auth=auth_data)
assert client.is_connected() == False
def test_connection_rejected_user_mismatch(self, client, mock_session_manager, sample_session):
"""Test connection rejection due to user mismatch."""
# Setup
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': 'different-user' # Different from session.user_id
}
# Try to connect
client.connect(auth=auth_data)
assert client.is_connected() == False
def test_message_processing_success(self, client, mock_session_manager, mock_chat_agent, sample_session):
"""Test successful message processing with streaming response."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Setup streaming response
mock_response_chunks = [
{'type': 'start', 'session_id': sample_session.id, 'language': 'python', 'timestamp': datetime.utcnow().isoformat()},
{'type': 'chunk', 'content': 'Hello', 'timestamp': datetime.utcnow().isoformat()},
{'type': 'chunk', 'content': ' world!', 'timestamp': datetime.utcnow().isoformat()},
{'type': 'complete', 'message_id': str(uuid4()), 'total_chunks': 2, 'processing_time': 0.5, 'timestamp': datetime.utcnow().isoformat()}
]
mock_chat_agent.stream_response.return_value = iter(mock_response_chunks)
# Send message
message_data = {
'content': 'Hello, how are you?',
'session_id': sample_session.id
}
client.emit('message', message_data)
# Verify chat agent was called
mock_chat_agent.stream_response.assert_called_once_with(
sample_session.id, 'Hello, how are you?', None
)
# Check received events
received = client.get_received()
# Should have: connection_status, message_received, processing_status, response_start, 2x response_chunk, response_complete
assert len(received) >= 6
# Find specific events
message_received = next((r for r in received if r['name'] == 'message_received'), None)
assert message_received is not None
processing_status = next((r for r in received if r['name'] == 'processing_status'), None)
assert processing_status is not None
assert processing_status['args'][0]['status'] == 'processing'
response_start = next((r for r in received if r['name'] == 'response_start'), None)
assert response_start is not None
response_chunks = [r for r in received if r['name'] == 'response_chunk']
assert len(response_chunks) == 2
response_complete = next((r for r in received if r['name'] == 'response_complete'), None)
assert response_complete is not None
def test_message_validation_failure(self, client, mock_session_manager, sample_session):
"""Test message validation failure."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Send invalid message (missing content)
invalid_message = {
'session_id': sample_session.id
# Missing 'content' field
}
client.emit('message', invalid_message)
# Check for error event
received = client.get_received()
error_event = next((r for r in received if r['name'] == 'error'), None)
assert error_event is not None
assert error_event['args'][0]['code'] == 'INVALID_MESSAGE'
def test_language_switch_success(self, client, mock_session_manager, mock_chat_agent, sample_session):
"""Test successful language switching."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Setup language switch response
switch_result = {
'success': True,
'previous_language': 'python',
'new_language': 'javascript',
'message': 'Language switched to JavaScript',
'timestamp': datetime.utcnow().isoformat()
}
mock_chat_agent.switch_language.return_value = switch_result
# Send language switch request
switch_data = {
'language': 'javascript',
'session_id': sample_session.id
}
client.emit('language_switch', switch_data)
# Verify chat agent was called
mock_chat_agent.switch_language.assert_called_once_with(sample_session.id, 'javascript')
# Check for language_switched event
received = client.get_received()
language_switched = next((r for r in received if r['name'] == 'language_switched'), None)
assert language_switched is not None
assert language_switched['args'][0]['new_language'] == 'javascript'
assert language_switched['args'][0]['previous_language'] == 'python'
def test_language_switch_invalid_language(self, client, mock_session_manager, sample_session):
"""Test language switch with invalid language."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Send invalid language switch request
switch_data = {
'language': 'invalid-language',
'session_id': sample_session.id
}
client.emit('language_switch', switch_data)
# Check for error event
received = client.get_received()
error_event = next((r for r in received if r['name'] == 'error'), None)
assert error_event is not None
assert error_event['args'][0]['code'] == 'INVALID_LANGUAGE_SWITCH'
def test_typing_indicators(self, client, mock_session_manager, sample_session):
"""Test typing indicator functionality."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Send typing start
client.emit('typing_start', {})
# Send typing stop
client.emit('typing_stop', {})
# Note: typing events are broadcast to room excluding sender,
# so we won't see them in our own client's received events
# This test mainly verifies no errors occur
received = client.get_received()
error_events = [r for r in received if r['name'] == 'error']
assert len(error_events) == 0
def test_ping_pong(self, client, mock_session_manager, sample_session):
"""Test ping/pong for connection health checks."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Send ping
ping_timestamp = datetime.utcnow().isoformat()
client.emit('ping', {'timestamp': ping_timestamp})
# Check for pong response
received = client.get_received()
pong_event = next((r for r in received if r['name'] == 'pong'), None)
assert pong_event is not None
assert pong_event['args'][0]['client_timestamp'] == ping_timestamp
def test_session_info_request(self, client, mock_session_manager, mock_chat_agent, sample_session):
"""Test session info request."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Setup session info response
session_info = {
'session': {
'id': sample_session.id,
'user_id': sample_session.user_id,
'language': 'python',
'message_count': 5
},
'language_context': {'current_language': 'python'},
'statistics': {'total_messages': 5},
'supported_languages': ['python', 'javascript', 'java']
}
mock_chat_agent.get_session_info.return_value = session_info
# Request session info
client.emit('get_session_info', {})
# Verify chat agent was called
mock_chat_agent.get_session_info.assert_called_once_with(sample_session.id)
# Check for session_info event
received = client.get_received()
session_info_event = next((r for r in received if r['name'] == 'session_info'), None)
assert session_info_event is not None
assert session_info_event['args'][0]['session']['id'] == sample_session.id
def test_disconnect_cleanup(self, client, mock_session_manager, sample_session, connection_manager):
"""Test proper cleanup on disconnect."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Verify connection was added
connections = connection_manager.get_all_connections()
assert len(connections) > 0
# Disconnect
client.disconnect()
# Verify cleanup
mock_session_manager.update_session_activity.assert_called()
def test_error_handling_chat_agent_error(self, client, mock_session_manager, mock_chat_agent, sample_session):
"""Test error handling when chat agent fails."""
# Setup connection
mock_session_manager.get_session.return_value = sample_session
auth_data = {
'session_id': sample_session.id,
'user_id': sample_session.user_id
}
client.connect(auth=auth_data)
# Setup chat agent to raise error
mock_chat_agent.stream_response.side_effect = ChatAgentError("Processing failed")
# Send message
message_data = {
'content': 'Hello',
'session_id': sample_session.id
}
client.emit('message', message_data)
# Check for error event
received = client.get_received()
error_event = next((r for r in received if r['name'] == 'error'), None)
assert error_event is not None
assert error_event['args'][0]['code'] == 'CHAT_AGENT_ERROR'
class TestMessageValidator:
"""Tests for message validation functionality."""
@pytest.fixture
def validator(self):
"""Create message validator."""
return create_message_validator()
def test_valid_message(self, validator):
"""Test validation of valid message."""
message_data = {
'content': 'Hello, how can I help you with Python?',
'session_id': 'test-session-123'
}
result = validator.validate_message(message_data)
assert result['valid'] == True
assert result['errors'] == []
assert result['sanitized_content'] == 'Hello, how can I help you with Python?'
def test_message_missing_content(self, validator):
"""Test validation failure for missing content."""
message_data = {
'session_id': 'test-session-123'
}
result = validator.validate_message(message_data)
assert result['valid'] == False
assert 'Message content is required' in result['errors']
def test_message_too_long(self, validator):
"""Test validation failure for message too long."""
long_content = 'x' * (validator.MAX_MESSAGE_LENGTH + 1)
message_data = {
'content': long_content,
'session_id': 'test-session-123'
}
result = validator.validate_message(message_data)
assert result['valid'] == False
assert any('too long' in error for error in result['errors'])
def test_message_sanitization(self, validator):
"""Test message content sanitization."""
malicious_content = '<script>alert("xss")</script>Hello world'
message_data = {
'content': malicious_content,
'session_id': 'test-session-123'
}
result = validator.validate_message(message_data)
# Should be rejected due to malicious content
assert result['valid'] == False
def test_valid_language_switch(self, validator):
"""Test validation of valid language switch."""
switch_data = {
'language': 'javascript',
'session_id': 'test-session-123'
}
result = validator.validate_language_switch(switch_data)
assert result['valid'] == True
assert result['errors'] == []
assert result['language'] == 'javascript'
def test_invalid_language_switch(self, validator):
"""Test validation failure for invalid language."""
switch_data = {
'language': 'invalid-language',
'session_id': 'test-session-123'
}
result = validator.validate_language_switch(switch_data)
assert result['valid'] == False
assert any('Unsupported language' in error for error in result['errors'])
def test_rate_limiting(self, validator):
"""Test rate limiting functionality."""
session_id = 'test-session-rate-limit'
# Send messages up to the limit
for i in range(validator.MAX_MESSAGES_PER_MINUTE):
message_data = {
'content': f'Message {i}',
'session_id': session_id
}
result = validator.validate_message(message_data)
assert result['valid'] == True
# Next message should be rate limited
message_data = {
'content': 'Rate limited message',
'session_id': session_id
}
result = validator.validate_message(message_data)
assert result['valid'] == False
assert any('Rate limit exceeded' in error for error in result['errors'])
class TestConnectionManager:
"""Tests for connection management functionality."""
@pytest.fixture
def mock_redis(self):
"""Create mock Redis client."""
mock_redis = Mock(spec=redis.Redis)
mock_redis.setex = Mock()
mock_redis.get = Mock(return_value=None)
mock_redis.delete = Mock()
mock_redis.sadd = Mock()
mock_redis.srem = Mock()
mock_redis.smembers = Mock(return_value=set())
mock_redis.expire = Mock()
return mock_redis
@pytest.fixture
def connection_manager(self, mock_redis):
"""Create connection manager."""
return create_connection_manager(mock_redis)
def test_add_connection(self, connection_manager, mock_redis):
"""Test adding a connection."""
client_id = 'test-client-123'
connection_info = {
'client_id': client_id,
'session_id': 'test-session-123',
'user_id': 'test-user-123',
'connected_at': datetime.utcnow().isoformat(),
'language': 'python'
}
connection_manager.add_connection(client_id, connection_info)
# Verify connection was added to memory
retrieved_info = connection_manager.get_connection(client_id)
assert retrieved_info is not None
assert retrieved_info['session_id'] == 'test-session-123'
# Verify Redis calls
mock_redis.setex.assert_called()
mock_redis.sadd.assert_called()
def test_remove_connection(self, connection_manager):
"""Test removing a connection."""
client_id = 'test-client-123'
connection_info = {
'client_id': client_id,
'session_id': 'test-session-123',
'user_id': 'test-user-123',
'connected_at': datetime.utcnow().isoformat(),
'language': 'python'
}
# Add connection
connection_manager.add_connection(client_id, connection_info)
# Remove connection
removed_info = connection_manager.remove_connection(client_id)
assert removed_info is not None
assert removed_info['session_id'] == 'test-session-123'
# Verify connection is gone
retrieved_info = connection_manager.get_connection(client_id)
assert retrieved_info is None
def test_update_connection_activity(self, connection_manager):
"""Test updating connection activity."""
client_id = 'test-client-123'
connection_info = {
'client_id': client_id,
'session_id': 'test-session-123',
'user_id': 'test-user-123',
'connected_at': datetime.utcnow().isoformat(),
'language': 'python'
}
# Add connection
connection_manager.add_connection(client_id, connection_info)
# Update activity
success = connection_manager.update_connection_activity(client_id)
assert success == True
# Verify last_activity was added
updated_info = connection_manager.get_connection(client_id)
assert 'last_activity' in updated_info
def test_get_connection_stats(self, connection_manager):
"""Test getting connection statistics."""
# Add multiple connections
for i in range(3):
client_id = f'client-{i}'
connection_info = {
'client_id': client_id,
'session_id': f'session-{i}',
'user_id': f'user-{i % 2}', # 2 unique users
'connected_at': datetime.utcnow().isoformat(),
'language': 'python' if i % 2 == 0 else 'javascript'
}
connection_manager.add_connection(client_id, connection_info)
stats = connection_manager.get_connection_stats()
assert stats['total_connections'] == 3
assert stats['unique_sessions'] == 3
assert stats['unique_users'] == 2
assert 'python' in stats['languages']
assert 'javascript' in stats['languages']
if __name__ == '__main__':
pytest.main([__file__])