Spaces:
Build error
Build error
| """ | |
| pytest configuration and fixtures for ARF API tests. | |
| """ | |
| from app.core.usage_tracker import enforce_quota, Tier | |
| from app.api.deps import get_db | |
| from app.database.base import Base | |
| from app.main import app as fastapi_app | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy import create_engine | |
| from fastapi.testclient import TestClient | |
| import app.core.usage_tracker | |
| import os | |
| import pytest | |
| # ===== STEP 1: Set environment variables BEFORE any app imports ===== | |
| os.environ["ARF_USAGE_TRACKING"] = "false" | |
| # Force the correct database URL for tests | |
| os.environ["DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb" | |
| os.environ["TEST_DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb" | |
| # Additional PostgreSQL environment variables to prevent fallback to | |
| # system user | |
| os.environ["PGUSER"] = "postgres" | |
| os.environ["PGPASSWORD"] = "postgres" | |
| os.environ["PGHOST"] = "localhost" | |
| os.environ["PGPORT"] = "5432" | |
| os.environ["PGDATABASE"] = "testdb" | |
| # ===== STEP 2: Mock the tracker module BEFORE importing app ===== | |
| class MockTracker: | |
| def get_tier(self, api_key): | |
| from app.core.usage_tracker import Tier | |
| return Tier.PRO | |
| def get_remaining_quota(self, api_key, tier): | |
| return 1000 | |
| def consume_quota_and_log(self, record, idempotency_key=None): | |
| return (True, None) | |
| def increment_usage_sync(self, record, idempotency_key=None): | |
| return True | |
| def get_or_create_api_key(self, key, tier): | |
| return True | |
| def update_api_key_tier(self, key, tier): | |
| return True | |
| def _insert_audit_log(self, record): | |
| pass | |
| # Replace the tracker at the module level | |
| app.core.usage_tracker.tracker = MockTracker() | |
| # ===== STEP 3: Import app and database modules ===== | |
| # Force model registration (prevents "no such table" errors) | |
| # Use the environment variable for the database URL (already set) | |
| TEST_DATABASE_URL = os.getenv( | |
| "TEST_DATABASE_URL", | |
| "postgresql://postgres:postgres@localhost:5432/testdb") | |
| if TEST_DATABASE_URL.startswith("postgresql"): | |
| engine = create_engine(TEST_DATABASE_URL) | |
| else: | |
| engine = create_engine( | |
| TEST_DATABASE_URL, connect_args={ | |
| "check_same_thread": False}) | |
| TestingSessionLocal = sessionmaker( | |
| autocommit=False, | |
| autoflush=False, | |
| bind=engine) | |
| def override_get_db(): | |
| db = TestingSessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| fastapi_app.dependency_overrides[get_db] = override_get_db | |
| # Override enforce_quota dependency | |
| async def mock_enforce_quota(request, api_key=None): | |
| return {"api_key": "test_key", "tier": Tier.PRO, "remaining": 1000} | |
| fastapi_app.dependency_overrides[enforce_quota] = mock_enforce_quota | |
| def setup_database(): | |
| """Create tables before any tests run.""" | |
| Base.metadata.create_all(bind=engine) | |
| yield | |
| Base.metadata.drop_all(bind=engine) | |
| def client(): | |
| with TestClient(fastapi_app) as test_client: | |
| yield test_client | |
| def db_session(): | |
| """Provide a clean database session for each test.""" | |
| Base.metadata.create_all(bind=engine) | |
| session = TestingSessionLocal() | |
| yield session | |
| session.rollback() | |
| session.close() | |
| Base.metadata.drop_all(bind=engine) | |