""" Load testing for multiple concurrent chat sessions. Tests system performance under various load conditions. """ import pytest import asyncio import time import threading import concurrent.futures from unittest.mock import patch, MagicMock import statistics from chat_agent.services.chat_agent import ChatAgent from chat_agent.services.session_manager import SessionManager from chat_agent.services.language_context import LanguageContextManager from chat_agent.services.chat_history import ChatHistoryManager class TestConcurrentChatSessions: """Load testing for concurrent chat sessions.""" @pytest.fixture def mock_groq_client(self): """Mock Groq client with realistic response times.""" with patch('chat_agent.services.groq_client.GroqClient') as mock: mock_instance = MagicMock() def mock_generate_response(*args, **kwargs): # Simulate realistic API response time time.sleep(0.1 + (time.time() % 0.1)) # 100-200ms return f"Test response for concurrent user at {time.time()}" mock_instance.generate_response.side_effect = mock_generate_response mock.return_value = mock_instance yield mock_instance @pytest.fixture def chat_system(self, mock_groq_client): """Create complete chat system for load testing.""" session_manager = SessionManager() language_context_manager = LanguageContextManager() chat_history_manager = ChatHistoryManager() chat_agent = ChatAgent( groq_client=mock_groq_client, session_manager=session_manager, language_context_manager=language_context_manager, chat_history_manager=chat_history_manager ) return { 'chat_agent': chat_agent, 'session_manager': session_manager, 'language_context_manager': language_context_manager, 'chat_history_manager': chat_history_manager } def simulate_user_session(self, user_id, chat_system, num_messages=5): """Simulate a complete user session with multiple messages.""" results = { 'user_id': user_id, 'session_id': None, 'messages_sent': 0, 'responses_received': 0, 'errors': 0, 'response_times': [], 'total_time': 0, 'success': False } start_time = time.time() try: # Create session session = chat_system['session_manager'].create_session( user_id, language="python" ) results['session_id'] = session['session_id'] # Send multiple messages messages = [ "What is Python?", "How do I create a list?", "Explain functions", "What are loops?", "How to handle errors?" ] for i in range(min(num_messages, len(messages))): message_start = time.time() try: response = chat_system['chat_agent'].process_message( session_id=session['session_id'], message=messages[i], language="python" ) message_time = time.time() - message_start results['response_times'].append(message_time) results['messages_sent'] += 1 if response and len(response) > 0: results['responses_received'] += 1 except Exception as e: results['errors'] += 1 print(f"Error in user {user_id} message {i}: {e}") results['success'] = results['errors'] == 0 except Exception as e: results['errors'] += 1 print(f"Error creating session for user {user_id}: {e}") results['total_time'] = time.time() - start_time return results def test_concurrent_users_light_load(self, chat_system): """Test with 10 concurrent users (light load).""" num_users = 10 messages_per_user = 3 # Create user IDs user_ids = [f"load-test-user-{i}" for i in range(num_users)] # Run concurrent sessions start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: futures = [ executor.submit( self.simulate_user_session, user_id, chat_system, messages_per_user ) for user_id in user_ids ] results = [future.result() for future in concurrent.futures.as_completed(futures)] total_time = time.time() - start_time # Analyze results successful_sessions = [r for r in results if r['success']] failed_sessions = [r for r in results if not r['success']] total_messages = sum(r['messages_sent'] for r in results) total_responses = sum(r['responses_received'] for r in results) total_errors = sum(r['errors'] for r in results) all_response_times = [] for r in results: all_response_times.extend(r['response_times']) # Assertions for light load assert len(successful_sessions) >= 8, f"Expected at least 8 successful sessions, got {len(successful_sessions)}" assert total_errors <= 2, f"Expected at most 2 errors, got {total_errors}" assert total_responses >= total_messages * 0.8, "Expected at least 80% response rate" if all_response_times: avg_response_time = statistics.mean(all_response_times) assert avg_response_time < 1.0, f"Average response time too high: {avg_response_time}s" print(f"Light Load Test Results:") print(f" Users: {num_users}") print(f" Successful sessions: {len(successful_sessions)}") print(f" Failed sessions: {len(failed_sessions)}") print(f" Total messages: {total_messages}") print(f" Total responses: {total_responses}") print(f" Total errors: {total_errors}") print(f" Total time: {total_time:.2f}s") if all_response_times: print(f" Avg response time: {statistics.mean(all_response_times):.3f}s") print(f" Max response time: {max(all_response_times):.3f}s") def test_concurrent_users_medium_load(self, chat_system): """Test with 25 concurrent users (medium load).""" num_users = 25 messages_per_user = 4 user_ids = [f"medium-load-user-{i}" for i in range(num_users)] start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: futures = [ executor.submit( self.simulate_user_session, user_id, chat_system, messages_per_user ) for user_id in user_ids ] results = [future.result() for future in concurrent.futures.as_completed(futures)] total_time = time.time() - start_time # Analyze results successful_sessions = [r for r in results if r['success']] total_messages = sum(r['messages_sent'] for r in results) total_responses = sum(r['responses_received'] for r in results) total_errors = sum(r['errors'] for r in results) all_response_times = [] for r in results: all_response_times.extend(r['response_times']) # Assertions for medium load (more lenient) assert len(successful_sessions) >= 20, f"Expected at least 20 successful sessions, got {len(successful_sessions)}" assert total_errors <= 10, f"Expected at most 10 errors, got {total_errors}" assert total_responses >= total_messages * 0.7, "Expected at least 70% response rate" if all_response_times: avg_response_time = statistics.mean(all_response_times) assert avg_response_time < 2.0, f"Average response time too high: {avg_response_time}s" print(f"Medium Load Test Results:") print(f" Users: {num_users}") print(f" Successful sessions: {len(successful_sessions)}") print(f" Total messages: {total_messages}") print(f" Total responses: {total_responses}") print(f" Total errors: {total_errors}") print(f" Total time: {total_time:.2f}s") if all_response_times: print(f" Avg response time: {statistics.mean(all_response_times):.3f}s") def test_concurrent_users_heavy_load(self, chat_system): """Test with 50 concurrent users (heavy load).""" num_users = 50 messages_per_user = 3 user_ids = [f"heavy-load-user-{i}" for i in range(num_users)] start_time = time.time() with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: futures = [ executor.submit( self.simulate_user_session, user_id, chat_system, messages_per_user ) for user_id in user_ids ] results = [future.result() for future in concurrent.futures.as_completed(futures)] total_time = time.time() - start_time # Analyze results successful_sessions = [r for r in results if r['success']] total_messages = sum(r['messages_sent'] for r in results) total_responses = sum(r['responses_received'] for r in results) total_errors = sum(r['errors'] for r in results) all_response_times = [] for r in results: all_response_times.extend(r['response_times']) # Assertions for heavy load (most lenient) assert len(successful_sessions) >= 35, f"Expected at least 35 successful sessions, got {len(successful_sessions)}" assert total_errors <= 25, f"Expected at most 25 errors, got {total_errors}" assert total_responses >= total_messages * 0.6, "Expected at least 60% response rate" if all_response_times: avg_response_time = statistics.mean(all_response_times) assert avg_response_time < 5.0, f"Average response time too high: {avg_response_time}s" print(f"Heavy Load Test Results:") print(f" Users: {num_users}") print(f" Successful sessions: {len(successful_sessions)}") print(f" Total messages: {total_messages}") print(f" Total responses: {total_responses}") print(f" Total errors: {total_errors}") print(f" Total time: {total_time:.2f}s") if all_response_times: print(f" Avg response time: {statistics.mean(all_response_times):.3f}s") def test_sustained_load(self, chat_system): """Test sustained load over time.""" duration_seconds = 30 # 30 second test users_per_wave = 5 wave_interval = 2 # New wave every 2 seconds results = [] start_time = time.time() wave_count = 0 while time.time() - start_time < duration_seconds: wave_start = time.time() wave_count += 1 # Create user IDs for this wave user_ids = [f"sustained-wave-{wave_count}-user-{i}" for i in range(users_per_wave)] # Launch concurrent sessions for this wave with concurrent.futures.ThreadPoolExecutor(max_workers=users_per_wave) as executor: futures = [ executor.submit( self.simulate_user_session, user_id, chat_system, 2 # 2 messages per user ) for user_id in user_ids ] wave_results = [future.result() for future in concurrent.futures.as_completed(futures)] results.extend(wave_results) # Wait for next wave elapsed = time.time() - wave_start if elapsed < wave_interval: time.sleep(wave_interval - elapsed) total_time = time.time() - start_time # Analyze sustained load results successful_sessions = [r for r in results if r['success']] total_messages = sum(r['messages_sent'] for r in results) total_responses = sum(r['responses_received'] for r in results) total_errors = sum(r['errors'] for r in results) # Assertions for sustained load success_rate = len(successful_sessions) / len(results) if results else 0 response_rate = total_responses / total_messages if total_messages > 0 else 0 assert success_rate >= 0.7, f"Expected at least 70% success rate, got {success_rate:.2%}" assert response_rate >= 0.6, f"Expected at least 60% response rate, got {response_rate:.2%}" print(f"Sustained Load Test Results:") print(f" Duration: {total_time:.1f}s") print(f" Waves: {wave_count}") print(f" Total users: {len(results)}") print(f" Successful sessions: {len(successful_sessions)} ({success_rate:.1%})") print(f" Total messages: {total_messages}") print(f" Total responses: {total_responses} ({response_rate:.1%})") print(f" Total errors: {total_errors}") def test_memory_usage_under_load(self, chat_system): """Test memory usage during concurrent sessions.""" import psutil import os process = psutil.Process(os.getpid()) initial_memory = process.memory_info().rss / 1024 / 1024 # MB num_users = 20 messages_per_user = 5 user_ids = [f"memory-test-user-{i}" for i in range(num_users)] # Run concurrent sessions with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor: futures = [ executor.submit( self.simulate_user_session, user_id, chat_system, messages_per_user ) for user_id in user_ids ] results = [future.result() for future in concurrent.futures.as_completed(futures)] final_memory = process.memory_info().rss / 1024 / 1024 # MB memory_increase = final_memory - initial_memory # Memory usage assertions assert memory_increase < 100, f"Memory increase too high: {memory_increase:.1f}MB" successful_sessions = [r for r in results if r['success']] assert len(successful_sessions) >= 15, "Expected at least 15 successful sessions" print(f"Memory Usage Test Results:") print(f" Initial memory: {initial_memory:.1f}MB") print(f" Final memory: {final_memory:.1f}MB") print(f" Memory increase: {memory_increase:.1f}MB") print(f" Successful sessions: {len(successful_sessions)}") if __name__ == '__main__': pytest.main([__file__, '-v', '-s'])