From f713b8d4fa80fe84bee28f8a409d6778590ffee0 Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Fri, 7 Nov 2025 06:09:03 -0500 Subject: [PATCH] added safe_insert --- .vscode/settings.json | 9 +++++- overlays/default.nix | 1 + python/database.py | 57 +++++++++++++++++++++++++++++++++++ tests/test_databasse.py | 67 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 1 deletion(-) create mode 100644 python/database.py create mode 100644 tests/test_databasse.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 093f01a..4bbd226 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -252,6 +252,7 @@ "schemeless", "scrollback", "SECUREFOX", + "sessionmaker", "sessionstore", "shellcheck", "signon", @@ -263,6 +264,7 @@ "socialtracking", "sonarr", "sponsorblock", + "sqlalchemy", "sqltools", "ssdp", "SSHOPTS", @@ -325,5 +327,10 @@ "zoxide", "zram", "zstd" - ] + ], + "python-envs.defaultEnvManager": "ms-python.python:system", + "python-envs.pythonProjects": [], + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/overlays/default.nix b/overlays/default.nix index dded516..2113d41 100644 --- a/overlays/default.nix +++ b/overlays/default.nix @@ -29,6 +29,7 @@ pytest-xdist requests ruff + sqlalchemy typer types-requests ] diff --git a/python/database.py b/python/database.py new file mode 100644 index 0000000..e510715 --- /dev/null +++ b/python/database.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from sqlalchemy import inspect +from sqlalchemy.exc import NoInspectionAvailable + +if TYPE_CHECKING: + from collections.abc import Sequence + + from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + + +def safe_insert(orm_objects: Sequence[object], session: Session) -> list[tuple[Exception, object]]: + """Safer insert at allows for partial rollbacks. + + Args: + orm_objects (Sequence[object]): Tables to insert. + session (Session): Database session. + """ + if unmapped := [orm_object for orm_object in orm_objects if not _is_mapped_instance(orm_object)]: + error = f"binary_search_insert expects ORM-mapped instances {unmapped}" + raise TypeError(error) + return _safe_insert(orm_objects, session) + + +def _safe_insert(objects: Sequence[object], session: Session) -> list[tuple[Exception, object]]: + exceptions: list[tuple[Exception, object]] = [] + try: + session.add_all(objects) + session.commit() + + except Exception as error: + session.rollback() + + objects_len = len(objects) + if objects_len == 1: + logger.exception(objects) + return [(error, objects[0])] + + middle = objects_len // 2 + exceptions.extend(_safe_insert(objects=objects[:middle], session=session)) + exceptions.extend(_safe_insert(objects=objects[middle:], session=session)) + return exceptions + + +def _is_mapped_instance(obj: object) -> bool: + """Return True if `obj` is a SQLAlchemy ORM-mapped instance.""" + try: + inspect(obj) # raises NoInspectionAvailable if not mapped + except NoInspectionAvailable: + return False + else: + return True diff --git a/tests/test_databasse.py b/tests/test_databasse.py new file mode 100644 index 0000000..af24847 --- /dev/null +++ b/tests/test_databasse.py @@ -0,0 +1,67 @@ +"""test_database.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from sqlalchemy import Integer, String, create_engine, select +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker + +from python.database import safe_insert + +if TYPE_CHECKING: + from collections.abc import Generator + + +class TestingBase(DeclarativeBase): + """TestingBase.""" + + +class Item(TestingBase): + """Item.""" + + __tablename__ = "items" + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(50), nullable=False, unique=True) + + +@pytest.fixture +def session() -> Generator[Session]: + """Fresh in-memory DB + tables for each test.""" + engine = create_engine("sqlite+pysqlite:///:memory:", echo=False, future=True) + TestingBase.metadata.create_all(engine) + with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as s: + yield s + + +def test_partial_failure_unique_constraint(session: Session) -> None: + """Duplicate name should fail only for the conflicting row; others commit.""" + objs = [Item(name="a"), Item(name="b"), Item(name="a"), Item(name="c")] + failures = safe_insert(objs, session) + + assert len(failures) == 1 + exc, failed_obj = failures[0] + assert isinstance(exc, Exception) + assert isinstance(failed_obj, Item) + assert failed_obj.name == "a" + + rows = session.scalars(select(Item.name)).all() + assert sorted(rows) == ["a", "b", "c"] + assert rows.count("a") == 1 + + +def test_all_good_inserts(session: Session) -> None: + """No failures when all rows are valid.""" + objs = [Item(name="x"), Item(name="y")] + failures = safe_insert(objs, session) + assert failures == [] + + rows = session.scalars(select(Item.name).where(Item.name.in_(("x", "y")))).all() + assert sorted(rows) == ["x", "y"] + + +def test_unmapped_object_raises(session: Session) -> None: + """Non-ORM instances should raise TypeError immediately.""" + with pytest.raises(TypeError): + safe_insert([object()], session)