petter2025's picture
Upload folder using huggingface_hub (#3)
6d20eab
"""
pytest configuration and fixtures for ARF API tests.
"""
from app.core.usage_tracker import enforce_quota, Tier
from app.api.deps import get_db
from app.database.base import Base
from app.main import app as fastapi_app
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from fastapi.testclient import TestClient
import app.core.usage_tracker
import os
import pytest
# ===== STEP 1: Set environment variables BEFORE any app imports =====
os.environ["ARF_USAGE_TRACKING"] = "false"
# Force the correct database URL for tests
os.environ["DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb"
os.environ["TEST_DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb"
# Additional PostgreSQL environment variables to prevent fallback to
# system user
os.environ["PGUSER"] = "postgres"
os.environ["PGPASSWORD"] = "postgres"
os.environ["PGHOST"] = "localhost"
os.environ["PGPORT"] = "5432"
os.environ["PGDATABASE"] = "testdb"
# ===== STEP 2: Mock the tracker module BEFORE importing app =====
class MockTracker:
def get_tier(self, api_key):
from app.core.usage_tracker import Tier
return Tier.PRO
def get_remaining_quota(self, api_key, tier):
return 1000
def consume_quota_and_log(self, record, idempotency_key=None):
return (True, None)
def increment_usage_sync(self, record, idempotency_key=None):
return True
def get_or_create_api_key(self, key, tier):
return True
def update_api_key_tier(self, key, tier):
return True
def _insert_audit_log(self, record):
pass
# Replace the tracker at the module level
app.core.usage_tracker.tracker = MockTracker()
# ===== STEP 3: Import app and database modules =====
# Force model registration (prevents "no such table" errors)
# Use the environment variable for the database URL (already set)
TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/testdb")
if TEST_DATABASE_URL.startswith("postgresql"):
engine = create_engine(TEST_DATABASE_URL)
else:
engine = create_engine(
TEST_DATABASE_URL, connect_args={
"check_same_thread": False})
TestingSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine)
def override_get_db():
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
fastapi_app.dependency_overrides[get_db] = override_get_db
# Override enforce_quota dependency
async def mock_enforce_quota(request, api_key=None):
return {"api_key": "test_key", "tier": Tier.PRO, "remaining": 1000}
fastapi_app.dependency_overrides[enforce_quota] = mock_enforce_quota
@pytest.fixture(scope="session", autouse=True)
def setup_database():
"""Create tables before any tests run."""
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="session")
def client():
with TestClient(fastapi_app) as test_client:
yield test_client
@pytest.fixture(scope="function")
def db_session():
"""Provide a clean database session for each test."""
Base.metadata.create_all(bind=engine)
session = TestingSessionLocal()
yield session
session.rollback()
session.close()
Base.metadata.drop_all(bind=engine)