import pytest import tempfile import time from app.core.usage_tracker import UsageTracker, Tier, UsageRecord @pytest.fixture def tracker(): with tempfile.NamedTemporaryFile(suffix=".db") as tmp: yield UsageTracker(db_path=tmp.name) def test_get_or_create_api_key(tracker): assert tracker.get_or_create_api_key("test_key", Tier.FREE) is True assert tracker.get_tier("test_key") == Tier.FREE # Second call should return True without error assert tracker.get_or_create_api_key("test_key") is True def test_update_api_key_tier(tracker): tracker.get_or_create_api_key("test_key", Tier.FREE) assert tracker.update_api_key_tier("test_key", Tier.PRO) is True assert tracker.get_tier("test_key") == Tier.PRO # Non-existent key assert tracker.update_api_key_tier("nonexistent", Tier.PRO) is False def test_get_remaining_quota_free(tracker): tracker.get_or_create_api_key("free_key", Tier.FREE) # Initially 1000 remaining remaining = tracker.get_remaining_quota("free_key", Tier.FREE) assert remaining == 1000 # Simulate usage using the atomic method record = UsageRecord( api_key="free_key", tier=Tier.FREE, timestamp=time.time(), endpoint="/test" ) tracker.increment_usage_sync(record) remaining = tracker.get_remaining_quota("free_key", Tier.FREE) assert remaining == 999 def test_get_remaining_quota_enterprise(tracker): tracker.get_or_create_api_key("ent_key", Tier.ENTERPRISE) remaining = tracker.get_remaining_quota("ent_key", Tier.ENTERPRISE) assert remaining is None def test_increment_usage_sync(tracker): tracker.get_or_create_api_key("test_key", Tier.FREE) record = UsageRecord( api_key="test_key", tier=Tier.FREE, timestamp=time.time(), endpoint="/test", ) result = tracker.increment_usage_sync(record) assert result is True # Check quota decreased remaining = tracker.get_remaining_quota("test_key", Tier.FREE) assert remaining == 999 def test_get_audit_logs(tracker): tracker.get_or_create_api_key("test_key", Tier.FREE) record = UsageRecord( api_key="test_key", tier=Tier.FREE, timestamp=time.time(), endpoint="/test", request_body={"foo": "bar"}, response={"status": "ok"}, ) tracker.increment_usage_sync(record) logs = tracker.get_audit_logs("test_key", limit=10) assert len(logs) == 1 assert logs[0]["endpoint"] == "/test"