| | """ |
| | Test the parquet module. |
| | |
| | Mostly auto-generated by Cursor + GPT-5. |
| | """ |
| |
|
| | import os |
| | import tempfile |
| | from typing import Any |
| |
|
| | import pandas as pd |
| | import pytest |
| | from sqlalchemy import create_engine, text |
| | from sqlalchemy.engine import Engine |
| | from sqlmodel import Field, Session, SQLModel |
| |
|
| | from parquet import export_to_parquet, import_from_parquet |
| |
|
| |
|
| | |
| | class DummyUser(SQLModel, table=True): |
| | id: int = Field(primary_key=True) |
| | name: str = Field(max_length=100) |
| | email: str = Field(max_length=255) |
| | age: int = Field() |
| |
|
| |
|
| | class DummyProduct(SQLModel, table=True): |
| | id: int = Field(primary_key=True) |
| | name: str = Field(max_length=200) |
| | price: float = Field() |
| | category: str = Field(max_length=100) |
| |
|
| |
|
| | @pytest.fixture |
| | def temp_db_engine(): |
| | """Create a temporary SQLite database engine for testing.""" |
| | |
| | temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") |
| | temp_db.close() |
| |
|
| | |
| | engine = create_engine(f"sqlite:///{temp_db.name}") |
| |
|
| | |
| | SQLModel.metadata.create_all(engine) |
| |
|
| | yield engine |
| |
|
| | |
| | engine.dispose() |
| | os.unlink(temp_db.name) |
| |
|
| |
|
| | @pytest.fixture |
| | def sample_data(): |
| | """Sample data for testing.""" |
| | users_data = [ |
| | {"id": 1, "name": "Alice", "email": "alice@example.com", "age": 30}, |
| | {"id": 2, "name": "Bob", "email": "bob@example.com", "age": 25}, |
| | {"id": 3, "name": "Charlie", "email": "charlie@example.com", "age": 35}, |
| | ] |
| |
|
| | products_data = [ |
| | {"id": 1, "name": "Laptop", "price": 999.99, "category": "Electronics"}, |
| | {"id": 2, "name": "Book", "price": 19.99, "category": "Education"}, |
| | {"id": 3, "name": "Coffee", "price": 4.99, "category": "Food"}, |
| | ] |
| |
|
| | return {"users": users_data, "products": products_data} |
| |
|
| |
|
| | @pytest.fixture |
| | def populated_db(temp_db_engine: Engine, sample_data: dict[str, list[dict[str, Any]]]): |
| | """Populate the temporary database with sample data.""" |
| | with Session(temp_db_engine) as session: |
| | |
| | for user_data in sample_data["users"]: |
| | user = DummyUser(**user_data) |
| | session.add(user) |
| |
|
| | |
| | for product_data in sample_data["products"]: |
| | product = DummyProduct(**product_data) |
| | session.add(product) |
| |
|
| | session.commit() |
| |
|
| | return temp_db_engine |
| |
|
| |
|
| | def test_export_to_parquet_success( |
| | populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] |
| | ): |
| | """Test successful export of tables to parquet files.""" |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | export_to_parquet(populated_db, temp_dir) |
| |
|
| | |
| | assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet")) |
| | assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet")) |
| |
|
| | |
| | users_df = pd.read_parquet(os.path.join(temp_dir, "dummyuser.parquet")) |
| | products_df = pd.read_parquet(os.path.join(temp_dir, "dummyproduct.parquet")) |
| |
|
| | assert len(users_df) == len(sample_data["users"]) |
| | assert len(products_df) == len(sample_data["products"]) |
| |
|
| | |
| | assert users_df.equals( |
| | users_df.sort_values(by=list(users_df.columns)).reset_index(drop=True) |
| | ) |
| | assert products_df.equals( |
| | products_df.sort_values(by=list(products_df.columns)).reset_index(drop=True) |
| | ) |
| |
|
| |
|
| | def test_export_to_parquet_empty_table(temp_db_engine: Engine): |
| | """Test export with empty table.""" |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | export_to_parquet(temp_db_engine, temp_dir) |
| |
|
| | |
| | assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet")) |
| | assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet")) |
| |
|
| |
|
| | def test_export_to_parquet_creates_directory(populated_db): |
| | """Test that export creates the backup directory if it doesn't exist.""" |
| | temp_dir = os.path.join(tempfile.gettempdir(), "test_backup_dir") |
| |
|
| | try: |
| | export_to_parquet(populated_db, temp_dir) |
| | assert os.path.exists(temp_dir) |
| | assert os.path.isdir(temp_dir) |
| | finally: |
| | if os.path.exists(temp_dir): |
| | import shutil |
| |
|
| | shutil.rmtree(temp_dir) |
| |
|
| |
|
| | def test_import_from_parquet_success( |
| | populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] |
| | ): |
| | """Test successful import from parquet files.""" |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | |
| | export_to_parquet(populated_db, temp_dir) |
| |
|
| | |
| | with Session(populated_db) as session: |
| | session.exec(text("DELETE FROM dummyuser")) |
| | session.exec(text("DELETE FROM dummyproduct")) |
| | session.commit() |
| |
|
| | |
| | with Session(populated_db) as session: |
| | users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first() |
| | products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first() |
| | assert users[0] == 0 |
| | assert products[0] == 0 |
| |
|
| | |
| | import_from_parquet(populated_db, temp_dir) |
| |
|
| | |
| | with Session(populated_db) as session: |
| | users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first() |
| | products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first() |
| | assert users[0] == len(sample_data["users"]) |
| | assert products[0] == len(sample_data["products"]) |
| |
|
| |
|
| | def test_import_from_parquet_missing_file(populated_db: Engine): |
| | """Test import handles missing parquet files gracefully.""" |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | |
| | import_from_parquet(populated_db, temp_dir) |
| | |
| |
|
| |
|
| | def test_import_from_parquet_clears_existing_data(populated_db: Engine): |
| | """Test that import clears existing data before inserting new data.""" |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | |
| | export_to_parquet(populated_db, temp_dir) |
| |
|
| | |
| | with Session(populated_db) as session: |
| | session.exec(text("UPDATE dummyuser SET name = 'Modified' WHERE id = 1")) |
| | session.commit() |
| |
|
| | |
| | with Session(populated_db) as session: |
| | result = session.exec( |
| | text("SELECT name FROM dummyuser WHERE id = 1") |
| | ).first() |
| | assert result[0] == "Modified" |
| |
|
| | |
| | import_from_parquet(populated_db, temp_dir) |
| |
|
| | |
| | with Session(populated_db) as session: |
| | result = session.exec( |
| | text("SELECT name FROM dummyuser WHERE id = 1") |
| | ).first() |
| | assert result[0] == "Alice" |
| |
|
| |
|
| | def test_export_import_cycle( |
| | populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] |
| | ): |
| | """Test complete export and import cycle maintains data integrity.""" |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | |
| | export_to_parquet(populated_db, temp_dir) |
| |
|
| | |
| | with Session(populated_db) as session: |
| | session.exec(text("DELETE FROM dummyuser")) |
| | session.exec(text("DELETE FROM dummyproduct")) |
| | session.commit() |
| |
|
| | |
| | import_from_parquet(populated_db, temp_dir) |
| |
|
| | |
| | with Session(populated_db) as session: |
| | |
| | users_result = session.exec( |
| | text("SELECT * FROM dummyuser ORDER BY id") |
| | ).fetchall() |
| | assert len(users_result) == len(sample_data["users"]) |
| |
|
| | for i, user in enumerate(users_result): |
| | assert user[0] == sample_data["users"][i]["id"] |
| | assert user[1] == sample_data["users"][i]["name"] |
| | assert user[2] == sample_data["users"][i]["email"] |
| | assert user[3] == sample_data["users"][i]["age"] |
| |
|
| | |
| | products_result = session.exec( |
| | text("SELECT * FROM dummyproduct ORDER BY id") |
| | ).fetchall() |
| | assert len(products_result) == len(sample_data["products"]) |
| |
|
| | for i, product in enumerate(products_result): |
| | assert product[0] == sample_data["products"][i]["id"] |
| | assert product[1] == sample_data["products"][i]["name"] |
| | assert product[2] == sample_data["products"][i]["price"] |
| | assert product[3] == sample_data["products"][i]["category"] |
| |
|