mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
added safe_insert
This commit is contained in:
9
.vscode/settings.json
vendored
9
.vscode/settings.json
vendored
@@ -252,6 +252,7 @@
|
|||||||
"schemeless",
|
"schemeless",
|
||||||
"scrollback",
|
"scrollback",
|
||||||
"SECUREFOX",
|
"SECUREFOX",
|
||||||
|
"sessionmaker",
|
||||||
"sessionstore",
|
"sessionstore",
|
||||||
"shellcheck",
|
"shellcheck",
|
||||||
"signon",
|
"signon",
|
||||||
@@ -263,6 +264,7 @@
|
|||||||
"socialtracking",
|
"socialtracking",
|
||||||
"sonarr",
|
"sonarr",
|
||||||
"sponsorblock",
|
"sponsorblock",
|
||||||
|
"sqlalchemy",
|
||||||
"sqltools",
|
"sqltools",
|
||||||
"ssdp",
|
"ssdp",
|
||||||
"SSHOPTS",
|
"SSHOPTS",
|
||||||
@@ -325,5 +327,10 @@
|
|||||||
"zoxide",
|
"zoxide",
|
||||||
"zram",
|
"zram",
|
||||||
"zstd"
|
"zstd"
|
||||||
]
|
],
|
||||||
|
"python-envs.defaultEnvManager": "ms-python.python:system",
|
||||||
|
"python-envs.pythonProjects": [],
|
||||||
|
"python.testing.pytestArgs": ["tests"],
|
||||||
|
"python.testing.unittestEnabled": false,
|
||||||
|
"python.testing.pytestEnabled": true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@
|
|||||||
pytest-xdist
|
pytest-xdist
|
||||||
requests
|
requests
|
||||||
ruff
|
ruff
|
||||||
|
sqlalchemy
|
||||||
typer
|
typer
|
||||||
types-requests
|
types-requests
|
||||||
]
|
]
|
||||||
|
|||||||
57
python/database.py
Normal file
57
python/database.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
from sqlalchemy.exc import NoInspectionAvailable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def safe_insert(orm_objects: Sequence[object], session: Session) -> list[tuple[Exception, object]]:
|
||||||
|
"""Safer insert at allows for partial rollbacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orm_objects (Sequence[object]): Tables to insert.
|
||||||
|
session (Session): Database session.
|
||||||
|
"""
|
||||||
|
if unmapped := [orm_object for orm_object in orm_objects if not _is_mapped_instance(orm_object)]:
|
||||||
|
error = f"binary_search_insert expects ORM-mapped instances {unmapped}"
|
||||||
|
raise TypeError(error)
|
||||||
|
return _safe_insert(orm_objects, session)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_insert(objects: Sequence[object], session: Session) -> list[tuple[Exception, object]]:
|
||||||
|
exceptions: list[tuple[Exception, object]] = []
|
||||||
|
try:
|
||||||
|
session.add_all(objects)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
except Exception as error:
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
objects_len = len(objects)
|
||||||
|
if objects_len == 1:
|
||||||
|
logger.exception(objects)
|
||||||
|
return [(error, objects[0])]
|
||||||
|
|
||||||
|
middle = objects_len // 2
|
||||||
|
exceptions.extend(_safe_insert(objects=objects[:middle], session=session))
|
||||||
|
exceptions.extend(_safe_insert(objects=objects[middle:], session=session))
|
||||||
|
return exceptions
|
||||||
|
|
||||||
|
|
||||||
|
def _is_mapped_instance(obj: object) -> bool:
|
||||||
|
"""Return True if `obj` is a SQLAlchemy ORM-mapped instance."""
|
||||||
|
try:
|
||||||
|
inspect(obj) # raises NoInspectionAvailable if not mapped
|
||||||
|
except NoInspectionAvailable:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
67
tests/test_databasse.py
Normal file
67
tests/test_databasse.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""test_database."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Integer, String, create_engine, select
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker
|
||||||
|
|
||||||
|
from python.database import safe_insert
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
|
||||||
|
class TestingBase(DeclarativeBase):
|
||||||
|
"""TestingBase."""
|
||||||
|
|
||||||
|
|
||||||
|
class Item(TestingBase):
|
||||||
|
"""Item."""
|
||||||
|
|
||||||
|
__tablename__ = "items"
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(50), nullable=False, unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session() -> Generator[Session]:
|
||||||
|
"""Fresh in-memory DB + tables for each test."""
|
||||||
|
engine = create_engine("sqlite+pysqlite:///:memory:", echo=False, future=True)
|
||||||
|
TestingBase.metadata.create_all(engine)
|
||||||
|
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as s:
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_failure_unique_constraint(session: Session) -> None:
|
||||||
|
"""Duplicate name should fail only for the conflicting row; others commit."""
|
||||||
|
objs = [Item(name="a"), Item(name="b"), Item(name="a"), Item(name="c")]
|
||||||
|
failures = safe_insert(objs, session)
|
||||||
|
|
||||||
|
assert len(failures) == 1
|
||||||
|
exc, failed_obj = failures[0]
|
||||||
|
assert isinstance(exc, Exception)
|
||||||
|
assert isinstance(failed_obj, Item)
|
||||||
|
assert failed_obj.name == "a"
|
||||||
|
|
||||||
|
rows = session.scalars(select(Item.name)).all()
|
||||||
|
assert sorted(rows) == ["a", "b", "c"]
|
||||||
|
assert rows.count("a") == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_good_inserts(session: Session) -> None:
|
||||||
|
"""No failures when all rows are valid."""
|
||||||
|
objs = [Item(name="x"), Item(name="y")]
|
||||||
|
failures = safe_insert(objs, session)
|
||||||
|
assert failures == []
|
||||||
|
|
||||||
|
rows = session.scalars(select(Item.name).where(Item.name.in_(("x", "y")))).all()
|
||||||
|
assert sorted(rows) == ["x", "y"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_unmapped_object_raises(session: Session) -> None:
|
||||||
|
"""Non-ORM instances should raise TypeError immediately."""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
safe_insert([object()], session)
|
||||||
Reference in New Issue
Block a user