added safe_insert

This commit is contained in:
2025-11-07 06:09:03 -05:00
parent ddba7d1068
commit f713b8d4fa
4 changed files with 133 additions and 1 deletions

View File

@@ -252,6 +252,7 @@
"schemeless",
"scrollback",
"SECUREFOX",
"sessionmaker",
"sessionstore",
"shellcheck",
"signon",
@@ -263,6 +264,7 @@
"socialtracking",
"sonarr",
"sponsorblock",
"sqlalchemy",
"sqltools",
"ssdp",
"SSHOPTS",
@@ -325,5 +327,10 @@
"zoxide",
"zram",
"zstd"
]
],
"python-envs.defaultEnvManager": "ms-python.python:system",
"python-envs.pythonProjects": [],
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@@ -29,6 +29,7 @@
pytest-xdist
requests
ruff
sqlalchemy
typer
types-requests
]

57
python/database.py Normal file
View File

@@ -0,0 +1,57 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from sqlalchemy import inspect
from sqlalchemy.exc import NoInspectionAvailable
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
def safe_insert(orm_objects: Sequence[object], session: Session) -> list[tuple[Exception, object]]:
"""Safer insert at allows for partial rollbacks.
Args:
orm_objects (Sequence[object]): Tables to insert.
session (Session): Database session.
"""
if unmapped := [orm_object for orm_object in orm_objects if not _is_mapped_instance(orm_object)]:
error = f"binary_search_insert expects ORM-mapped instances {unmapped}"
raise TypeError(error)
return _safe_insert(orm_objects, session)
def _safe_insert(objects: Sequence[object], session: Session) -> list[tuple[Exception, object]]:
exceptions: list[tuple[Exception, object]] = []
try:
session.add_all(objects)
session.commit()
except Exception as error:
session.rollback()
objects_len = len(objects)
if objects_len == 1:
logger.exception(objects)
return [(error, objects[0])]
middle = objects_len // 2
exceptions.extend(_safe_insert(objects=objects[:middle], session=session))
exceptions.extend(_safe_insert(objects=objects[middle:], session=session))
return exceptions
def _is_mapped_instance(obj: object) -> bool:
"""Return True if `obj` is a SQLAlchemy ORM-mapped instance."""
try:
inspect(obj) # raises NoInspectionAvailable if not mapped
except NoInspectionAvailable:
return False
else:
return True

67
tests/test_databasse.py Normal file
View File

@@ -0,0 +1,67 @@
"""test_database."""
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from sqlalchemy import Integer, String, create_engine, select
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker
from python.database import safe_insert
if TYPE_CHECKING:
from collections.abc import Generator
class TestingBase(DeclarativeBase):
"""TestingBase."""
class Item(TestingBase):
"""Item."""
__tablename__ = "items"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), nullable=False, unique=True)
@pytest.fixture
def session() -> Generator[Session]:
"""Fresh in-memory DB + tables for each test."""
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False, future=True)
TestingBase.metadata.create_all(engine)
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as s:
yield s
def test_partial_failure_unique_constraint(session: Session) -> None:
"""Duplicate name should fail only for the conflicting row; others commit."""
objs = [Item(name="a"), Item(name="b"), Item(name="a"), Item(name="c")]
failures = safe_insert(objs, session)
assert len(failures) == 1
exc, failed_obj = failures[0]
assert isinstance(exc, Exception)
assert isinstance(failed_obj, Item)
assert failed_obj.name == "a"
rows = session.scalars(select(Item.name)).all()
assert sorted(rows) == ["a", "b", "c"]
assert rows.count("a") == 1
def test_all_good_inserts(session: Session) -> None:
"""No failures when all rows are valid."""
objs = [Item(name="x"), Item(name="y")]
failures = safe_insert(objs, session)
assert failures == []
rows = session.scalars(select(Item.name).where(Item.name.in_(("x", "y")))).all()
assert sorted(rows) == ["x", "y"]
def test_unmapped_object_raises(session: Session) -> None:
"""Non-ORM instances should raise TypeError immediately."""
with pytest.raises(TypeError):
safe_insert([object()], session)