File size: 3,361 Bytes
afa4de7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
pytest configuration and fixtures for ARF API tests.
"""

from app.core.usage_tracker import enforce_quota, Tier
from app.api.deps import get_db
from app.database.base import Base
from app.main import app as fastapi_app
from sqlalchemy.orm import sessionmaker
from sqlalchemy import create_engine
from fastapi.testclient import TestClient
import app.core.usage_tracker
import os
import pytest

# ===== STEP 1: Set environment variables BEFORE any app imports =====
os.environ["ARF_USAGE_TRACKING"] = "false"

# Force the correct database URL for tests
os.environ["DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb"
os.environ["TEST_DATABASE_URL"] = "postgresql://postgres:postgres@localhost:5432/testdb"

# Additional PostgreSQL environment variables to prevent fallback to
# system user
os.environ["PGUSER"] = "postgres"
os.environ["PGPASSWORD"] = "postgres"
os.environ["PGHOST"] = "localhost"
os.environ["PGPORT"] = "5432"
os.environ["PGDATABASE"] = "testdb"


# ===== STEP 2: Mock the tracker module BEFORE importing app =====
class MockTracker:
    def get_tier(self, api_key):
        from app.core.usage_tracker import Tier

        return Tier.PRO

    def get_remaining_quota(self, api_key, tier):

        return 1000

    def consume_quota_and_log(self, record, idempotency_key=None):

        return (True, None)

    def increment_usage_sync(self, record, idempotency_key=None):
        return True

    def get_or_create_api_key(self, key, tier):

        return True

    def update_api_key_tier(self, key, tier):
        return True

    def _insert_audit_log(self, record):
        pass


# Replace the tracker at the module level
app.core.usage_tracker.tracker = MockTracker()

# ===== STEP 3: Import app and database modules =====

# Force model registration (prevents "no such table" errors)

# Use the environment variable for the database URL (already set)
TEST_DATABASE_URL = os.getenv(
    "TEST_DATABASE_URL",
    "postgresql://postgres:postgres@localhost:5432/testdb")

if TEST_DATABASE_URL.startswith("postgresql"):
    engine = create_engine(TEST_DATABASE_URL)
else:
    engine = create_engine(
        TEST_DATABASE_URL, connect_args={
            "check_same_thread": False})

TestingSessionLocal = sessionmaker(
    autocommit=False,
    autoflush=False,
    bind=engine)


def override_get_db():

    db = TestingSessionLocal()
    try:
        yield db

    finally:
        db.close()


fastapi_app.dependency_overrides[get_db] = override_get_db

# Override enforce_quota dependency


async def mock_enforce_quota(request, api_key=None):
    return {"api_key": "test_key", "tier": Tier.PRO, "remaining": 1000}
fastapi_app.dependency_overrides[enforce_quota] = mock_enforce_quota


@pytest.fixture(scope="session", autouse=True)
def setup_database():
    """Create tables before any tests run."""
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)


@pytest.fixture(scope="session")
def client():
    with TestClient(fastapi_app) as test_client:
        yield test_client


@pytest.fixture(scope="function")
def db_session():
    """Provide a clean database session for each test."""
    Base.metadata.create_all(bind=engine)
    session = TestingSessionLocal()
    yield session
    session.rollback()
    session.close()
    Base.metadata.drop_all(bind=engine)