Spaces:
Runtime error
Runtime error
| """ | |
| Performance tests for concurrent user scenarios and load testing. | |
| This module tests the chat agent's performance under various load conditions | |
| including multiple concurrent users, high message throughput, and stress testing. | |
| """ | |
| import asyncio | |
| import time | |
| import threading | |
| import statistics | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import List, Dict, Any, Tuple | |
| import pytest | |
| import requests | |
| from unittest.mock import Mock, patch | |
| from chat_agent.services.chat_agent import ChatAgent | |
| from chat_agent.services.session_manager import SessionManager | |
| from chat_agent.services.chat_history import ChatHistoryManager | |
| from chat_agent.services.cache_service import CacheService | |
| from chat_agent.utils.connection_pool import ConnectionPoolManager | |
| class PerformanceMetrics: | |
| """Collects and analyzes performance metrics.""" | |
| def __init__(self): | |
| self.response_times = [] | |
| self.error_count = 0 | |
| self.success_count = 0 | |
| self.start_time = None | |
| self.end_time = None | |
| self.concurrent_users = 0 | |
| self.messages_per_second = 0 | |
| def add_response_time(self, response_time: float): | |
| """Add response time measurement.""" | |
| self.response_times.append(response_time) | |
| def add_success(self): | |
| """Record successful operation.""" | |
| self.success_count += 1 | |
| def add_error(self): | |
| """Record failed operation.""" | |
| self.error_count += 1 | |
| def start_timing(self): | |
| """Start timing measurement.""" | |
| self.start_time = time.time() | |
| def end_timing(self): | |
| """End timing measurement.""" | |
| self.end_time = time.time() | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get performance statistics.""" | |
| if not self.response_times: | |
| return { | |
| 'total_requests': 0, | |
| 'success_rate': 0, | |
| 'error_rate': 0 | |
| } | |
| total_time = self.end_time - self.start_time if self.end_time and self.start_time else 0 | |
| total_requests = self.success_count + self.error_count | |
| return { | |
| 'total_requests': total_requests, | |
| 'success_count': self.success_count, | |
| 'error_count': self.error_count, | |
| 'success_rate': (self.success_count / total_requests * 100) if total_requests > 0 else 0, | |
| 'error_rate': (self.error_count / total_requests * 100) if total_requests > 0 else 0, | |
| 'avg_response_time': statistics.mean(self.response_times), | |
| 'median_response_time': statistics.median(self.response_times), | |
| 'min_response_time': min(self.response_times), | |
| 'max_response_time': max(self.response_times), | |
| 'p95_response_time': self._percentile(self.response_times, 95), | |
| 'p99_response_time': self._percentile(self.response_times, 99), | |
| 'total_duration': total_time, | |
| 'requests_per_second': total_requests / total_time if total_time > 0 else 0, | |
| 'concurrent_users': self.concurrent_users | |
| } | |
| def _percentile(self, data: List[float], percentile: int) -> float: | |
| """Calculate percentile of response times.""" | |
| if not data: | |
| return 0 | |
| sorted_data = sorted(data) | |
| index = int(len(sorted_data) * percentile / 100) | |
| return sorted_data[min(index, len(sorted_data) - 1)] | |
| class ConcurrentUserSimulator: | |
| """Simulates concurrent users for load testing.""" | |
| def __init__(self, base_url: str = "http://localhost:5000"): | |
| self.base_url = base_url | |
| self.session = requests.Session() | |
| self.metrics = PerformanceMetrics() | |
| def simulate_user_session(self, user_id: str, num_messages: int = 10) -> Dict[str, Any]: | |
| """ | |
| Simulate a single user chat session. | |
| Args: | |
| user_id: Unique user identifier | |
| num_messages: Number of messages to send | |
| Returns: | |
| Session metrics | |
| """ | |
| session_metrics = { | |
| 'user_id': user_id, | |
| 'messages_sent': 0, | |
| 'messages_failed': 0, | |
| 'response_times': [], | |
| 'session_duration': 0 | |
| } | |
| session_start = time.time() | |
| try: | |
| # Create session | |
| session_data = { | |
| 'language': 'python', | |
| 'metadata': {'test_user': user_id} | |
| } | |
| start_time = time.time() | |
| response = self.session.post( | |
| f"{self.base_url}/api/v1/chat/sessions", | |
| json=session_data, | |
| headers={'Authorization': f'Bearer test-token-{user_id}'}, | |
| timeout=30 | |
| ) | |
| response_time = time.time() - start_time | |
| if response.status_code != 201: | |
| session_metrics['session_creation_failed'] = True | |
| return session_metrics | |
| session_id = response.json()['session_id'] | |
| # Send messages | |
| for i in range(num_messages): | |
| message_data = { | |
| 'content': f'Test message {i+1} from user {user_id}', | |
| 'language': 'python' | |
| } | |
| start_time = time.time() | |
| try: | |
| response = self.session.post( | |
| f"{self.base_url}/api/v1/chat/sessions/{session_id}/messages", | |
| json=message_data, | |
| headers={'Authorization': f'Bearer test-token-{user_id}'}, | |
| timeout=30 | |
| ) | |
| response_time = time.time() - start_time | |
| if response.status_code == 200: | |
| session_metrics['messages_sent'] += 1 | |
| session_metrics['response_times'].append(response_time) | |
| self.metrics.add_response_time(response_time) | |
| self.metrics.add_success() | |
| else: | |
| session_metrics['messages_failed'] += 1 | |
| self.metrics.add_error() | |
| except requests.RequestException as e: | |
| session_metrics['messages_failed'] += 1 | |
| self.metrics.add_error() | |
| # Small delay between messages | |
| time.sleep(0.1) | |
| # Clean up session | |
| self.session.delete( | |
| f"{self.base_url}/api/v1/chat/sessions/{session_id}", | |
| headers={'Authorization': f'Bearer test-token-{user_id}'} | |
| ) | |
| except Exception as e: | |
| session_metrics['session_error'] = str(e) | |
| session_metrics['session_duration'] = time.time() - session_start | |
| return session_metrics | |
| def run_concurrent_test(self, num_users: int, messages_per_user: int = 10) -> Dict[str, Any]: | |
| """ | |
| Run concurrent user test. | |
| Args: | |
| num_users: Number of concurrent users | |
| messages_per_user: Messages per user | |
| Returns: | |
| Test results and metrics | |
| """ | |
| self.metrics = PerformanceMetrics() | |
| self.metrics.concurrent_users = num_users | |
| self.metrics.start_timing() | |
| user_sessions = [] | |
| # Use ThreadPoolExecutor for concurrent execution | |
| with ThreadPoolExecutor(max_workers=min(num_users, 50)) as executor: | |
| # Submit all user sessions | |
| futures = [] | |
| for i in range(num_users): | |
| user_id = f"test_user_{i}" | |
| future = executor.submit(self.simulate_user_session, user_id, messages_per_user) | |
| futures.append(future) | |
| # Collect results | |
| for future in as_completed(futures): | |
| try: | |
| session_result = future.result(timeout=120) # 2 minute timeout per session | |
| user_sessions.append(session_result) | |
| except Exception as e: | |
| user_sessions.append({'error': str(e)}) | |
| self.metrics.add_error() | |
| self.metrics.end_timing() | |
| # Analyze results | |
| successful_sessions = [s for s in user_sessions if 'error' not in s and not s.get('session_creation_failed')] | |
| failed_sessions = len(user_sessions) - len(successful_sessions) | |
| total_messages_sent = sum(s.get('messages_sent', 0) for s in successful_sessions) | |
| total_messages_failed = sum(s.get('messages_failed', 0) for s in successful_sessions) | |
| return { | |
| 'test_config': { | |
| 'concurrent_users': num_users, | |
| 'messages_per_user': messages_per_user, | |
| 'total_expected_messages': num_users * messages_per_user | |
| }, | |
| 'session_results': { | |
| 'successful_sessions': len(successful_sessions), | |
| 'failed_sessions': failed_sessions, | |
| 'session_success_rate': len(successful_sessions) / num_users * 100 | |
| }, | |
| 'message_results': { | |
| 'total_messages_sent': total_messages_sent, | |
| 'total_messages_failed': total_messages_failed, | |
| 'message_success_rate': total_messages_sent / (total_messages_sent + total_messages_failed) * 100 if (total_messages_sent + total_messages_failed) > 0 else 0 | |
| }, | |
| 'performance_metrics': self.metrics.get_statistics(), | |
| 'user_sessions': user_sessions | |
| } | |
| def performance_metrics(): | |
| """Fixture for performance metrics.""" | |
| return PerformanceMetrics() | |
| def mock_services(): | |
| """Fixture for mocked services.""" | |
| with patch('redis.Redis') as mock_redis: | |
| mock_redis_client = Mock() | |
| mock_redis_client.ping.return_value = True | |
| mock_redis.return_value = mock_redis_client | |
| session_manager = Mock(spec=SessionManager) | |
| chat_history_manager = Mock(spec=ChatHistoryManager) | |
| cache_service = Mock(spec=CacheService) | |
| yield { | |
| 'redis_client': mock_redis_client, | |
| 'session_manager': session_manager, | |
| 'chat_history_manager': chat_history_manager, | |
| 'cache_service': cache_service | |
| } | |
| class TestConcurrentUsers: | |
| """Test concurrent user scenarios.""" | |
| def test_single_user_performance(self, mock_services, performance_metrics): | |
| """Test single user performance baseline.""" | |
| # Simulate single user with multiple messages | |
| num_messages = 50 | |
| performance_metrics.start_timing() | |
| for i in range(num_messages): | |
| start_time = time.time() | |
| # Simulate message processing | |
| time.sleep(0.01) # Simulate processing time | |
| response_time = time.time() - start_time | |
| performance_metrics.add_response_time(response_time) | |
| performance_metrics.add_success() | |
| performance_metrics.end_timing() | |
| stats = performance_metrics.get_statistics() | |
| # Assertions for single user performance | |
| assert stats['success_count'] == num_messages | |
| assert stats['error_count'] == 0 | |
| assert stats['success_rate'] == 100.0 | |
| assert stats['avg_response_time'] < 0.1 # Should be fast for single user | |
| def test_concurrent_session_creation(self, mock_services): | |
| """Test concurrent session creation performance.""" | |
| num_concurrent_sessions = 20 | |
| def create_session(user_id: str) -> Tuple[str, float]: | |
| start_time = time.time() | |
| # Mock session creation | |
| session_id = f"session_{user_id}_{int(time.time())}" | |
| time.sleep(0.05) # Simulate database operation | |
| duration = time.time() - start_time | |
| return session_id, duration | |
| # Test concurrent session creation | |
| with ThreadPoolExecutor(max_workers=num_concurrent_sessions) as executor: | |
| futures = [] | |
| start_time = time.time() | |
| for i in range(num_concurrent_sessions): | |
| future = executor.submit(create_session, f"user_{i}") | |
| futures.append(future) | |
| results = [] | |
| for future in as_completed(futures): | |
| session_id, duration = future.result() | |
| results.append((session_id, duration)) | |
| total_time = time.time() - start_time | |
| # Analyze results | |
| assert len(results) == num_concurrent_sessions | |
| durations = [duration for _, duration in results] | |
| avg_duration = sum(durations) / len(durations) | |
| max_duration = max(durations) | |
| # Performance assertions | |
| assert avg_duration < 0.2 # Average session creation should be fast | |
| assert max_duration < 0.5 # Even slowest should be reasonable | |
| assert total_time < 2.0 # Total time should be much less than sequential | |
| def test_concurrent_message_processing(self, mock_services): | |
| """Test concurrent message processing performance.""" | |
| num_concurrent_messages = 30 | |
| def process_message(message_id: str) -> Tuple[str, float, bool]: | |
| start_time = time.time() | |
| try: | |
| # Simulate message processing with some variability | |
| processing_time = 0.02 + (hash(message_id) % 100) / 10000 # 0.02-0.12 seconds | |
| time.sleep(processing_time) | |
| duration = time.time() - start_time | |
| return message_id, duration, True | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| return message_id, duration, False | |
| # Test concurrent message processing | |
| with ThreadPoolExecutor(max_workers=15) as executor: | |
| futures = [] | |
| start_time = time.time() | |
| for i in range(num_concurrent_messages): | |
| future = executor.submit(process_message, f"msg_{i}") | |
| futures.append(future) | |
| results = [] | |
| for future in as_completed(futures): | |
| message_id, duration, success = future.result() | |
| results.append((message_id, duration, success)) | |
| total_time = time.time() - start_time | |
| # Analyze results | |
| successful_results = [r for r in results if r[2]] | |
| failed_results = [r for r in results if not r[2]] | |
| assert len(results) == num_concurrent_messages | |
| assert len(successful_results) == num_concurrent_messages # All should succeed | |
| assert len(failed_results) == 0 | |
| durations = [duration for _, duration, _ in successful_results] | |
| avg_duration = sum(durations) / len(durations) | |
| # Performance assertions | |
| assert avg_duration < 0.2 # Average processing should be reasonable | |
| assert total_time < 5.0 # Total time should be much less than sequential | |
| def test_memory_usage_under_load(self, mock_services): | |
| """Test memory usage under concurrent load.""" | |
| import psutil | |
| import os | |
| process = psutil.Process(os.getpid()) | |
| initial_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| # Simulate high load scenario | |
| num_sessions = 50 | |
| messages_per_session = 20 | |
| # Create mock data structures to simulate memory usage | |
| sessions = {} | |
| messages = {} | |
| for session_id in range(num_sessions): | |
| sessions[f"session_{session_id}"] = { | |
| 'user_id': f"user_{session_id}", | |
| 'language': 'python', | |
| 'created_at': time.time(), | |
| 'messages': [] | |
| } | |
| for msg_id in range(messages_per_session): | |
| message_key = f"session_{session_id}_msg_{msg_id}" | |
| messages[message_key] = { | |
| 'content': f"Test message {msg_id} " * 50, # Larger message | |
| 'timestamp': time.time(), | |
| 'metadata': {'test': True} | |
| } | |
| sessions[f"session_{session_id}"]['messages'].append(message_key) | |
| peak_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| # Clean up | |
| del sessions | |
| del messages | |
| final_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| memory_increase = peak_memory - initial_memory | |
| memory_cleanup = peak_memory - final_memory | |
| # Memory usage assertions | |
| assert memory_increase < 100 # Should not use more than 100MB for test data | |
| assert memory_cleanup > 0 # Memory should be freed after cleanup | |
| def test_database_connection_pool_performance(self, mock_services): | |
| """Test database connection pool under concurrent load.""" | |
| from chat_agent.utils.connection_pool import DatabaseConnectionPool | |
| # Mock database URL | |
| database_url = "sqlite:///:memory:" | |
| pool = DatabaseConnectionPool(database_url, pool_size=5, max_overflow=10) | |
| def execute_query(query_id: str) -> Tuple[str, float, bool]: | |
| start_time = time.time() | |
| try: | |
| with pool.get_connection() as conn: | |
| # Simulate database query | |
| time.sleep(0.01) | |
| result = conn.execute("SELECT 1").fetchone() | |
| duration = time.time() - start_time | |
| return query_id, duration, True | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| return query_id, duration, False | |
| # Test concurrent database access | |
| num_concurrent_queries = 25 | |
| with ThreadPoolExecutor(max_workers=15) as executor: | |
| futures = [] | |
| for i in range(num_concurrent_queries): | |
| future = executor.submit(execute_query, f"query_{i}") | |
| futures.append(future) | |
| results = [] | |
| for future in as_completed(futures): | |
| query_id, duration, success = future.result() | |
| results.append((query_id, duration, success)) | |
| # Analyze results | |
| successful_queries = [r for r in results if r[2]] | |
| failed_queries = [r for r in results if not r[2]] | |
| assert len(successful_queries) == num_concurrent_queries | |
| assert len(failed_queries) == 0 | |
| durations = [duration for _, duration, _ in successful_queries] | |
| avg_duration = sum(durations) / len(durations) | |
| max_duration = max(durations) | |
| # Performance assertions | |
| assert avg_duration < 0.1 # Database queries should be fast | |
| assert max_duration < 0.5 # Even with connection pool contention | |
| # Check pool status | |
| pool_status = pool.get_pool_status() | |
| assert pool_status['pool_size'] >= 0 | |
| assert pool_status['checked_out'] >= 0 | |
| def test_redis_connection_pool_performance(self, mock_services): | |
| """Test Redis connection pool under concurrent load.""" | |
| # This test would require actual Redis connection | |
| # For now, we'll test the mock behavior | |
| redis_client = mock_services['redis_client'] | |
| def redis_operation(operation_id: str) -> Tuple[str, float, bool]: | |
| start_time = time.time() | |
| try: | |
| # Simulate Redis operations | |
| redis_client.set(f"key_{operation_id}", f"value_{operation_id}") | |
| value = redis_client.get(f"key_{operation_id}") | |
| redis_client.delete(f"key_{operation_id}") | |
| duration = time.time() - start_time | |
| return operation_id, duration, True | |
| except Exception as e: | |
| duration = time.time() - start_time | |
| return operation_id, duration, False | |
| # Test concurrent Redis operations | |
| num_concurrent_ops = 30 | |
| with ThreadPoolExecutor(max_workers=10) as executor: | |
| futures = [] | |
| for i in range(num_concurrent_ops): | |
| future = executor.submit(redis_operation, f"op_{i}") | |
| futures.append(future) | |
| results = [] | |
| for future in as_completed(futures): | |
| op_id, duration, success = future.result() | |
| results.append((op_id, duration, success)) | |
| # Analyze results | |
| successful_ops = [r for r in results if r[2]] | |
| assert len(successful_ops) == num_concurrent_ops | |
| durations = [duration for _, duration, _ in successful_ops] | |
| avg_duration = sum(durations) / len(durations) | |
| # Performance assertions (for mocked Redis) | |
| assert avg_duration < 0.01 # Mocked operations should be very fast | |
| class TestLoadTesting: | |
| """Load testing scenarios.""" | |
| def test_sustained_load(self, mock_services): | |
| """Test sustained load over time.""" | |
| duration_seconds = 30 # 30 second test | |
| requests_per_second = 10 | |
| metrics = PerformanceMetrics() | |
| metrics.start_timing() | |
| def generate_load(): | |
| end_time = time.time() + duration_seconds | |
| request_count = 0 | |
| while time.time() < end_time: | |
| start_time = time.time() | |
| # Simulate request processing | |
| time.sleep(0.01) # Simulate work | |
| response_time = time.time() - start_time | |
| metrics.add_response_time(response_time) | |
| metrics.add_success() | |
| request_count += 1 | |
| # Control request rate | |
| elapsed = time.time() - start_time | |
| sleep_time = (1.0 / requests_per_second) - elapsed | |
| if sleep_time > 0: | |
| time.sleep(sleep_time) | |
| # Run load test | |
| generate_load() | |
| metrics.end_timing() | |
| stats = metrics.get_statistics() | |
| # Assertions for sustained load | |
| expected_requests = duration_seconds * requests_per_second | |
| assert stats['total_requests'] >= expected_requests * 0.9 # Allow 10% variance | |
| assert stats['success_rate'] >= 95.0 # 95% success rate minimum | |
| assert stats['avg_response_time'] < 0.1 # Average response time | |
| def test_spike_load(self, mock_services): | |
| """Test handling of sudden load spikes.""" | |
| normal_load_rps = 5 | |
| spike_load_rps = 50 | |
| spike_duration = 10 # seconds | |
| metrics = PerformanceMetrics() | |
| metrics.start_timing() | |
| def simulate_spike(): | |
| # Normal load for 5 seconds | |
| end_normal = time.time() + 5 | |
| while time.time() < end_normal: | |
| start_time = time.time() | |
| time.sleep(0.01) | |
| response_time = time.time() - start_time | |
| metrics.add_response_time(response_time) | |
| metrics.add_success() | |
| time.sleep(1.0 / normal_load_rps - 0.01) | |
| # Spike load for 10 seconds | |
| end_spike = time.time() + spike_duration | |
| while time.time() < end_spike: | |
| start_time = time.time() | |
| time.sleep(0.01) | |
| response_time = time.time() - start_time | |
| metrics.add_response_time(response_time) | |
| if response_time < 0.5: # Consider successful if under 500ms | |
| metrics.add_success() | |
| else: | |
| metrics.add_error() | |
| time.sleep(max(0, 1.0 / spike_load_rps - 0.01)) | |
| # Return to normal load for 5 seconds | |
| end_normal2 = time.time() + 5 | |
| while time.time() < end_normal2: | |
| start_time = time.time() | |
| time.sleep(0.01) | |
| response_time = time.time() - start_time | |
| metrics.add_response_time(response_time) | |
| metrics.add_success() | |
| time.sleep(1.0 / normal_load_rps - 0.01) | |
| simulate_spike() | |
| metrics.end_timing() | |
| stats = metrics.get_statistics() | |
| # Assertions for spike handling | |
| assert stats['success_rate'] >= 80.0 # Should handle most requests even during spike | |
| assert stats['p95_response_time'] < 1.0 # 95th percentile should be reasonable | |
| if __name__ == "__main__": | |
| # Run a simple load test | |
| simulator = ConcurrentUserSimulator() | |
| print("Running concurrent user test...") | |
| results = simulator.run_concurrent_test(num_users=10, messages_per_user=5) | |
| print("\nTest Results:") | |
| print(f"Concurrent Users: {results['test_config']['concurrent_users']}") | |
| print(f"Messages per User: {results['test_config']['messages_per_user']}") | |
| print(f"Session Success Rate: {results['session_results']['session_success_rate']:.1f}%") | |
| print(f"Message Success Rate: {results['message_results']['message_success_rate']:.1f}%") | |
| print(f"Average Response Time: {results['performance_metrics']['avg_response_time']:.3f}s") | |
| print(f"95th Percentile Response Time: {results['performance_metrics']['p95_response_time']:.3f}s") | |
| print(f"Requests per Second: {results['performance_metrics']['requests_per_second']:.1f}") |