| |
| import pytest |
| from unittest.mock import patch, MagicMock, ANY, AsyncMock |
| from openai import OpenAIError |
| import json |
| import tenacity |
| import asyncio |
| from openai.types.chat import ChatCompletion |
| from openai.types.chat.chat_completion import Choice as ChatCompletionChoice |
| from openai.types.chat.chat_completion_message import ChatCompletionMessage |
| from openai import APIConnectionError, APIError, AsyncOpenAI |
|
|
| |
| from ankigen_core.llm_interface import ( |
| OpenAIClientManager, |
| structured_output_completion, |
| process_crawled_page, |
| process_crawled_pages, |
| ) |
| from ankigen_core.utils import ( |
| ResponseCache, |
| ) |
| from ankigen_core.models import CrawledPage, AnkiCardData |
|
|
| |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_client_manager_init(): |
| """Test initial state of the client manager.""" |
| manager = OpenAIClientManager() |
| assert manager._client is None |
| assert manager._api_key is None |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_client_manager_initialize_success(): |
| """Test successful client initialization.""" |
| manager = OpenAIClientManager() |
| valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| |
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI" |
| ) as mock_async_openai_constructor: |
| await manager.initialize_client(valid_key) |
| mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
| assert manager.get_client() is not None |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_client_manager_initialize_invalid_key_format(): |
| """Test initialization failure with invalid API key format.""" |
| manager = OpenAIClientManager() |
| invalid_key = "invalid-key-format" |
| with pytest.raises(ValueError, match="Invalid OpenAI API key format."): |
| await manager.initialize_client(invalid_key) |
| assert manager._client is None |
| assert manager._api_key is None |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_client_manager_initialize_openai_error(): |
| """Test handling of OpenAIError during client initialization.""" |
| manager = OpenAIClientManager() |
| valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| error_message = "Test OpenAI Init Error" |
|
|
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI", side_effect=OpenAIError(error_message) |
| ) as mock_async_openai_constructor: |
| with pytest.raises(OpenAIError, match=error_message): |
| await manager.initialize_client(valid_key) |
| mock_async_openai_constructor.assert_called_once_with(api_key=valid_key) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_client_manager_get_client_success(): |
| """Test getting the client after successful initialization.""" |
| manager = OpenAIClientManager() |
| valid_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" |
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI" |
| ) as mock_async_openai_constructor: |
| mock_instance = mock_async_openai_constructor.return_value |
| await manager.initialize_client(valid_key) |
| assert manager.get_client() == mock_instance |
|
|
|
|
| def test_client_manager_get_client_not_initialized(): |
| """Test getting the client before initialization.""" |
| manager = OpenAIClientManager() |
| with pytest.raises(RuntimeError, match="OpenAI client is not initialized."): |
| manager.get_client() |
|
|
|
|
| |
|
|
|
|
| |
| @pytest.fixture |
| def mock_openai_client(): |
| client = MagicMock(spec=AsyncOpenAI) |
| client.chat = AsyncMock() |
| client.chat.completions = AsyncMock() |
| client.chat.completions.create = AsyncMock() |
| mock_chat_completion_response = create_mock_chat_completion( |
| json.dumps([{"data": "mocked success"}]) |
| ) |
| client.chat.completions.create.return_value = mock_chat_completion_response |
| return client |
|
|
|
|
| |
| @pytest.fixture |
| def mock_response_cache(): |
| cache = MagicMock(spec=ResponseCache) |
| return cache |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_structured_output_completion_cache_hit( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when the response is found in the cache.""" |
| system_prompt = "System prompt" |
| user_prompt = "User prompt" |
| model = "test-model" |
| cached_result = {"data": "cached result"} |
|
|
| |
| mock_response_cache.get.return_value = cached_result |
|
|
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| mock_response_cache.get.assert_called_once_with( |
| f"{system_prompt}:{user_prompt}", model |
| ) |
| mock_openai_client.chat.completions.create.assert_not_called() |
| mock_response_cache.set.assert_not_called() |
| assert result == cached_result |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_structured_output_completion_cache_miss_success( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior on cache miss with a successful API call.""" |
| system_prompt = "System prompt for success" |
| user_prompt = "User prompt for success" |
| model = "test-model-success" |
| expected_result = {"data": "successful API result"} |
|
|
| |
| mock_response_cache.get.return_value = None |
|
|
| |
| mock_completion = MagicMock() |
| mock_message = MagicMock() |
| mock_message.content = json.dumps(expected_result) |
| mock_choice = MagicMock() |
| mock_choice.message = mock_message |
| mock_completion.choices = [mock_choice] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| mock_response_cache.get.assert_called_once_with( |
| f"{system_prompt}:{user_prompt}", model |
| ) |
| mock_openai_client.chat.completions.create.assert_called_once_with( |
| model=model, |
| messages=[ |
| { |
| "role": "system", |
| "content": ANY, |
| }, |
| {"role": "user", "content": user_prompt}, |
| ], |
| response_format={"type": "json_object"}, |
| temperature=0.7, |
| ) |
| mock_response_cache.set.assert_called_once_with( |
| f"{system_prompt}:{user_prompt}", model, expected_result |
| ) |
| assert result == expected_result |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_structured_output_completion_api_error( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when the OpenAI API call raises an error.""" |
| system_prompt = "System prompt for error" |
| user_prompt = "User prompt for error" |
| model = "test-model-error" |
| error_message = "Test API Error" |
|
|
| |
| mock_response_cache.get.return_value = None |
|
|
| |
| |
| |
| mock_openai_client.chat.completions.create.side_effect = OpenAIError(error_message) |
|
|
| with pytest.raises(tenacity.RetryError): |
| await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| assert ( |
| mock_response_cache.get.call_count == 3 |
| ), f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
| |
| assert ( |
| mock_openai_client.chat.completions.create.call_count == 3 |
| ), f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_structured_output_completion_invalid_json( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when the API returns invalid JSON.""" |
| system_prompt = "System prompt for invalid json" |
| user_prompt = "User prompt for invalid json" |
| model = "test-model-invalid-json" |
| invalid_json_content = "this is not json" |
|
|
| |
| mock_response_cache.get.return_value = None |
|
|
| |
| mock_completion = MagicMock() |
| mock_message = MagicMock() |
| mock_message.content = invalid_json_content |
| mock_choice = MagicMock() |
| mock_choice.message = mock_message |
| mock_completion.choices = [mock_choice] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| with pytest.raises(tenacity.RetryError): |
| await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
|
|
| |
| |
| assert ( |
| mock_response_cache.get.call_count == 3 |
| ), f"Expected cache.get to be called 3 times due to retries, but was {mock_response_cache.get.call_count}" |
| |
| assert ( |
| mock_openai_client.chat.completions.create.call_count == 3 |
| ), f"Expected create to be called 3 times due to retries, but was {mock_openai_client.chat.completions.create.call_count}" |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_structured_output_completion_no_choices( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when API completion has no choices.""" |
| system_prompt = "System prompt no choices" |
| user_prompt = "User prompt no choices" |
| model = "test-model-no-choices" |
|
|
| mock_response_cache.get.return_value = None |
| mock_completion = MagicMock() |
| mock_completion.choices = [] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| |
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
| assert result is None |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_structured_output_completion_no_message_content( |
| mock_openai_client, mock_response_cache |
| ): |
| """Test behavior when API choice has no message content.""" |
| system_prompt = "System prompt no content" |
| user_prompt = "User prompt no content" |
| model = "test-model-no-content" |
|
|
| mock_response_cache.get.return_value = None |
| mock_completion = MagicMock() |
| mock_message = MagicMock() |
| mock_message.content = None |
| mock_choice = MagicMock() |
| mock_choice.message = mock_message |
| mock_completion.choices = [mock_choice] |
| mock_openai_client.chat.completions.create.return_value = mock_completion |
|
|
| |
| result = await structured_output_completion( |
| openai_client=mock_openai_client, |
| model=model, |
| response_format={"type": "json_object"}, |
| system_prompt=system_prompt, |
| user_prompt=user_prompt, |
| cache=mock_response_cache, |
| ) |
| assert result is None |
| mock_response_cache.set.assert_not_called() |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
|
|
| @pytest.fixture |
| def client_manager(): |
| """Fixture for the OpenAIClientManager.""" |
| return OpenAIClientManager() |
|
|
|
|
| @pytest.fixture |
| def sample_crawled_page(): |
| """Fixture for a sample CrawledPage object.""" |
| return CrawledPage( |
| url="http://example.com", |
| html_content="<html><body>This is some test content for the page.</body></html>", |
| text_content="This is some test content for the page.", |
| title="Test Page", |
| meta_description="A test page.", |
| meta_keywords=["test", "page"], |
| crawl_depth=0, |
| ) |
|
|
|
|
| |
|
|
|
|
| def create_mock_chat_completion(content: str) -> ChatCompletion: |
| return ChatCompletion( |
| id="chatcmpl-test123", |
| choices=[ |
| ChatCompletionChoice( |
| finish_reason="stop", |
| index=0, |
| message=ChatCompletionMessage(content=content, role="assistant"), |
| logprobs=None, |
| ) |
| ], |
| created=1677652288, |
| model="gpt-4o", |
| object="chat.completion", |
| system_fingerprint="fp_test", |
| usage=None, |
| ) |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_success(mock_openai_client, sample_crawled_page): |
| |
| mock_response_content = json.dumps( |
| [ |
| {"front": "Q1", "back": "A1", "tags": ["tag1"]}, |
| {"front": "Q2", "back": "A2", "tags": ["tag2", "python"]}, |
| ] |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| assert len(result_cards) == 2 |
| assert result_cards[0].front == "Q1" |
| assert result_cards[0].source_url == sample_crawled_page.url |
| assert result_cards[1].back == "A2" |
| |
| |
| mock_openai_client.chat.completions.create.assert_awaited_once() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_empty_llm_response_content( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion("") |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 0 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_llm_returns_not_a_list( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_response_content = json.dumps( |
| {"error": "not a list as expected"} |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 0 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_llm_returns_dict_with_cards_key( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_response_content = json.dumps( |
| {"cards": [{"front": "Q1", "back": "A1", "tags": []}]} |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| |
| assert len(result_cards) == 1 |
| assert result_cards[0].front == "Q1" |
| assert result_cards[0].back == "A1" |
| assert result_cards[0].source_url == sample_crawled_page.url |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_json_decode_error( |
| mock_openai_client, sample_crawled_page |
| ): |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion("this is not valid json") |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
| assert len(result_cards) == 0 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_empty_text_content(mock_openai_client): |
| empty_content_page = CrawledPage( |
| url="http://example.com/empty", |
| html_content="", |
| text_content="", |
| title="Empty", |
| ) |
| result_cards = await process_crawled_page(mock_openai_client, empty_content_page) |
| assert len(result_cards) == 0 |
| mock_openai_client.chat.completions.create.assert_not_awaited() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_openai_api_error_retry( |
| mock_openai_client, sample_crawled_page, caplog |
| ): |
| |
| |
|
|
| |
| from ankigen_core.llm_interface import process_crawled_page as original_func |
|
|
| |
| async def mock_implementation(*args, **kwargs): |
| return await original_func(*args, **kwargs) |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_implementation, |
| ): |
| |
| responses = [ |
| create_mock_chat_completion( |
| json.dumps([{"front": "Q1", "back": "A1", "tags": []}]) |
| ) |
| ] |
| mock_openai_client.chat.completions.create.return_value = responses[0] |
|
|
| |
| result_cards = await mock_implementation( |
| mock_openai_client, sample_crawled_page |
| ) |
|
|
| |
| assert len(result_cards) == 1 |
| assert result_cards[0].front == "Q1" |
| assert result_cards[0].back == "A1" |
| assert mock_openai_client.chat.completions.create.call_count == 1 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_openai_persistent_api_error( |
| mock_openai_client, sample_crawled_page, caplog |
| ): |
| |
| mock_openai_client.chat.completions.create.side_effect = APIConnectionError( |
| request=MagicMock() |
| ) |
|
|
| result_cards = await process_crawled_page(mock_openai_client, sample_crawled_page) |
|
|
| assert len(result_cards) == 0 |
| assert mock_openai_client.chat.completions.create.await_count == 1 |
| assert "OpenAI API error while processing page" in caplog.text |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_tiktoken_truncation( |
| mock_openai_client, sample_crawled_page, monkeypatch |
| ): |
| |
| long_text = "word " * 8000 |
| sample_crawled_page.text_content = long_text |
|
|
| |
| mock_response_content = json.dumps( |
| [{"front": "TruncatedQ", "back": "TruncatedA", "tags": []}] |
| ) |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion(mock_response_content) |
| ) |
|
|
| |
| mock_encoding = MagicMock() |
|
|
| |
| |
| |
| |
| mock_encoding.encode.side_effect = [ |
| list(range(1000)), |
| list(range(10000)), |
| list(range(10000)), |
| ] |
|
|
| |
| truncated_content = [] |
|
|
| def mock_decode(tokens): |
| truncated_content.append(len(tokens)) |
| return "Truncated content" |
|
|
| mock_encoding.decode = mock_decode |
|
|
| mock_get_encoding = MagicMock(return_value=mock_encoding) |
|
|
| with patch("tiktoken.get_encoding", mock_get_encoding): |
| with patch("tiktoken.encoding_for_model", side_effect=KeyError("test")): |
| result_cards = await process_crawled_page( |
| mock_openai_client, sample_crawled_page, max_prompt_content_tokens=6000 |
| ) |
|
|
| |
| assert len(result_cards) == 1 |
| assert result_cards[0].front == "TruncatedQ" |
| assert result_cards[0].back == "TruncatedA" |
|
|
| |
| mock_get_encoding.assert_called_with("cl100k_base") |
| assert mock_encoding.encode.call_count >= 2 |
|
|
|
|
| |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_pages_success(mock_openai_client, sample_crawled_page): |
| pages_to_process = [ |
| sample_crawled_page, |
| CrawledPage( |
| url="http://example.com/page2", |
| html_content="", |
| text_content="Content for page 2", |
| title="Page 2", |
| ), |
| ] |
|
|
| |
| async def mock_single_page_processor(openai_client, page, model="gpt-4o", **kwargs): |
| if page.url == pages_to_process[0].url: |
| return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
| elif page.url == pages_to_process[1].url: |
| return [ |
| AnkiCardData(front="P2Q1", back="P2A1", source_url=page.url), |
| AnkiCardData(front="P2Q2", back="P2A2", source_url=page.url), |
| ] |
| return [] |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_single_page_processor, |
| ) as mock_processor: |
| result_cards = await process_crawled_pages( |
| mock_openai_client, pages_to_process, max_concurrent_requests=1 |
| ) |
|
|
| assert len(result_cards) == 3 |
| assert mock_processor.call_count == 2 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_pages_partial_failure( |
| mock_openai_client, sample_crawled_page |
| ): |
| pages_to_process = [ |
| sample_crawled_page, |
| CrawledPage( |
| url="http://example.com/page_fail", |
| html_content="", |
| text_content="Content for page fail", |
| title="Page Fail", |
| ), |
| CrawledPage( |
| url="http://example.com/page3", |
| html_content="", |
| text_content="Content for page 3", |
| title="Page 3", |
| ), |
| ] |
|
|
| async def mock_single_page_processor_with_failure( |
| openai_client, page, model="gpt-4o", **kwargs |
| ): |
| if page.url == pages_to_process[0].url: |
| return [AnkiCardData(front="P1Q1", back="P1A1", source_url=page.url)] |
| elif page.url == pages_to_process[1].url: |
| raise APIConnectionError(request=MagicMock()) |
| elif page.url == pages_to_process[2].url: |
| return [AnkiCardData(front="P3Q1", back="P3A1", source_url=page.url)] |
| return [] |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_single_page_processor_with_failure, |
| ) as mock_processor: |
| result_cards = await process_crawled_pages( |
| mock_openai_client, pages_to_process, max_concurrent_requests=2 |
| ) |
|
|
| assert len(result_cards) == 2 |
| assert mock_processor.call_count == 3 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_pages_progress_callback( |
| mock_openai_client, sample_crawled_page |
| ): |
| pages_to_process = [sample_crawled_page] * 3 |
| progress_log = [] |
|
|
| def callback(completed_count, total_count): |
| progress_log.append((completed_count, total_count)) |
|
|
| async def mock_simple_processor(client, page, model, max_tokens): |
| await asyncio.sleep(0.01) |
| return [AnkiCardData(front=f"{page.url}-Q", back="A", source_url=page.url)] |
|
|
| with patch( |
| "ankigen_core.llm_interface.process_crawled_page", |
| side_effect=mock_simple_processor, |
| ): |
| await process_crawled_pages( |
| mock_openai_client, |
| pages_to_process, |
| progress_callback=callback, |
| max_concurrent_requests=1, |
| ) |
|
|
| assert len(progress_log) == 3 |
| assert progress_log[0] == (1, 3) |
| assert progress_log[1] == (2, 3) |
| assert progress_log[2] == (3, 3) |
|
|
|
|
| |
| TEST_API_KEY = "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_api_error( |
| client_manager, mock_openai_client, sample_crawled_page |
| ): |
| """Test handling of API error during LLM call.""" |
|
|
| |
| |
| mock_request = MagicMock() |
| mock_openai_client.chat.completions.create.side_effect = APIError( |
| message="Test API Error", request=mock_request, body=None |
| ) |
|
|
| with patch.object(client_manager, "get_client", return_value=mock_openai_client): |
| |
| mock_openai_client.chat.completions.create.reset_mock() |
|
|
| result_cards = await process_crawled_page( |
| mock_openai_client, |
| sample_crawled_page, |
| "gpt-4o", |
| max_prompt_content_tokens=1000, |
| ) |
| assert len(result_cards) == 0 |
| |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_process_crawled_page_content_truncation( |
| client_manager, mock_openai_client, sample_crawled_page |
| ): |
| """Test content truncation based on max_prompt_content_tokens.""" |
| long_content_piece = "This is a word. " |
| repetitions = 10 |
| sample_crawled_page.text_content = long_content_piece * repetitions |
|
|
| with ( |
| patch.object(client_manager, "get_client", return_value=mock_openai_client), |
| patch("tiktoken.encoding_for_model", side_effect=KeyError("test")), |
| patch("tiktoken.get_encoding") as mock_get_encoding, |
| ): |
| mock_encoding = MagicMock() |
|
|
| |
| |
| |
| system_prompt_tokens = list(range(100)) |
| mock_encoding.encode.return_value = system_prompt_tokens |
|
|
| mock_get_encoding.return_value = mock_encoding |
|
|
| |
| mock_openai_client.chat.completions.create.return_value = ( |
| create_mock_chat_completion( |
| json.dumps([{"front": "TestQ", "back": "TestA", "tags": []}]) |
| ) |
| ) |
|
|
| |
| result = await process_crawled_page( |
| mock_openai_client, |
| sample_crawled_page, |
| "gpt-4o", |
| max_prompt_content_tokens=20, |
| ) |
|
|
| |
| assert result == [] |
|
|
| |
| mock_get_encoding.assert_called_with("cl100k_base") |
| assert mock_encoding.encode.call_count >= 1 |
|
|
| |
| mock_openai_client.chat.completions.create.assert_not_called() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_openai_client_manager_get_client( |
| client_manager, mock_async_openai_client |
| ): |
| """Test that get_client returns the AsyncOpenAI client instance and initializes it once.""" |
| |
| client_manager._client = None |
| client_manager._api_key = None |
|
|
| with patch( |
| "ankigen_core.llm_interface.AsyncOpenAI", return_value=mock_async_openai_client |
| ) as mock_constructor: |
| |
| await client_manager.initialize_client( |
| "sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
| ) |
|
|
| client1 = client_manager.get_client() |
| client2 = ( |
| client_manager.get_client() |
| ) |
|
|
| assert client1 is mock_async_openai_client |
| assert client2 is mock_async_openai_client |
| mock_constructor.assert_called_once_with( |
| api_key="sk-testkey1234567890abcdefghijklmnopqrstuvwxyz" |
| ) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| @pytest.fixture |
| def mock_async_openai_client(): |
| client = MagicMock(spec=AsyncOpenAI) |
| client.chat = AsyncMock() |
| client.chat.completions = AsyncMock() |
| client.chat.completions.create = AsyncMock() |
| mock_process_page_response = create_mock_chat_completion( |
| json.dumps([{"front": "Q_Default", "back": "A_Default", "tags": []}]) |
| ) |
| client.chat.completions.create.return_value = mock_process_page_response |
| return client |
|
|