Spaces:
Sleeping
Sleeping
| """ | |
| Test the parquet module. | |
| Mostly auto-generated by Cursor + GPT-5. | |
| """ | |
| import os | |
| import tempfile | |
| from typing import Any | |
| import pandas as pd | |
| import pytest | |
| from sqlalchemy import create_engine, text | |
| from sqlalchemy.engine import Engine | |
| from sqlmodel import Field, Session, SQLModel | |
| from parquet import export_to_parquet, import_from_parquet | |
| # Test model for creating temporary tables | |
| class DummyUser(SQLModel, table=True): | |
| id: int = Field(primary_key=True) | |
| name: str = Field(max_length=100) | |
| email: str = Field(max_length=255) | |
| age: int = Field() | |
| class DummyProduct(SQLModel, table=True): | |
| id: int = Field(primary_key=True) | |
| name: str = Field(max_length=200) | |
| price: float = Field() | |
| category: str = Field(max_length=100) | |
| def temp_db_engine(): | |
| """Create a temporary SQLite database engine for testing.""" | |
| # Create temporary database file | |
| temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") | |
| temp_db.close() | |
| # Create engine | |
| engine = create_engine(f"sqlite:///{temp_db.name}") | |
| # Create tables | |
| SQLModel.metadata.create_all(engine) | |
| yield engine | |
| # Cleanup | |
| engine.dispose() | |
| os.unlink(temp_db.name) | |
| def sample_data(): | |
| """Sample data for testing.""" | |
| users_data = [ | |
| {"id": 1, "name": "Alice", "email": "alice@example.com", "age": 30}, | |
| {"id": 2, "name": "Bob", "email": "bob@example.com", "age": 25}, | |
| {"id": 3, "name": "Charlie", "email": "charlie@example.com", "age": 35}, | |
| ] | |
| products_data = [ | |
| {"id": 1, "name": "Laptop", "price": 999.99, "category": "Electronics"}, | |
| {"id": 2, "name": "Book", "price": 19.99, "category": "Education"}, | |
| {"id": 3, "name": "Coffee", "price": 4.99, "category": "Food"}, | |
| ] | |
| return {"users": users_data, "products": products_data} | |
| def populated_db(temp_db_engine: Engine, sample_data: dict[str, list[dict[str, Any]]]): | |
| """Populate the temporary database with sample data.""" | |
| with Session(temp_db_engine) as session: | |
| # Insert users | |
| for user_data in sample_data["users"]: | |
| user = DummyUser(**user_data) | |
| session.add(user) | |
| # Insert products | |
| for product_data in sample_data["products"]: | |
| product = DummyProduct(**product_data) | |
| session.add(product) | |
| session.commit() | |
| return temp_db_engine | |
| def test_export_to_parquet_success( | |
| populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] | |
| ): | |
| """Test successful export of tables to parquet files.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| export_to_parquet(populated_db, temp_dir) | |
| # Check that files were created | |
| assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet")) | |
| assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet")) | |
| # Verify data integrity | |
| users_df = pd.read_parquet(os.path.join(temp_dir, "dummyuser.parquet")) | |
| products_df = pd.read_parquet(os.path.join(temp_dir, "dummyproduct.parquet")) | |
| assert len(users_df) == len(sample_data["users"]) | |
| assert len(products_df) == len(sample_data["products"]) | |
| # Check that data is sorted | |
| assert users_df.equals( | |
| users_df.sort_values(by=list(users_df.columns)).reset_index(drop=True) | |
| ) | |
| assert products_df.equals( | |
| products_df.sort_values(by=list(products_df.columns)).reset_index(drop=True) | |
| ) | |
| def test_export_to_parquet_empty_table(temp_db_engine: Engine): | |
| """Test export with empty table.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| export_to_parquet(temp_db_engine, temp_dir) | |
| # Should create file but skip empty table | |
| assert os.path.exists(os.path.join(temp_dir, "dummyuser.parquet")) | |
| assert os.path.exists(os.path.join(temp_dir, "dummyproduct.parquet")) | |
| def test_export_to_parquet_creates_directory(populated_db): | |
| """Test that export creates the backup directory if it doesn't exist.""" | |
| temp_dir = os.path.join(tempfile.gettempdir(), "test_backup_dir") | |
| try: | |
| export_to_parquet(populated_db, temp_dir) | |
| assert os.path.exists(temp_dir) | |
| assert os.path.isdir(temp_dir) | |
| finally: | |
| if os.path.exists(temp_dir): | |
| import shutil | |
| shutil.rmtree(temp_dir) | |
| def test_import_from_parquet_success( | |
| populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] | |
| ): | |
| """Test successful import from parquet files.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # First export | |
| export_to_parquet(populated_db, temp_dir) | |
| # Clear the database | |
| with Session(populated_db) as session: | |
| session.exec(text("DELETE FROM dummyuser")) | |
| session.exec(text("DELETE FROM dummyproduct")) | |
| session.commit() | |
| # Verify tables are empty | |
| with Session(populated_db) as session: | |
| users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first() | |
| products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first() | |
| assert users[0] == 0 | |
| assert products[0] == 0 | |
| # Import from parquet | |
| import_from_parquet(populated_db, temp_dir) | |
| # Verify data was imported | |
| with Session(populated_db) as session: | |
| users = session.exec(text("SELECT COUNT(*) FROM dummyuser")).first() | |
| products = session.exec(text("SELECT COUNT(*) FROM dummyproduct")).first() | |
| assert users[0] == len(sample_data["users"]) | |
| assert products[0] == len(sample_data["products"]) | |
| def test_import_from_parquet_missing_file(populated_db: Engine): | |
| """Test import handles missing parquet files gracefully.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Don't create any parquet files | |
| import_from_parquet(populated_db, temp_dir) | |
| # Should not raise an error, just skip missing files | |
| def test_import_from_parquet_clears_existing_data(populated_db: Engine): | |
| """Test that import clears existing data before inserting new data.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # First export | |
| export_to_parquet(populated_db, temp_dir) | |
| # Modify data in database | |
| with Session(populated_db) as session: | |
| session.exec(text("UPDATE dummyuser SET name = 'Modified' WHERE id = 1")) | |
| session.commit() | |
| # Verify modification | |
| with Session(populated_db) as session: | |
| result = session.exec( | |
| text("SELECT name FROM dummyuser WHERE id = 1") | |
| ).first() | |
| assert result[0] == "Modified" | |
| # Import should clear and restore original data | |
| import_from_parquet(populated_db, temp_dir) | |
| # Original name restored | |
| with Session(populated_db) as session: | |
| result = session.exec( | |
| text("SELECT name FROM dummyuser WHERE id = 1") | |
| ).first() | |
| assert result[0] == "Alice" | |
| def test_export_import_cycle( | |
| populated_db: Engine, sample_data: dict[str, list[dict[str, Any]]] | |
| ): | |
| """Test complete export and import cycle maintains data integrity.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # Export | |
| export_to_parquet(populated_db, temp_dir) | |
| # Clear database | |
| with Session(populated_db) as session: | |
| session.exec(text("DELETE FROM dummyuser")) | |
| session.exec(text("DELETE FROM dummyproduct")) | |
| session.commit() | |
| # Import | |
| import_from_parquet(populated_db, temp_dir) | |
| # Verify data integrity | |
| with Session(populated_db) as session: | |
| # Check users | |
| users_result = session.exec( | |
| text("SELECT * FROM dummyuser ORDER BY id") | |
| ).fetchall() | |
| assert len(users_result) == len(sample_data["users"]) | |
| for i, user in enumerate(users_result): | |
| assert user[0] == sample_data["users"][i]["id"] | |
| assert user[1] == sample_data["users"][i]["name"] | |
| assert user[2] == sample_data["users"][i]["email"] | |
| assert user[3] == sample_data["users"][i]["age"] | |
| # Check products | |
| products_result = session.exec( | |
| text("SELECT * FROM dummyproduct ORDER BY id") | |
| ).fetchall() | |
| assert len(products_result) == len(sample_data["products"]) | |
| for i, product in enumerate(products_result): | |
| assert product[0] == sample_data["products"][i]["id"] | |
| assert product[1] == sample_data["products"][i]["name"] | |
| assert product[2] == sample_data["products"][i]["price"] | |
| assert product[3] == sample_data["products"][i]["category"] | |