| import pytest |
| import tempfile |
| import time |
| from app.core.usage_tracker import UsageTracker, Tier, UsageRecord |
|
|
|
|
| @pytest.fixture |
| def tracker(): |
| with tempfile.NamedTemporaryFile(suffix=".db") as tmp: |
| yield UsageTracker(db_path=tmp.name) |
|
|
|
|
| def test_get_or_create_api_key(tracker): |
| assert tracker.get_or_create_api_key("test_key", Tier.FREE) is True |
| assert tracker.get_tier("test_key") == Tier.FREE |
| |
| assert tracker.get_or_create_api_key("test_key") is True |
|
|
|
|
| def test_update_api_key_tier(tracker): |
| tracker.get_or_create_api_key("test_key", Tier.FREE) |
| assert tracker.update_api_key_tier("test_key", Tier.PRO) is True |
| assert tracker.get_tier("test_key") == Tier.PRO |
| |
| assert tracker.update_api_key_tier("nonexistent", Tier.PRO) is False |
|
|
|
|
| def test_get_remaining_quota_free(tracker): |
| tracker.get_or_create_api_key("free_key", Tier.FREE) |
| |
| remaining = tracker.get_remaining_quota("free_key", Tier.FREE) |
| assert remaining == 1000 |
| |
| record = UsageRecord( |
| api_key="free_key", |
| tier=Tier.FREE, |
| timestamp=time.time(), |
| endpoint="/test" |
| ) |
| tracker.increment_usage_sync(record) |
| remaining = tracker.get_remaining_quota("free_key", Tier.FREE) |
| assert remaining == 999 |
|
|
|
|
| def test_get_remaining_quota_enterprise(tracker): |
| tracker.get_or_create_api_key("ent_key", Tier.ENTERPRISE) |
| remaining = tracker.get_remaining_quota("ent_key", Tier.ENTERPRISE) |
| assert remaining is None |
|
|
|
|
| def test_increment_usage_sync(tracker): |
| tracker.get_or_create_api_key("test_key", Tier.FREE) |
| record = UsageRecord( |
| api_key="test_key", |
| tier=Tier.FREE, |
| timestamp=time.time(), |
| endpoint="/test", |
| ) |
| result = tracker.increment_usage_sync(record) |
| assert result is True |
| |
| remaining = tracker.get_remaining_quota("test_key", Tier.FREE) |
| assert remaining == 999 |
|
|
|
|
| def test_get_audit_logs(tracker): |
| tracker.get_or_create_api_key("test_key", Tier.FREE) |
| record = UsageRecord( |
| api_key="test_key", |
| tier=Tier.FREE, |
| timestamp=time.time(), |
| endpoint="/test", |
| request_body={"foo": "bar"}, |
| response={"status": "ok"}, |
| ) |
| tracker.increment_usage_sync(record) |
| logs = tracker.get_audit_logs("test_key", limit=10) |
| assert len(logs) == 1 |
| assert logs[0]["endpoint"] == "/test" |
|
|