| """ |
| pytest configuration and fixtures for ARF API tests. |
| """ |
|
|
| from app.core.usage_tracker import enforce_quota, Tier |
| from app.api.deps import get_db |
| from app.database.base import Base |
| from app.main import app as fastapi_app |
| from sqlalchemy.orm import sessionmaker |
| from sqlalchemy import create_engine |
| from fastapi.testclient import TestClient |
| import app.core.usage_tracker |
| import os |
| import pytest |
|
|
| |
| os.environ["ARF_USAGE_TRACKING"] = "false" |
|
|
| |
| os.environ["DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb" |
| os.environ["TEST_DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb" |
|
|
| |
| |
| os.environ["PGUSER"] = "postgres" |
| os.environ["PGPASSWORD"] = "postgres" |
| os.environ["PGHOST"] = "localhost" |
| os.environ["PGPORT"] = "5432" |
| os.environ["PGDATABASE"] = "testdb" |
|
|
|
|
| |
| class MockTracker: |
| def get_tier(self, api_key): |
| from app.core.usage_tracker import Tier |
|
|
| return Tier.PRO |
|
|
| def get_remaining_quota(self, api_key, tier): |
|
|
| return 1000 |
|
|
| def consume_quota_and_log(self, record, idempotency_key=None): |
|
|
| return (True, None) |
|
|
| def increment_usage_sync(self, record, idempotency_key=None): |
| return True |
|
|
| def get_or_create_api_key(self, key, tier): |
|
|
| return True |
|
|
| def update_api_key_tier(self, key, tier): |
| return True |
|
|
| def _insert_audit_log(self, record): |
| pass |
|
|
|
|
| |
| app.core.usage_tracker.tracker = MockTracker() |
|
|
| |
|
|
| |
|
|
| |
| TEST_DATABASE_URL = os.getenv( |
| "TEST_DATABASE_URL", |
| "postgresql://postgres:postgres@localhost:5432/testdb") |
|
|
| if TEST_DATABASE_URL.startswith("postgresql"): |
| engine = create_engine(TEST_DATABASE_URL) |
| else: |
| engine = create_engine( |
| TEST_DATABASE_URL, connect_args={ |
| "check_same_thread": False}) |
|
|
| TestingSessionLocal = sessionmaker( |
| autocommit=False, |
| autoflush=False, |
| bind=engine) |
|
|
|
|
| def override_get_db(): |
|
|
| db = TestingSessionLocal() |
| try: |
| yield db |
|
|
| finally: |
| db.close() |
|
|
|
|
| fastapi_app.dependency_overrides[get_db] = override_get_db |
|
|
| |
|
|
|
|
| async def mock_enforce_quota(request, api_key=None): |
| return {"api_key": "test_key", "tier": Tier.PRO, "remaining": 1000} |
| fastapi_app.dependency_overrides[enforce_quota] = mock_enforce_quota |
|
|
|
|
| @pytest.fixture(scope="session", autouse=True) |
| def setup_database(): |
| """Create tables before any tests run.""" |
| Base.metadata.create_all(bind=engine) |
| yield |
| Base.metadata.drop_all(bind=engine) |
|
|
|
|
| @pytest.fixture(scope="session") |
| def client(): |
| with TestClient(fastapi_app) as test_client: |
| yield test_client |
|
|
|
|
| @pytest.fixture(scope="function") |
| def db_session(): |
| """Provide a clean database session for each test.""" |
| Base.metadata.create_all(bind=engine) |
| session = TestingSessionLocal() |
| yield session |
| session.rollback() |
| session.close() |
| Base.metadata.drop_all(bind=engine) |
|
|