| | |
| | """ |
| | Phase 1 Validation Test Script |
| | Tests that HF API inference has been removed and local models work correctly |
| | """ |
| |
|
| | import sys |
| | import os |
| | import asyncio |
| | import logging |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | def test_imports(): |
| | """Test that all required modules can be imported""" |
| | logger.info("Testing imports...") |
| | try: |
| | from src.llm_router import LLMRouter |
| | from src.models_config import LLM_CONFIG |
| | from src.local_model_loader import LocalModelLoader |
| | logger.info("✅ All imports successful") |
| | return True |
| | except Exception as e: |
| | logger.error(f"❌ Import failed: {e}") |
| | return False |
| |
|
| | def test_models_config(): |
| | """Test that models_config is updated correctly""" |
| | logger.info("Testing models_config...") |
| | try: |
| | from src.models_config import LLM_CONFIG |
| | |
| | |
| | assert LLM_CONFIG["primary_provider"] == "local", "Primary provider should be 'local'" |
| | logger.info("✅ Primary provider is 'local'") |
| | |
| | |
| | reasoning_model = LLM_CONFIG["models"]["reasoning_primary"]["model_id"] |
| | assert ":cerebras" not in reasoning_model, "Model ID should not have API suffix" |
| | assert reasoning_model == "Qwen/Qwen2.5-7B-Instruct", "Should use Qwen model" |
| | logger.info(f"✅ Reasoning model: {reasoning_model}") |
| | |
| | |
| | assert "API" not in str(LLM_CONFIG["routing_logic"]["fallback_chain"]), "No API in fallback chain" |
| | logger.info("✅ Routing logic updated") |
| | |
| | return True |
| | except Exception as e: |
| | logger.error(f"❌ Models config test failed: {e}") |
| | return False |
| |
|
| | def test_llm_router_init(): |
| | """Test LLM router initialization""" |
| | logger.info("Testing LLM router initialization...") |
| | try: |
| | from src.llm_router import LLMRouter |
| | |
| | |
| | try: |
| | router = LLMRouter(hf_token=None, use_local_models=False) |
| | logger.error("❌ Should have raised ValueError for use_local_models=False") |
| | return False |
| | except ValueError: |
| | logger.info("✅ Correctly raises error for use_local_models=False") |
| | |
| | |
| | try: |
| | router = LLMRouter(hf_token=None, use_local_models=True) |
| | logger.info("✅ LLM router initialized (local models)") |
| | |
| | |
| | assert not hasattr(router, '_call_hf_endpoint'), "Should not have _call_hf_endpoint method" |
| | assert not hasattr(router, '_is_model_healthy'), "Should not have _is_model_healthy method" |
| | assert not hasattr(router, '_get_fallback_model'), "Should not have _get_fallback_model method" |
| | logger.info("✅ HF API methods removed") |
| | |
| | return True |
| | except RuntimeError as e: |
| | logger.warning(f"⚠️ Local models not available: {e}") |
| | logger.warning("This is expected if transformers/torch not installed") |
| | return True |
| | except Exception as e: |
| | logger.error(f"❌ LLM router test failed: {e}") |
| | return False |
| |
|
| | def test_no_api_references(): |
| | """Test that no API references remain in code""" |
| | logger.info("Testing for API references...") |
| | try: |
| | import inspect |
| | from src.llm_router import LLMRouter |
| | |
| | router_source = inspect.getsource(LLMRouter) |
| | |
| | |
| | assert "_call_hf_endpoint" not in router_source, "Should not have _call_hf_endpoint" |
| | assert "router.huggingface.co" not in router_source, "Should not have HF API URL" |
| | assert "HF Inference API" not in router_source or "no API fallback" in router_source, "Should not reference HF API" |
| | |
| | logger.info("✅ No API references found in LLM router") |
| | return True |
| | except Exception as e: |
| | logger.error(f"❌ API reference test failed: {e}") |
| | return False |
| |
|
| | async def test_inference_flow(): |
| | """Test inference flow (if models available)""" |
| | logger.info("Testing inference flow...") |
| | try: |
| | from src.llm_router import LLMRouter |
| | |
| | router = LLMRouter(hf_token=None, use_local_models=True) |
| | |
| | |
| | try: |
| | result = await router.route_inference( |
| | task_type="general_reasoning", |
| | prompt="What is 2+2?", |
| | max_tokens=50 |
| | ) |
| | |
| | if result: |
| | logger.info(f"✅ Inference successful: {result[:50]}...") |
| | return True |
| | else: |
| | logger.warning("⚠️ Inference returned None") |
| | return False |
| | except RuntimeError as e: |
| | logger.warning(f"⚠️ Inference failed (expected if models not loaded): {e}") |
| | return True |
| | except RuntimeError as e: |
| | logger.warning(f"⚠️ Router not available: {e}") |
| | return True |
| | except Exception as e: |
| | logger.error(f"❌ Inference test failed: {e}") |
| | return False |
| |
|
| | def main(): |
| | """Run all tests""" |
| | logger.info("=" * 60) |
| | logger.info("PHASE 1 VALIDATION TESTS") |
| | logger.info("=" * 60) |
| | |
| | tests = [ |
| | ("Imports", test_imports), |
| | ("Models Config", test_models_config), |
| | ("LLM Router Init", test_llm_router_init), |
| | ("No API References", test_no_api_references), |
| | ] |
| | |
| | results = [] |
| | for test_name, test_func in tests: |
| | logger.info(f"\n--- Running {test_name} Test ---") |
| | try: |
| | result = test_func() |
| | results.append((test_name, result)) |
| | except Exception as e: |
| | logger.error(f"Test {test_name} crashed: {e}") |
| | results.append((test_name, False)) |
| | |
| | |
| | logger.info("\n--- Running Inference Flow Test ---") |
| | try: |
| | result = asyncio.run(test_inference_flow()) |
| | results.append(("Inference Flow", result)) |
| | except Exception as e: |
| | logger.error(f"Inference flow test crashed: {e}") |
| | results.append(("Inference Flow", False)) |
| | |
| | |
| | logger.info("\n" + "=" * 60) |
| | logger.info("TEST SUMMARY") |
| | logger.info("=" * 60) |
| | |
| | passed = sum(1 for _, result in results if result) |
| | total = len(results) |
| | |
| | for test_name, result in results: |
| | status = "✅ PASS" if result else "❌ FAIL" |
| | logger.info(f"{status}: {test_name}") |
| | |
| | logger.info(f"\nTotal: {passed}/{total} tests passed") |
| | |
| | if passed == total: |
| | logger.info("✅ All tests passed!") |
| | return 0 |
| | else: |
| | logger.warning(f"⚠️ {total - passed} test(s) failed") |
| | return 1 |
| |
|
| | if __name__ == "__main__": |
| | sys.exit(main()) |
| |
|
| |
|