mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 04:58:19 -04:00
68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
"""test_database."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from sqlalchemy import Integer, String, create_engine, select
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker
|
|
|
|
from python.database import safe_insert
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
|
|
class TestingBase(DeclarativeBase):
|
|
"""TestingBase."""
|
|
|
|
|
|
class Item(TestingBase):
|
|
"""Item."""
|
|
|
|
__tablename__ = "items"
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
name: Mapped[str] = mapped_column(String(50), nullable=False, unique=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def session() -> Generator[Session]:
|
|
"""Fresh in-memory DB + tables for each test."""
|
|
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False, future=True)
|
|
TestingBase.metadata.create_all(engine)
|
|
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as s:
|
|
yield s
|
|
|
|
|
|
def test_partial_failure_unique_constraint(session: Session) -> None:
|
|
"""Duplicate name should fail only for the conflicting row; others commit."""
|
|
objs = [Item(name="a"), Item(name="b"), Item(name="a"), Item(name="c")]
|
|
failures = safe_insert(objs, session)
|
|
|
|
assert len(failures) == 1
|
|
exc, failed_obj = failures[0]
|
|
assert isinstance(exc, Exception)
|
|
assert isinstance(failed_obj, Item)
|
|
assert failed_obj.name == "a"
|
|
|
|
rows = session.scalars(select(Item.name)).all()
|
|
assert sorted(rows) == ["a", "b", "c"]
|
|
assert rows.count("a") == 1
|
|
|
|
|
|
def test_all_good_inserts(session: Session) -> None:
|
|
"""No failures when all rows are valid."""
|
|
objs = [Item(name="x"), Item(name="y")]
|
|
failures = safe_insert(objs, session)
|
|
assert failures == []
|
|
|
|
rows = session.scalars(select(Item.name).where(Item.name.in_(("x", "y")))).all()
|
|
assert sorted(rows) == ["x", "y"]
|
|
|
|
|
|
def test_unmapped_object_raises(session: Session) -> None:
|
|
"""Non-ORM instances should raise TypeError immediately."""
|
|
with pytest.raises(TypeError):
|
|
safe_insert([object()], session)
|