File size: 16,260 Bytes
330b6e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
"""

Load testing for multiple concurrent chat sessions.

Tests system performance under various load conditions.

"""

import pytest
import asyncio
import time
import threading
import concurrent.futures
from unittest.mock import patch, MagicMock
import statistics
from chat_agent.services.chat_agent import ChatAgent
from chat_agent.services.session_manager import SessionManager
from chat_agent.services.language_context import LanguageContextManager
from chat_agent.services.chat_history import ChatHistoryManager


class TestConcurrentChatSessions:
    """Load testing for concurrent chat sessions."""
    
    @pytest.fixture
    def mock_groq_client(self):
        """Mock Groq client with realistic response times."""
        with patch('chat_agent.services.groq_client.GroqClient') as mock:
            mock_instance = MagicMock()
            
            def mock_generate_response(*args, **kwargs):
                # Simulate realistic API response time
                time.sleep(0.1 + (time.time() % 0.1))  # 100-200ms
                return f"Test response for concurrent user at {time.time()}"
            
            mock_instance.generate_response.side_effect = mock_generate_response
            mock.return_value = mock_instance
            yield mock_instance
    
    @pytest.fixture
    def chat_system(self, mock_groq_client):
        """Create complete chat system for load testing."""
        session_manager = SessionManager()
        language_context_manager = LanguageContextManager()
        chat_history_manager = ChatHistoryManager()
        chat_agent = ChatAgent(
            groq_client=mock_groq_client,
            session_manager=session_manager,
            language_context_manager=language_context_manager,
            chat_history_manager=chat_history_manager
        )
        
        return {
            'chat_agent': chat_agent,
            'session_manager': session_manager,
            'language_context_manager': language_context_manager,
            'chat_history_manager': chat_history_manager
        }
    
    def simulate_user_session(self, user_id, chat_system, num_messages=5):
        """Simulate a complete user session with multiple messages."""
        results = {
            'user_id': user_id,
            'session_id': None,
            'messages_sent': 0,
            'responses_received': 0,
            'errors': 0,
            'response_times': [],
            'total_time': 0,
            'success': False
        }
        
        start_time = time.time()
        
        try:
            # Create session
            session = chat_system['session_manager'].create_session(
                user_id, language="python"
            )
            results['session_id'] = session['session_id']
            
            # Send multiple messages
            messages = [
                "What is Python?",
                "How do I create a list?",
                "Explain functions",
                "What are loops?",
                "How to handle errors?"
            ]
            
            for i in range(min(num_messages, len(messages))):
                message_start = time.time()
                
                try:
                    response = chat_system['chat_agent'].process_message(
                        session_id=session['session_id'],
                        message=messages[i],
                        language="python"
                    )
                    
                    message_time = time.time() - message_start
                    results['response_times'].append(message_time)
                    results['messages_sent'] += 1
                    
                    if response and len(response) > 0:
                        results['responses_received'] += 1
                    
                except Exception as e:
                    results['errors'] += 1
                    print(f"Error in user {user_id} message {i}: {e}")
            
            results['success'] = results['errors'] == 0
            
        except Exception as e:
            results['errors'] += 1
            print(f"Error creating session for user {user_id}: {e}")
        
        results['total_time'] = time.time() - start_time
        return results
    
    def test_concurrent_users_light_load(self, chat_system):
        """Test with 10 concurrent users (light load)."""
        num_users = 10
        messages_per_user = 3
        
        # Create user IDs
        user_ids = [f"load-test-user-{i}" for i in range(num_users)]
        
        # Run concurrent sessions
        start_time = time.time()
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
            futures = [
                executor.submit(
                    self.simulate_user_session, 
                    user_id, 
                    chat_system, 
                    messages_per_user
                )
                for user_id in user_ids
            ]
            
            results = [future.result() for future in concurrent.futures.as_completed(futures)]
        
        total_time = time.time() - start_time
        
        # Analyze results
        successful_sessions = [r for r in results if r['success']]
        failed_sessions = [r for r in results if not r['success']]
        
        total_messages = sum(r['messages_sent'] for r in results)
        total_responses = sum(r['responses_received'] for r in results)
        total_errors = sum(r['errors'] for r in results)
        
        all_response_times = []
        for r in results:
            all_response_times.extend(r['response_times'])
        
        # Assertions for light load
        assert len(successful_sessions) >= 8, f"Expected at least 8 successful sessions, got {len(successful_sessions)}"
        assert total_errors <= 2, f"Expected at most 2 errors, got {total_errors}"
        assert total_responses >= total_messages * 0.8, "Expected at least 80% response rate"
        
        if all_response_times:
            avg_response_time = statistics.mean(all_response_times)
            assert avg_response_time < 1.0, f"Average response time too high: {avg_response_time}s"
        
        print(f"Light Load Test Results:")
        print(f"  Users: {num_users}")
        print(f"  Successful sessions: {len(successful_sessions)}")
        print(f"  Failed sessions: {len(failed_sessions)}")
        print(f"  Total messages: {total_messages}")
        print(f"  Total responses: {total_responses}")
        print(f"  Total errors: {total_errors}")
        print(f"  Total time: {total_time:.2f}s")
        if all_response_times:
            print(f"  Avg response time: {statistics.mean(all_response_times):.3f}s")
            print(f"  Max response time: {max(all_response_times):.3f}s")
    
    def test_concurrent_users_medium_load(self, chat_system):
        """Test with 25 concurrent users (medium load)."""
        num_users = 25
        messages_per_user = 4
        
        user_ids = [f"medium-load-user-{i}" for i in range(num_users)]
        
        start_time = time.time()
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
            futures = [
                executor.submit(
                    self.simulate_user_session, 
                    user_id, 
                    chat_system, 
                    messages_per_user
                )
                for user_id in user_ids
            ]
            
            results = [future.result() for future in concurrent.futures.as_completed(futures)]
        
        total_time = time.time() - start_time
        
        # Analyze results
        successful_sessions = [r for r in results if r['success']]
        total_messages = sum(r['messages_sent'] for r in results)
        total_responses = sum(r['responses_received'] for r in results)
        total_errors = sum(r['errors'] for r in results)
        
        all_response_times = []
        for r in results:
            all_response_times.extend(r['response_times'])
        
        # Assertions for medium load (more lenient)
        assert len(successful_sessions) >= 20, f"Expected at least 20 successful sessions, got {len(successful_sessions)}"
        assert total_errors <= 10, f"Expected at most 10 errors, got {total_errors}"
        assert total_responses >= total_messages * 0.7, "Expected at least 70% response rate"
        
        if all_response_times:
            avg_response_time = statistics.mean(all_response_times)
            assert avg_response_time < 2.0, f"Average response time too high: {avg_response_time}s"
        
        print(f"Medium Load Test Results:")
        print(f"  Users: {num_users}")
        print(f"  Successful sessions: {len(successful_sessions)}")
        print(f"  Total messages: {total_messages}")
        print(f"  Total responses: {total_responses}")
        print(f"  Total errors: {total_errors}")
        print(f"  Total time: {total_time:.2f}s")
        if all_response_times:
            print(f"  Avg response time: {statistics.mean(all_response_times):.3f}s")
    
    def test_concurrent_users_heavy_load(self, chat_system):
        """Test with 50 concurrent users (heavy load)."""
        num_users = 50
        messages_per_user = 3
        
        user_ids = [f"heavy-load-user-{i}" for i in range(num_users)]
        
        start_time = time.time()
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
            futures = [
                executor.submit(
                    self.simulate_user_session, 
                    user_id, 
                    chat_system, 
                    messages_per_user
                )
                for user_id in user_ids
            ]
            
            results = [future.result() for future in concurrent.futures.as_completed(futures)]
        
        total_time = time.time() - start_time
        
        # Analyze results
        successful_sessions = [r for r in results if r['success']]
        total_messages = sum(r['messages_sent'] for r in results)
        total_responses = sum(r['responses_received'] for r in results)
        total_errors = sum(r['errors'] for r in results)
        
        all_response_times = []
        for r in results:
            all_response_times.extend(r['response_times'])
        
        # Assertions for heavy load (most lenient)
        assert len(successful_sessions) >= 35, f"Expected at least 35 successful sessions, got {len(successful_sessions)}"
        assert total_errors <= 25, f"Expected at most 25 errors, got {total_errors}"
        assert total_responses >= total_messages * 0.6, "Expected at least 60% response rate"
        
        if all_response_times:
            avg_response_time = statistics.mean(all_response_times)
            assert avg_response_time < 5.0, f"Average response time too high: {avg_response_time}s"
        
        print(f"Heavy Load Test Results:")
        print(f"  Users: {num_users}")
        print(f"  Successful sessions: {len(successful_sessions)}")
        print(f"  Total messages: {total_messages}")
        print(f"  Total responses: {total_responses}")
        print(f"  Total errors: {total_errors}")
        print(f"  Total time: {total_time:.2f}s")
        if all_response_times:
            print(f"  Avg response time: {statistics.mean(all_response_times):.3f}s")
    
    def test_sustained_load(self, chat_system):
        """Test sustained load over time."""
        duration_seconds = 30  # 30 second test
        users_per_wave = 5
        wave_interval = 2  # New wave every 2 seconds
        
        results = []
        start_time = time.time()
        wave_count = 0
        
        while time.time() - start_time < duration_seconds:
            wave_start = time.time()
            wave_count += 1
            
            # Create user IDs for this wave
            user_ids = [f"sustained-wave-{wave_count}-user-{i}" for i in range(users_per_wave)]
            
            # Launch concurrent sessions for this wave
            with concurrent.futures.ThreadPoolExecutor(max_workers=users_per_wave) as executor:
                futures = [
                    executor.submit(
                        self.simulate_user_session, 
                        user_id, 
                        chat_system, 
                        2  # 2 messages per user
                    )
                    for user_id in user_ids
                ]
                
                wave_results = [future.result() for future in concurrent.futures.as_completed(futures)]
                results.extend(wave_results)
            
            # Wait for next wave
            elapsed = time.time() - wave_start
            if elapsed < wave_interval:
                time.sleep(wave_interval - elapsed)
        
        total_time = time.time() - start_time
        
        # Analyze sustained load results
        successful_sessions = [r for r in results if r['success']]
        total_messages = sum(r['messages_sent'] for r in results)
        total_responses = sum(r['responses_received'] for r in results)
        total_errors = sum(r['errors'] for r in results)
        
        # Assertions for sustained load
        success_rate = len(successful_sessions) / len(results) if results else 0
        response_rate = total_responses / total_messages if total_messages > 0 else 0
        
        assert success_rate >= 0.7, f"Expected at least 70% success rate, got {success_rate:.2%}"
        assert response_rate >= 0.6, f"Expected at least 60% response rate, got {response_rate:.2%}"
        
        print(f"Sustained Load Test Results:")
        print(f"  Duration: {total_time:.1f}s")
        print(f"  Waves: {wave_count}")
        print(f"  Total users: {len(results)}")
        print(f"  Successful sessions: {len(successful_sessions)} ({success_rate:.1%})")
        print(f"  Total messages: {total_messages}")
        print(f"  Total responses: {total_responses} ({response_rate:.1%})")
        print(f"  Total errors: {total_errors}")
    
    def test_memory_usage_under_load(self, chat_system):
        """Test memory usage during concurrent sessions."""
        import psutil
        import os
        
        process = psutil.Process(os.getpid())
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
        
        num_users = 20
        messages_per_user = 5
        user_ids = [f"memory-test-user-{i}" for i in range(num_users)]
        
        # Run concurrent sessions
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_users) as executor:
            futures = [
                executor.submit(
                    self.simulate_user_session, 
                    user_id, 
                    chat_system, 
                    messages_per_user
                )
                for user_id in user_ids
            ]
            
            results = [future.result() for future in concurrent.futures.as_completed(futures)]
        
        final_memory = process.memory_info().rss / 1024 / 1024  # MB
        memory_increase = final_memory - initial_memory
        
        # Memory usage assertions
        assert memory_increase < 100, f"Memory increase too high: {memory_increase:.1f}MB"
        
        successful_sessions = [r for r in results if r['success']]
        assert len(successful_sessions) >= 15, "Expected at least 15 successful sessions"
        
        print(f"Memory Usage Test Results:")
        print(f"  Initial memory: {initial_memory:.1f}MB")
        print(f"  Final memory: {final_memory:.1f}MB")
        print(f"  Memory increase: {memory_increase:.1f}MB")
        print(f"  Successful sessions: {len(successful_sessions)}")


if __name__ == '__main__':
    pytest.main([__file__, '-v', '-s'])