| | """ |
| | Unit tests for database module |
| | Comprehensive test coverage for database operations |
| | """ |
| |
|
| | import pytest |
| | import sqlite3 |
| | import tempfile |
| | import os |
| | from datetime import datetime |
| | from pathlib import Path |
| |
|
| | |
| | import sys |
| | sys.path.insert(0, str(Path(__file__).parent.parent)) |
| |
|
| | from database import db_manager |
| | from database.migrations import MigrationManager, auto_migrate |
| |
|
| |
|
| | @pytest.fixture |
| | def temp_db(): |
| | """Create temporary database for testing""" |
| | fd, path = tempfile.mkstemp(suffix='.db') |
| | os.close(fd) |
| |
|
| | yield path |
| |
|
| | |
| | if os.path.exists(path): |
| | os.unlink(path) |
| |
|
| |
|
| | @pytest.fixture |
| | def db_instance(temp_db): |
| | """Create database instance for testing""" |
| | from database import CryptoDatabase |
| | db = CryptoDatabase(temp_db) |
| | return db |
| |
|
| |
|
| | class TestDatabaseInitialization: |
| | """Test database initialization and schema creation""" |
| |
|
| | def test_database_creation(self, temp_db): |
| | """Test that database file is created""" |
| | from database import CryptoDatabase |
| | db = CryptoDatabase(temp_db) |
| |
|
| | assert os.path.exists(temp_db) |
| | assert os.path.getsize(temp_db) > 0 |
| |
|
| | def test_tables_created(self, db_instance): |
| | """Test that all required tables are created""" |
| | conn = sqlite3.connect(db_instance.db_path) |
| | cursor = conn.cursor() |
| |
|
| | cursor.execute(""" |
| | SELECT name FROM sqlite_master |
| | WHERE type='table' |
| | """) |
| |
|
| | tables = {row[0] for row in cursor.fetchall()} |
| | conn.close() |
| |
|
| | required_tables = {'prices', 'news', 'market_analysis', 'user_queries'} |
| | assert required_tables.issubset(tables) |
| |
|
| | def test_indices_created(self, db_instance): |
| | """Test that indices are created""" |
| | conn = sqlite3.connect(db_instance.db_path) |
| | cursor = conn.cursor() |
| |
|
| | cursor.execute(""" |
| | SELECT name FROM sqlite_master |
| | WHERE type='index' |
| | """) |
| |
|
| | indices = {row[0] for row in cursor.fetchall()} |
| | conn.close() |
| |
|
| | |
| | assert len(indices) > 0 |
| |
|
| |
|
| | class TestPriceOperations: |
| | """Test price data operations""" |
| |
|
| | def test_save_price(self, db_instance): |
| | """Test saving price data""" |
| | price_data = { |
| | 'symbol': 'BTC', |
| | 'name': 'Bitcoin', |
| | 'price_usd': 50000.0, |
| | 'volume_24h': 1000000000, |
| | 'market_cap': 950000000000, |
| | 'percent_change_1h': 0.5, |
| | 'percent_change_24h': 2.3, |
| | 'percent_change_7d': -1.2, |
| | 'rank': 1 |
| | } |
| |
|
| | result = db_instance.save_price(price_data) |
| | assert result is True |
| |
|
| | def test_get_latest_prices(self, db_instance): |
| | """Test retrieving latest prices""" |
| | |
| | for i in range(10): |
| | price_data = { |
| | 'symbol': f'TEST{i}', |
| | 'name': f'Test Coin {i}', |
| | 'price_usd': 100.0 * (i + 1), |
| | 'volume_24h': 1000000, |
| | 'market_cap': 10000000, |
| | 'rank': i + 1 |
| | } |
| | db_instance.save_price(price_data) |
| |
|
| | prices = db_instance.get_latest_prices(limit=5) |
| |
|
| | assert len(prices) == 5 |
| | assert prices[0]['rank'] == 1 |
| |
|
| | def test_get_historical_prices(self, db_instance): |
| | """Test retrieving historical prices""" |
| | |
| | for i in range(5): |
| | price_data = { |
| | 'symbol': 'BTC', |
| | 'name': 'Bitcoin', |
| | 'price_usd': 50000.0 + (i * 100), |
| | 'volume_24h': 1000000000, |
| | 'market_cap': 950000000000, |
| | 'rank': 1 |
| | } |
| | db_instance.save_price(price_data) |
| |
|
| | prices = db_instance.get_historical_prices('BTC', days=7) |
| |
|
| | assert len(prices) > 0 |
| | assert all(p['symbol'] == 'BTC' for p in prices) |
| |
|
| |
|
| | class TestNewsOperations: |
| | """Test news data operations""" |
| |
|
| | def test_save_news(self, db_instance): |
| | """Test saving news article""" |
| | news_data = { |
| | 'title': 'Test Article', |
| | 'summary': 'This is a test summary', |
| | 'url': 'https://example.com/test', |
| | 'source': 'Test Source', |
| | 'sentiment_score': 0.8, |
| | 'sentiment_label': 'positive' |
| | } |
| |
|
| | result = db_instance.save_news(news_data) |
| | assert result is True |
| |
|
| | def test_duplicate_news_url(self, db_instance): |
| | """Test that duplicate URLs are rejected""" |
| | news_data = { |
| | 'title': 'Test Article', |
| | 'summary': 'Summary', |
| | 'url': 'https://example.com/unique', |
| | 'source': 'Test' |
| | } |
| |
|
| | |
| | assert db_instance.save_news(news_data) is True |
| |
|
| | |
| | assert db_instance.save_news(news_data) is False |
| |
|
| | def test_get_latest_news(self, db_instance): |
| | """Test retrieving latest news""" |
| | |
| | for i in range(10): |
| | news_data = { |
| | 'title': f'Article {i}', |
| | 'summary': f'Summary {i}', |
| | 'url': f'https://example.com/article{i}', |
| | 'source': 'Test Source' |
| | } |
| | db_instance.save_news(news_data) |
| |
|
| | news = db_instance.get_latest_news(limit=5) |
| |
|
| | assert len(news) == 5 |
| | assert all('title' in n for n in news) |
| |
|
| |
|
| | class TestAnalysisOperations: |
| | """Test market analysis operations""" |
| |
|
| | def test_save_analysis(self, db_instance): |
| | """Test saving market analysis""" |
| | analysis_data = { |
| | 'symbol': 'BTC', |
| | 'timeframe': '24h', |
| | 'trend': 'bullish', |
| | 'support_level': 45000.0, |
| | 'resistance_level': 55000.0, |
| | 'prediction': 'Price likely to increase', |
| | 'confidence': 0.75 |
| | } |
| |
|
| | result = db_instance.save_analysis(analysis_data) |
| | assert result is True |
| |
|
| | def test_get_latest_analysis(self, db_instance): |
| | """Test retrieving latest analysis""" |
| | |
| | analysis_data = { |
| | 'symbol': 'BTC', |
| | 'timeframe': '24h', |
| | 'trend': 'bullish', |
| | 'confidence': 0.8 |
| | } |
| | db_instance.save_analysis(analysis_data) |
| |
|
| | analysis = db_instance.get_latest_analysis('BTC') |
| |
|
| | assert analysis is not None |
| | assert analysis['symbol'] == 'BTC' |
| | assert analysis['trend'] == 'bullish' |
| |
|
| |
|
| | class TestMigrations: |
| | """Test database migration system""" |
| |
|
| | def test_migration_manager_init(self, temp_db): |
| | """Test migration manager initialization""" |
| | manager = MigrationManager(temp_db) |
| |
|
| | assert len(manager.migrations) > 0 |
| | assert manager.get_current_version() == 0 |
| |
|
| | def test_apply_migration(self, temp_db): |
| | """Test applying a single migration""" |
| | manager = MigrationManager(temp_db) |
| | pending = manager.get_pending_migrations() |
| |
|
| | assert len(pending) > 0 |
| |
|
| | |
| | result = manager.apply_migration(pending[0]) |
| | assert result is True |
| |
|
| | |
| | assert manager.get_current_version() == pending[0].version |
| |
|
| | def test_migrate_to_latest(self, temp_db): |
| | """Test migrating to latest version""" |
| | manager = MigrationManager(temp_db) |
| | success, applied = manager.migrate_to_latest() |
| |
|
| | assert success is True |
| | assert len(applied) > 0 |
| | assert manager.get_current_version() == max(applied) |
| |
|
| | def test_auto_migrate(self, temp_db): |
| | """Test auto-migration function""" |
| | result = auto_migrate(temp_db) |
| | assert result is True |
| |
|
| |
|
| | class TestDataValidation: |
| | """Test data validation""" |
| |
|
| | def test_price_validation(self, db_instance): |
| | """Test price data validation""" |
| | |
| | invalid_price = { |
| | 'symbol': 'BTC', |
| | 'name': 'Bitcoin', |
| | 'price_usd': -100.0, |
| | 'rank': 1 |
| | } |
| |
|
| | |
| | |
| |
|
| | def test_required_fields(self, db_instance): |
| | """Test that required fields are enforced""" |
| | |
| | incomplete_price = { |
| | 'symbol': 'BTC' |
| | |
| | } |
| |
|
| | |
| |
|
| |
|
| | class TestConcurrency: |
| | """Test concurrent database access""" |
| |
|
| | def test_concurrent_writes(self, db_instance): |
| | """Test concurrent write operations""" |
| | import threading |
| |
|
| | def write_price(i): |
| | price_data = { |
| | 'symbol': f'TEST{i}', |
| | 'name': f'Test {i}', |
| | 'price_usd': float(i), |
| | 'rank': i |
| | } |
| | db_instance.save_price(price_data) |
| |
|
| | threads = [threading.Thread(target=write_price, args=(i,)) for i in range(10)] |
| |
|
| | for t in threads: |
| | t.start() |
| |
|
| | for t in threads: |
| | t.join() |
| |
|
| | |
| | prices = db_instance.get_latest_prices(limit=10) |
| | assert len(prices) == 10 |
| |
|
| |
|
| | if __name__ == '__main__': |
| | pytest.main([__file__, '-v']) |
| |
|