scratch_chat / tests /performance /test_load_testing.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""
Load testing for multiple concurrent chat sessions.
Tests system performance under various load conditions.
"""
import pytest
import asyncio
import time
import threading
import concurrent.futures
from unittest.mock import patch, MagicMock
import statistics
from chat_agent.services.chat_agent import ChatAgent
from chat_agent.services.session_manager import SessionManager
from chat_agent.services.language_context import LanguageContextManager
from chat_agent.services.chat_history import ChatHistoryManager
class TestConcurrentChatSessions:
"""Load testing for concurrent chat sessions."""
@pytest.fixture
def mock_groq_client(self):
"""Mock Groq client with realistic response times."""
with patch('chat_agent.services.groq_client.GroqClient') as mock:
mock_instance = MagicMock()
def mock_generate_response(*args, **kwargs):
# Simulate realistic API response time
time.sleep(0.1 + (time.time() % 0.1)) # 100-200ms
return f"Test response for concurrent user at {time.time()}"
mock_instance.generate_response.side_effect = mock_generate_response
mock.return_value = mock_instance
yield mock_instance
@pytest.fixture
def chat_system(self, mock_groq_client):
"""Create complete chat system for load testing."""
session_manager = SessionManager()
language_context_manager = LanguageContextManager()
chat_history_manager = ChatHistoryManager()
chat_agent = ChatAgent(
groq_client=mock_groq_client,
session_manager=session_manager,
language_context_manager=language_context_manager,
chat_history_manager=chat_history_manager
)
return {
'chat_agent': chat_agent,
'session_manager': session_manager,
'language_context_manager': language_context_manager,
'chat_history_manager': chat_history_manager
}
def simulate_user_session(self, user_id, chat_system, num_messages=5):
"""Simulate a complete user session with multiple messages."""
results = {
'user_id': user_id,
'session_id': None,
'messages_sent': 0,
'responses_received': 0,
'errors': 0,
'response_times': [],
'total_time': 0,
'success': False
}
start_time = time.time()
try:
# Create session
session = chat_system['session_manager'].create_session(
user_id, language="python"
)
results['session_id'] = session['session_id']
# Send multiple messages
messages = [
"What is Python?",
"How do I create a list?",
"Explain functions",
"What are loops?",
"How to handle errors?"
]
for i in range(min(num_messages, len(messages))):
message_start = time.time()
try:
response = chat_system['chat_agent'].process_message(
session_id=session['session_id'],
message=messages[i],
language="python"
)
message_time = time.time() - message_start
results['response_times'].append(message_time)
results['messages_sent'] += 1
if response and len(response) > 0:
results['responses_received'] += 1
except Exception as e:
results['errors'] += 1
print(f"Error in user {user_id} message {i}: {e}")
results['success'] = results['errors'] == 0
except Exception as e:
results['errors'] += 1
print(f"Error creating session for user {user_id}: {e}")
results['total_time'] = time.time() - start_time
return results
def test_concurrent_users_light_load(self, chat_system):
"""Test with 10 concurrent users (light load)."""
num_users = 10
messages_per_user = 3
# Create user IDs
user_ids = [f"load-test-user-{i}" for i in range(num_users)]
# Run concurrent sessions
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
futures = [
executor.submit(
self.simulate_user_session,
user_id,
chat_system,
messages_per_user
)
for user_id in user_ids
]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
total_time = time.time() - start_time
# Analyze results
successful_sessions = [r for r in results if r['success']]
failed_sessions = [r for r in results if not r['success']]
total_messages = sum(r['messages_sent'] for r in results)
total_responses = sum(r['responses_received'] for r in results)
total_errors = sum(r['errors'] for r in results)
all_response_times = []
for r in results:
all_response_times.extend(r['response_times'])
# Assertions for light load
assert len(successful_sessions) >= 8, f"Expected at least 8 successful sessions, got {len(successful_sessions)}"
assert total_errors <= 2, f"Expected at most 2 errors, got {total_errors}"
assert total_responses >= total_messages * 0.8, "Expected at least 80% response rate"
if all_response_times:
avg_response_time = statistics.mean(all_response_times)
assert avg_response_time < 1.0, f"Average response time too high: {avg_response_time}s"
print(f"Light Load Test Results:")
print(f" Users: {num_users}")
print(f" Successful sessions: {len(successful_sessions)}")
print(f" Failed sessions: {len(failed_sessions)}")
print(f" Total messages: {total_messages}")
print(f" Total responses: {total_responses}")
print(f" Total errors: {total_errors}")
print(f" Total time: {total_time:.2f}s")
if all_response_times:
print(f" Avg response time: {statistics.mean(all_response_times):.3f}s")
print(f" Max response time: {max(all_response_times):.3f}s")
def test_concurrent_users_medium_load(self, chat_system):
"""Test with 25 concurrent users (medium load)."""
num_users = 25
messages_per_user = 4
user_ids = [f"medium-load-user-{i}" for i in range(num_users)]
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
futures = [
executor.submit(
self.simulate_user_session,
user_id,
chat_system,
messages_per_user
)
for user_id in user_ids
]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
total_time = time.time() - start_time
# Analyze results
successful_sessions = [r for r in results if r['success']]
total_messages = sum(r['messages_sent'] for r in results)
total_responses = sum(r['responses_received'] for r in results)
total_errors = sum(r['errors'] for r in results)
all_response_times = []
for r in results:
all_response_times.extend(r['response_times'])
# Assertions for medium load (more lenient)
assert len(successful_sessions) >= 20, f"Expected at least 20 successful sessions, got {len(successful_sessions)}"
assert total_errors <= 10, f"Expected at most 10 errors, got {total_errors}"
assert total_responses >= total_messages * 0.7, "Expected at least 70% response rate"
if all_response_times:
avg_response_time = statistics.mean(all_response_times)
assert avg_response_time < 2.0, f"Average response time too high: {avg_response_time}s"
print(f"Medium Load Test Results:")
print(f" Users: {num_users}")
print(f" Successful sessions: {len(successful_sessions)}")
print(f" Total messages: {total_messages}")
print(f" Total responses: {total_responses}")
print(f" Total errors: {total_errors}")
print(f" Total time: {total_time:.2f}s")
if all_response_times:
print(f" Avg response time: {statistics.mean(all_response_times):.3f}s")
def test_concurrent_users_heavy_load(self, chat_system):
"""Test with 50 concurrent users (heavy load)."""
num_users = 50
messages_per_user = 3
user_ids = [f"heavy-load-user-{i}" for i in range(num_users)]
start_time = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
futures = [
executor.submit(
self.simulate_user_session,
user_id,
chat_system,
messages_per_user
)
for user_id in user_ids
]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
total_time = time.time() - start_time
# Analyze results
successful_sessions = [r for r in results if r['success']]
total_messages = sum(r['messages_sent'] for r in results)
total_responses = sum(r['responses_received'] for r in results)
total_errors = sum(r['errors'] for r in results)
all_response_times = []
for r in results:
all_response_times.extend(r['response_times'])
# Assertions for heavy load (most lenient)
assert len(successful_sessions) >= 35, f"Expected at least 35 successful sessions, got {len(successful_sessions)}"
assert total_errors <= 25, f"Expected at most 25 errors, got {total_errors}"
assert total_responses >= total_messages * 0.6, "Expected at least 60% response rate"
if all_response_times:
avg_response_time = statistics.mean(all_response_times)
assert avg_response_time < 5.0, f"Average response time too high: {avg_response_time}s"
print(f"Heavy Load Test Results:")
print(f" Users: {num_users}")
print(f" Successful sessions: {len(successful_sessions)}")
print(f" Total messages: {total_messages}")
print(f" Total responses: {total_responses}")
print(f" Total errors: {total_errors}")
print(f" Total time: {total_time:.2f}s")
if all_response_times:
print(f" Avg response time: {statistics.mean(all_response_times):.3f}s")
def test_sustained_load(self, chat_system):
"""Test sustained load over time."""
duration_seconds = 30 # 30 second test
users_per_wave = 5
wave_interval = 2 # New wave every 2 seconds
results = []
start_time = time.time()
wave_count = 0
while time.time() - start_time < duration_seconds:
wave_start = time.time()
wave_count += 1
# Create user IDs for this wave
user_ids = [f"sustained-wave-{wave_count}-user-{i}" for i in range(users_per_wave)]
# Launch concurrent sessions for this wave
with concurrent.futures.ThreadPoolExecutor(max_workers=users_per_wave) as executor:
futures = [
executor.submit(
self.simulate_user_session,
user_id,
chat_system,
2 # 2 messages per user
)
for user_id in user_ids
]
wave_results = [future.result() for future in concurrent.futures.as_completed(futures)]
results.extend(wave_results)
# Wait for next wave
elapsed = time.time() - wave_start
if elapsed < wave_interval:
time.sleep(wave_interval - elapsed)
total_time = time.time() - start_time
# Analyze sustained load results
successful_sessions = [r for r in results if r['success']]
total_messages = sum(r['messages_sent'] for r in results)
total_responses = sum(r['responses_received'] for r in results)
total_errors = sum(r['errors'] for r in results)
# Assertions for sustained load
success_rate = len(successful_sessions) / len(results) if results else 0
response_rate = total_responses / total_messages if total_messages > 0 else 0
assert success_rate >= 0.7, f"Expected at least 70% success rate, got {success_rate:.2%}"
assert response_rate >= 0.6, f"Expected at least 60% response rate, got {response_rate:.2%}"
print(f"Sustained Load Test Results:")
print(f" Duration: {total_time:.1f}s")
print(f" Waves: {wave_count}")
print(f" Total users: {len(results)}")
print(f" Successful sessions: {len(successful_sessions)} ({success_rate:.1%})")
print(f" Total messages: {total_messages}")
print(f" Total responses: {total_responses} ({response_rate:.1%})")
print(f" Total errors: {total_errors}")
def test_memory_usage_under_load(self, chat_system):
"""Test memory usage during concurrent sessions."""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
num_users = 20
messages_per_user = 5
user_ids = [f"memory-test-user-{i}" for i in range(num_users)]
# Run concurrent sessions
with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
futures = [
executor.submit(
self.simulate_user_session,
user_id,
chat_system,
messages_per_user
)
for user_id in user_ids
]
results = [future.result() for future in concurrent.futures.as_completed(futures)]
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory
# Memory usage assertions
assert memory_increase < 100, f"Memory increase too high: {memory_increase:.1f}MB"
successful_sessions = [r for r in results if r['success']]
assert len(successful_sessions) >= 15, "Expected at least 15 successful sessions"
print(f"Memory Usage Test Results:")
print(f" Initial memory: {initial_memory:.1f}MB")
print(f" Final memory: {final_memory:.1f}MB")
print(f" Memory increase: {memory_increase:.1f}MB")
print(f" Successful sessions: {len(successful_sessions)}")
if __name__ == '__main__':
pytest.main([__file__, '-v', '-s'])