Spaces:
Runtime error
Runtime error
| """ | |
| Load testing for multiple concurrent chat sessions. | |
| Tests system performance under various load conditions. | |
| """ | |
| import pytest | |
| import asyncio | |
| import time | |
| import threading | |
| import concurrent.futures | |
| from unittest.mock import patch, MagicMock | |
| import statistics | |
| from chat_agent.services.chat_agent import ChatAgent | |
| from chat_agent.services.session_manager import SessionManager | |
| from chat_agent.services.language_context import LanguageContextManager | |
| from chat_agent.services.chat_history import ChatHistoryManager | |
| class TestConcurrentChatSessions: | |
| """Load testing for concurrent chat sessions.""" | |
| def mock_groq_client(self): | |
| """Mock Groq client with realistic response times.""" | |
| with patch('chat_agent.services.groq_client.GroqClient') as mock: | |
| mock_instance = MagicMock() | |
| def mock_generate_response(*args, **kwargs): | |
| # Simulate realistic API response time | |
| time.sleep(0.1 + (time.time() % 0.1)) # 100-200ms | |
| return f"Test response for concurrent user at {time.time()}" | |
| mock_instance.generate_response.side_effect = mock_generate_response | |
| mock.return_value = mock_instance | |
| yield mock_instance | |
| def chat_system(self, mock_groq_client): | |
| """Create complete chat system for load testing.""" | |
| session_manager = SessionManager() | |
| language_context_manager = LanguageContextManager() | |
| chat_history_manager = ChatHistoryManager() | |
| chat_agent = ChatAgent( | |
| groq_client=mock_groq_client, | |
| session_manager=session_manager, | |
| language_context_manager=language_context_manager, | |
| chat_history_manager=chat_history_manager | |
| ) | |
| return { | |
| 'chat_agent': chat_agent, | |
| 'session_manager': session_manager, | |
| 'language_context_manager': language_context_manager, | |
| 'chat_history_manager': chat_history_manager | |
| } | |
| def simulate_user_session(self, user_id, chat_system, num_messages=5): | |
| """Simulate a complete user session with multiple messages.""" | |
| results = { | |
| 'user_id': user_id, | |
| 'session_id': None, | |
| 'messages_sent': 0, | |
| 'responses_received': 0, | |
| 'errors': 0, | |
| 'response_times': [], | |
| 'total_time': 0, | |
| 'success': False | |
| } | |
| start_time = time.time() | |
| try: | |
| # Create session | |
| session = chat_system['session_manager'].create_session( | |
| user_id, language="python" | |
| ) | |
| results['session_id'] = session['session_id'] | |
| # Send multiple messages | |
| messages = [ | |
| "What is Python?", | |
| "How do I create a list?", | |
| "Explain functions", | |
| "What are loops?", | |
| "How to handle errors?" | |
| ] | |
| for i in range(min(num_messages, len(messages))): | |
| message_start = time.time() | |
| try: | |
| response = chat_system['chat_agent'].process_message( | |
| session_id=session['session_id'], | |
| message=messages[i], | |
| language="python" | |
| ) | |
| message_time = time.time() - message_start | |
| results['response_times'].append(message_time) | |
| results['messages_sent'] += 1 | |
| if response and len(response) > 0: | |
| results['responses_received'] += 1 | |
| except Exception as e: | |
| results['errors'] += 1 | |
| print(f"Error in user {user_id} message {i}: {e}") | |
| results['success'] = results['errors'] == 0 | |
| except Exception as e: | |
| results['errors'] += 1 | |
| print(f"Error creating session for user {user_id}: {e}") | |
| results['total_time'] = time.time() - start_time | |
| return results | |
| def test_concurrent_users_light_load(self, chat_system): | |
| """Test with 10 concurrent users (light load).""" | |
| num_users = 10 | |
| messages_per_user = 3 | |
| # Create user IDs | |
| user_ids = [f"load-test-user-{i}" for i in range(num_users)] | |
| # Run concurrent sessions | |
| start_time = time.time() | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: | |
| futures = [ | |
| executor.submit( | |
| self.simulate_user_session, | |
| user_id, | |
| chat_system, | |
| messages_per_user | |
| ) | |
| for user_id in user_ids | |
| ] | |
| results = [future.result() for future in concurrent.futures.as_completed(futures)] | |
| total_time = time.time() - start_time | |
| # Analyze results | |
| successful_sessions = [r for r in results if r['success']] | |
| failed_sessions = [r for r in results if not r['success']] | |
| total_messages = sum(r['messages_sent'] for r in results) | |
| total_responses = sum(r['responses_received'] for r in results) | |
| total_errors = sum(r['errors'] for r in results) | |
| all_response_times = [] | |
| for r in results: | |
| all_response_times.extend(r['response_times']) | |
| # Assertions for light load | |
| assert len(successful_sessions) >= 8, f"Expected at least 8 successful sessions, got {len(successful_sessions)}" | |
| assert total_errors <= 2, f"Expected at most 2 errors, got {total_errors}" | |
| assert total_responses >= total_messages * 0.8, "Expected at least 80% response rate" | |
| if all_response_times: | |
| avg_response_time = statistics.mean(all_response_times) | |
| assert avg_response_time < 1.0, f"Average response time too high: {avg_response_time}s" | |
| print(f"Light Load Test Results:") | |
| print(f" Users: {num_users}") | |
| print(f" Successful sessions: {len(successful_sessions)}") | |
| print(f" Failed sessions: {len(failed_sessions)}") | |
| print(f" Total messages: {total_messages}") | |
| print(f" Total responses: {total_responses}") | |
| print(f" Total errors: {total_errors}") | |
| print(f" Total time: {total_time:.2f}s") | |
| if all_response_times: | |
| print(f" Avg response time: {statistics.mean(all_response_times):.3f}s") | |
| print(f" Max response time: {max(all_response_times):.3f}s") | |
| def test_concurrent_users_medium_load(self, chat_system): | |
| """Test with 25 concurrent users (medium load).""" | |
| num_users = 25 | |
| messages_per_user = 4 | |
| user_ids = [f"medium-load-user-{i}" for i in range(num_users)] | |
| start_time = time.time() | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: | |
| futures = [ | |
| executor.submit( | |
| self.simulate_user_session, | |
| user_id, | |
| chat_system, | |
| messages_per_user | |
| ) | |
| for user_id in user_ids | |
| ] | |
| results = [future.result() for future in concurrent.futures.as_completed(futures)] | |
| total_time = time.time() - start_time | |
| # Analyze results | |
| successful_sessions = [r for r in results if r['success']] | |
| total_messages = sum(r['messages_sent'] for r in results) | |
| total_responses = sum(r['responses_received'] for r in results) | |
| total_errors = sum(r['errors'] for r in results) | |
| all_response_times = [] | |
| for r in results: | |
| all_response_times.extend(r['response_times']) | |
| # Assertions for medium load (more lenient) | |
| assert len(successful_sessions) >= 20, f"Expected at least 20 successful sessions, got {len(successful_sessions)}" | |
| assert total_errors <= 10, f"Expected at most 10 errors, got {total_errors}" | |
| assert total_responses >= total_messages * 0.7, "Expected at least 70% response rate" | |
| if all_response_times: | |
| avg_response_time = statistics.mean(all_response_times) | |
| assert avg_response_time < 2.0, f"Average response time too high: {avg_response_time}s" | |
| print(f"Medium Load Test Results:") | |
| print(f" Users: {num_users}") | |
| print(f" Successful sessions: {len(successful_sessions)}") | |
| print(f" Total messages: {total_messages}") | |
| print(f" Total responses: {total_responses}") | |
| print(f" Total errors: {total_errors}") | |
| print(f" Total time: {total_time:.2f}s") | |
| if all_response_times: | |
| print(f" Avg response time: {statistics.mean(all_response_times):.3f}s") | |
| def test_concurrent_users_heavy_load(self, chat_system): | |
| """Test with 50 concurrent users (heavy load).""" | |
| num_users = 50 | |
| messages_per_user = 3 | |
| user_ids = [f"heavy-load-user-{i}" for i in range(num_users)] | |
| start_time = time.time() | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: | |
| futures = [ | |
| executor.submit( | |
| self.simulate_user_session, | |
| user_id, | |
| chat_system, | |
| messages_per_user | |
| ) | |
| for user_id in user_ids | |
| ] | |
| results = [future.result() for future in concurrent.futures.as_completed(futures)] | |
| total_time = time.time() - start_time | |
| # Analyze results | |
| successful_sessions = [r for r in results if r['success']] | |
| total_messages = sum(r['messages_sent'] for r in results) | |
| total_responses = sum(r['responses_received'] for r in results) | |
| total_errors = sum(r['errors'] for r in results) | |
| all_response_times = [] | |
| for r in results: | |
| all_response_times.extend(r['response_times']) | |
| # Assertions for heavy load (most lenient) | |
| assert len(successful_sessions) >= 35, f"Expected at least 35 successful sessions, got {len(successful_sessions)}" | |
| assert total_errors <= 25, f"Expected at most 25 errors, got {total_errors}" | |
| assert total_responses >= total_messages * 0.6, "Expected at least 60% response rate" | |
| if all_response_times: | |
| avg_response_time = statistics.mean(all_response_times) | |
| assert avg_response_time < 5.0, f"Average response time too high: {avg_response_time}s" | |
| print(f"Heavy Load Test Results:") | |
| print(f" Users: {num_users}") | |
| print(f" Successful sessions: {len(successful_sessions)}") | |
| print(f" Total messages: {total_messages}") | |
| print(f" Total responses: {total_responses}") | |
| print(f" Total errors: {total_errors}") | |
| print(f" Total time: {total_time:.2f}s") | |
| if all_response_times: | |
| print(f" Avg response time: {statistics.mean(all_response_times):.3f}s") | |
| def test_sustained_load(self, chat_system): | |
| """Test sustained load over time.""" | |
| duration_seconds = 30 # 30 second test | |
| users_per_wave = 5 | |
| wave_interval = 2 # New wave every 2 seconds | |
| results = [] | |
| start_time = time.time() | |
| wave_count = 0 | |
| while time.time() - start_time < duration_seconds: | |
| wave_start = time.time() | |
| wave_count += 1 | |
| # Create user IDs for this wave | |
| user_ids = [f"sustained-wave-{wave_count}-user-{i}" for i in range(users_per_wave)] | |
| # Launch concurrent sessions for this wave | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=users_per_wave) as executor: | |
| futures = [ | |
| executor.submit( | |
| self.simulate_user_session, | |
| user_id, | |
| chat_system, | |
| 2 # 2 messages per user | |
| ) | |
| for user_id in user_ids | |
| ] | |
| wave_results = [future.result() for future in concurrent.futures.as_completed(futures)] | |
| results.extend(wave_results) | |
| # Wait for next wave | |
| elapsed = time.time() - wave_start | |
| if elapsed < wave_interval: | |
| time.sleep(wave_interval - elapsed) | |
| total_time = time.time() - start_time | |
| # Analyze sustained load results | |
| successful_sessions = [r for r in results if r['success']] | |
| total_messages = sum(r['messages_sent'] for r in results) | |
| total_responses = sum(r['responses_received'] for r in results) | |
| total_errors = sum(r['errors'] for r in results) | |
| # Assertions for sustained load | |
| success_rate = len(successful_sessions) / len(results) if results else 0 | |
| response_rate = total_responses / total_messages if total_messages > 0 else 0 | |
| assert success_rate >= 0.7, f"Expected at least 70% success rate, got {success_rate:.2%}" | |
| assert response_rate >= 0.6, f"Expected at least 60% response rate, got {response_rate:.2%}" | |
| print(f"Sustained Load Test Results:") | |
| print(f" Duration: {total_time:.1f}s") | |
| print(f" Waves: {wave_count}") | |
| print(f" Total users: {len(results)}") | |
| print(f" Successful sessions: {len(successful_sessions)} ({success_rate:.1%})") | |
| print(f" Total messages: {total_messages}") | |
| print(f" Total responses: {total_responses} ({response_rate:.1%})") | |
| print(f" Total errors: {total_errors}") | |
| def test_memory_usage_under_load(self, chat_system): | |
| """Test memory usage during concurrent sessions.""" | |
| import psutil | |
| import os | |
| process = psutil.Process(os.getpid()) | |
| initial_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| num_users = 20 | |
| messages_per_user = 5 | |
| user_ids = [f"memory-test-user-{i}" for i in range(num_users)] | |
| # Run concurrent sessions | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: | |
| futures = [ | |
| executor.submit( | |
| self.simulate_user_session, | |
| user_id, | |
| chat_system, | |
| messages_per_user | |
| ) | |
| for user_id in user_ids | |
| ] | |
| results = [future.result() for future in concurrent.futures.as_completed(futures)] | |
| final_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| memory_increase = final_memory - initial_memory | |
| # Memory usage assertions | |
| assert memory_increase < 100, f"Memory increase too high: {memory_increase:.1f}MB" | |
| successful_sessions = [r for r in results if r['success']] | |
| assert len(successful_sessions) >= 15, "Expected at least 15 successful sessions" | |
| print(f"Memory Usage Test Results:") | |
| print(f" Initial memory: {initial_memory:.1f}MB") | |
| print(f" Final memory: {final_memory:.1f}MB") | |
| print(f" Memory increase: {memory_increase:.1f}MB") | |
| print(f" Successful sessions: {len(successful_sessions)}") | |
| if __name__ == '__main__': | |
| pytest.main([__file__, '-v', '-s']) |