| | import unittest |
| | from abc import ABC |
| | from unittest import mock |
| |
|
| | from src.agent import BaseAgent |
| | from src.exceptions.exceptions import InputErrorException |
| | from src.llm import BaseLLM |
| | from src.prompt import PromptTemplate |
| | from src.schemas import AgentType, AgentOutput |
| | from src.tools import BaseTool |
| |
|
| |
|
| | class SampleBaseAgent(BaseAgent): |
| | def run(self, *args, **kwargs) -> AgentOutput: |
| | pass |
| |
|
| |
|
| | class SampleBaseTool(BaseTool, ABC): |
| | def run(self, req): |
| | pass |
| |
|
| | async def async_run(self, req): |
| | pass |
| |
|
| |
|
| | class TestBaseAgent(unittest.TestCase): |
| |
|
| | def setUp(self): |
| | self.mock_llm = mock.create_autospec(BaseLLM) |
| | self.mock_prompt_template = mock.create_autospec(PromptTemplate) |
| | self.agent = SampleBaseAgent(name='TestAgent', type=AgentType.react, version='1.0', |
| | description='Test Description', prompt_template=self.mock_prompt_template) |
| | self.tool = SampleBaseTool("test_tool", "test_tool") |
| | self.agent.add_plugin('test_tool', self.tool) |
| |
|
| | def test_properties(self): |
| | self.assertEqual(self.agent.name, 'TestAgent') |
| | self.assertEqual(self.agent.type, AgentType.react) |
| | self.assertEqual(self.agent.version, '1.0') |
| | self.assertEqual(self.agent.description, 'Test Description') |
| | self.assertEqual(self.agent.prompt_template, self.mock_prompt_template) |
| |
|
| | |
| | def test_llm_setter_happy_path(self): |
| | self.agent.llm = self.mock_llm |
| | self.assertEqual(self.agent.llm, self.mock_llm) |
| |
|
| | def test_llm_setter_input_error(self): |
| | with self.assertRaises(InputErrorException): |
| | self.agent.llm = 'invalid_llm' |
| |
|
| | |
| | def test_add_plugin_happy_path(self): |
| | self.agent.add_plugin('test_tool', 'test_tool_instance') |
| | self.assertIn('test_tool', self.agent.plugins_map) |
| |
|
| | def test_add_plugin_input_error(self): |
| | with self.assertRaises(InputErrorException): |
| | self.agent.add_plugin('', None) |
| |
|
| | def test_get_prompt_template_dict(self): |
| | |
| | with mock.patch.object(BaseAgent, '_parse_prompt_template', return_value=self.mock_prompt_template): |
| | result = self.agent._get_prompt_template({'test_key': 'dict'}) |
| | self.assertEqual(result, {'test_key': self.mock_prompt_template}) |
| |
|
| | def test_get_prompt_template_instance(self): |
| | |
| | prompt_instance = PromptTemplate(input_variables=["foo"], template="Say {foo}") |
| | result = self.agent._get_prompt_template(prompt_instance) |
| | self.assertEqual(result, prompt_instance) |
| |
|
| | def test_clear(self): |
| | |
| | self.agent.clear() |
| |
|
| | def test_get_plugin_tool_function(self): |
| | function_map = self.agent.get_plugin_tool_function() |
| | self.assertIn('test_tool', function_map) |
| | self.assertEqual(function_map['test_tool'], self.tool.run) |
| |
|
| | def test_get_plugin_tool_async_function(self): |
| | function_map = self.agent.get_plugin_tool_async_function() |
| | self.assertIn('test_tool', function_map) |
| | self.assertEqual(function_map['test_tool'], self.tool.async_run) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | unittest.main() |
| |
|