"""LLM tool calling support for audiobook metadata resolution.""" from __future__ import annotations import json import re import time from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING from sqlalchemy import or_, select from python.orm.richie import Audiobook, AudiobookAuthor, AudiobookSeries if TYPE_CHECKING: from pathlib import Path from sqlalchemy.orm import Session from python.tools.audiobook.metadata_agent import AgentConfig CATALOG_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+)*$") TITLE_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$") LogWriter = Callable[..., None] class MetadataResolutionError(ValueError): """Metadata resolution failed validation.""" @dataclass(frozen=True) class EnsuredBook: """Book row plus whether it was created.""" book: Audiobook action: str class CatalogToolRegistry: """Controlled catalog tools exposed to the metadata model.""" def __init__( self, session: Session, log_path: Path, config: AgentConfig, write_log: LogWriter, ) -> None: """Create a registry bound to one database session and audit log.""" self.session = session self.log_path = log_path self.config = config self.write_log = write_log self.seen_author_ids: set[int] = set() self.seen_series_ids: set[int] = set() self.seen_book_ids: set[int] = set() self.created_author_ids: set[int] = set() self.created_series_ids: set[int] = set() self.created_book_ids: set[int] = set() def tool_schemas(self) -> list[dict[str, object]]: """Return Ollama tool schemas.""" schemas = [ { "type": "function", "function": { "name": "search_authors", "description": "Search canonical audiobook authors by slug or noisy source text.", "parameters": { "type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "search_series", "description": "Search canonical audiobook series by slug or noisy source text.", "parameters": { "type": "object", "properties": { "query": {"type": "string"}, "author_id": {"type": ["integer", "null"]}, }, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "search_books", "description": "Search canonical audiobook titles with optional author and series filters.", "parameters": { "type": "object", "properties": { "query": {"type": "string"}, "author_id": {"type": ["integer", "null"]}, "series_id": {"type": ["integer", "null"]}, }, "required": ["query"], }, }, }, { "type": "function", "function": { "name": "ensure_author", "description": "Normalize an author name to a catalog slug, then return or create that author.", "parameters": { "type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"], }, }, }, { "type": "function", "function": { "name": "ensure_series", "description": "Normalize a series name to a catalog slug, then return or create it for an author.", "parameters": { "type": "object", "properties": { "name": {"type": "string"}, "author_id": {"type": "integer"}, }, "required": ["name", "author_id"], }, }, }, { "type": "function", "function": { "name": "ensure_book", "description": "Normalize a title to a book slug, then return or create it for an author/series.", "parameters": { "type": "object", "properties": { "title": {"type": "string"}, "author_id": {"type": "integer"}, "series_id": {"type": ["integer", "null"]}, "series_index": {"type": "integer"}, }, "required": ["title", "author_id", "series_id", "series_index"], }, }, }, ] enabled_tool_names = set(self.config.tool_names) return [schema for schema in schemas if schema["function"]["name"] in enabled_tool_names] def run(self, name: str, arguments: dict[str, object]) -> list[dict[str, object]]: """Run one catalog tool and audit the call.""" handlers = { "search_authors": self.run_search_authors, "search_series": self.run_search_series, "search_books": self.run_search_books, "ensure_author": self.run_ensure_author, "ensure_series": self.run_ensure_series, "ensure_book": self.run_ensure_book, } handler = handlers.get(name) if handler is None: self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="unknown_tool") msg = f"Unknown audiobook metadata tool: {name}" raise MetadataResolutionError(msg) if name not in self.config.tool_names: self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="tool_not_enabled") msg = f"Audiobook metadata tool is not enabled: {name}" raise MetadataResolutionError(msg) started = time.perf_counter() self.write_log(self.log_path, "tool_call", tool=name, arguments=arguments) result = handler(arguments) duration_ms = round((time.perf_counter() - started) * 1000, 3) self.write_log( self.log_path, "tool_result", tool=name, duration_ms=duration_ms, result_count=len(result), preview=result[:3], ) return result def get_author(self, author_id: int) -> AudiobookAuthor | None: """Return an author by id.""" return self.session.get(AudiobookAuthor, author_id) def get_book(self, book_id: int) -> Audiobook | None: """Return a book by id.""" return self.session.get(Audiobook, book_id) def get_series(self, series_id: int) -> AudiobookSeries | None: """Return a series by id.""" return self.session.get(AudiobookSeries, series_id) def prune_unused_created_rows(self, *, author_id: int, book_id: int | None, series_id: int | None) -> None: """Remove catalog rows created during this run but not used by final metadata.""" used_book_ids = {book_id} if book_id is not None else set() for created_book_id in self.created_book_ids - used_book_ids: if book := self.get_book(created_book_id): self.session.delete(book) self.session.flush() used_series_ids = {series_id} if series_id is not None else set() for created_series_id in self.created_series_ids - used_series_ids: series = self.get_series(created_series_id) if series and not series.books: self.session.delete(series) self.session.flush() for created_author_id in self.created_author_ids - {author_id}: author = self.get_author(created_author_id) if author and not author.books and not author.series: self.session.delete(author) def run_search_authors(self, arguments: dict[str, object]) -> list[dict[str, object]]: """Search authors from tool arguments and remember returned ids.""" query = required_string(arguments, "query") statement = select(AudiobookAuthor).order_by(AudiobookAuthor.name).limit(self.config.max_tool_results) if terms := query_terms(query): statement = statement.where(or_(*(AudiobookAuthor.name.ilike(f"%{term}%") for term in terms))) authors = self.session.scalars(statement).all() self.seen_author_ids.update(author.id for author in authors) return [{"id": author.id, "name": author.name} for author in authors] def run_search_series(self, arguments: dict[str, object]) -> list[dict[str, object]]: """Search series from tool arguments and remember returned ids.""" query = required_string(arguments, "query") author_id = optional_int(arguments.get("author_id"), "author_id") statement = select(AudiobookSeries).order_by(AudiobookSeries.name).limit(self.config.max_tool_results) if terms := query_terms(query): statement = statement.where(or_(*(AudiobookSeries.name.ilike(f"%{term}%") for term in terms))) if author_id is not None: statement = statement.where(AudiobookSeries.author_id == author_id) series_rows = self.session.scalars(statement).all() self.seen_series_ids.update(series.id for series in series_rows) self.seen_author_ids.update(series.author_id for series in series_rows) return [ { "id": series.id, "name": series.name, "author_id": series.author_id, "author": series.author.name, } for series in series_rows ] def run_search_books(self, arguments: dict[str, object]) -> list[dict[str, object]]: """Search books from tool arguments and remember returned ids.""" query = required_string(arguments, "query") author_id = optional_int(arguments.get("author_id"), "author_id") series_id = optional_int(arguments.get("series_id"), "series_id") statement = select(Audiobook).order_by(Audiobook.title).limit(self.config.max_tool_results) if terms := query_terms(query): statement = statement.where(or_(*(Audiobook.title.ilike(f"%{term}%") for term in terms))) if author_id is not None: statement = statement.where(Audiobook.author_id == author_id) if series_id is not None: statement = statement.where(Audiobook.series_id == series_id) books = self.session.scalars(statement).all() self.seen_book_ids.update(book.id for book in books) self.seen_author_ids.update(book.author_id for book in books) self.seen_series_ids.update(book.series_id for book in books if book.series_id is not None) return [ { "id": book.id, "title": book.title, "author_id": book.author_id, "author": book.author.name, "series_id": book.series_id, "series": book.series.name if book.series else self.config.standalone_series, "series_index": book.series_index, } for book in books ] def run_ensure_author(self, arguments: dict[str, object]) -> list[dict[str, object]]: """Ensure an author from tool arguments and return a tool result.""" name = normalize_catalog_slug(required_string(arguments, "name")) validate_catalog_slug(name, "author") author = self.session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name)) action = "existing" if author is None: author = AudiobookAuthor(name=name) self.session.add(author) self.session.flush() self.created_author_ids.add(author.id) action = "created" self.seen_author_ids.add(author.id) return [{"id": author.id, "name": author.name, "action": action}] def run_ensure_series(self, arguments: dict[str, object]) -> list[dict[str, object]]: """Ensure a series from tool arguments and return a tool result.""" name = normalize_catalog_slug(required_string(arguments, "name")) author_id = required_int(arguments, "author_id") validate_catalog_slug(name, "series") author = self.required_author(author_id) series = self.session.scalar( select(AudiobookSeries).where( AudiobookSeries.name == name, AudiobookSeries.author_id == author.id, ), ) action = "existing" if series is None: series = AudiobookSeries(name=name, author=author) self.session.add(series) self.session.flush() self.created_series_ids.add(series.id) action = "created" self.seen_author_ids.add(author.id) self.seen_series_ids.add(series.id) return [self.series_result(series, action)] def run_ensure_book(self, arguments: dict[str, object]) -> list[dict[str, object]]: """Ensure a book from tool arguments and return a tool result.""" title = required_string(arguments, "title") author_id = required_int(arguments, "author_id") series_id = optional_int(arguments.get("series_id"), "series_id") series_index = required_int(arguments, "series_index") ensured = self.ensure_book(title, author_id, series_id, series_index) return [self.book_result(ensured.book, ensured.action)] def ensure_book( self, title: str, author_id: int, series_id: int | None, series_index: int, ) -> EnsuredBook: """Return an existing book row, or create it after validating ownership.""" title = normalize_title_slug(title) validate_title_slug(title) author = self.required_author(author_id) series = None if series_id is None: if series_index != 0: msg = "standalone books must use series_index 0" raise MetadataResolutionError(msg) else: series = self.required_series(series_id) if series.author_id != author.id: msg = f"series_id {series_id} does not belong to author_id {author_id}" raise MetadataResolutionError(msg) if series_index <= 0: msg = "series books must use a positive series_index" raise MetadataResolutionError(msg) statement = select(Audiobook).where( Audiobook.title == title, Audiobook.author_id == author.id, ) if series is None: statement = statement.where(Audiobook.series_id.is_(None)) else: statement = statement.where(Audiobook.series_id == series.id) book = self.session.scalar(statement) if book is None: book = Audiobook(title=title, author=author, series=series, series_index=series_index) self.session.add(book) self.session.flush() self.created_book_ids.add(book.id) action = "created" else: action = "existing" self.seen_book_ids.add(book.id) self.seen_author_ids.add(author.id) if book.series_id is not None: self.seen_series_ids.add(book.series_id) return EnsuredBook(book=book, action=action) def required_author(self, author_id: int) -> AudiobookAuthor: """Return an author or fail metadata resolution.""" author = self.get_author(author_id) if author is None: msg = f"author_id {author_id} does not exist" raise MetadataResolutionError(msg) return author def required_series(self, series_id: int) -> AudiobookSeries: """Return a series or fail metadata resolution.""" series = self.get_series(series_id) if series is None: msg = f"series_id {series_id} does not exist" raise MetadataResolutionError(msg) return series def series_result(self, series: AudiobookSeries, action: str) -> dict[str, object]: """Build a normalized series tool result.""" return { "id": series.id, "name": series.name, "author_id": series.author_id, "author": series.author.name, "action": action, } def book_result(self, book: Audiobook, action: str) -> dict[str, object]: """Build a normalized book tool result.""" return { "id": book.id, "title": book.title, "author_id": book.author_id, "author": book.author.name, "series_id": book.series_id, "series": book.series.name if book.series else self.config.standalone_series, "series_index": book.series_index, "action": action, } def run_tool_calls( messages: list[dict[str, object]], message: dict[str, object], tool_calls: list[tuple[str, dict[str, object]]], registry: CatalogToolRegistry, log_path: Path, write_log: LogWriter, ) -> str | None: """Run tool calls, append tool messages, and return fatal error text when stopped.""" messages.append(message) for tool_name, arguments in tool_calls: try: tool_result = registry.run(tool_name, arguments) except MetadataResolutionError as error: if is_fatal_tool_error(error): return str(error) write_log(log_path, "tool_error", tool=tool_name, arguments=arguments, error=str(error)) messages.append( { "role": "tool", "tool_name": tool_name, "content": json.dumps({"error": str(error)}, sort_keys=True), }, ) continue messages.append( { "role": "tool", "tool_name": tool_name, "content": json.dumps(tool_result, sort_keys=True), }, ) return None def parse_tool_calls(message: dict[str, object]) -> list[tuple[str, dict[str, object]]]: """Parse Ollama tool calls from a response message.""" raw_tool_calls = message.get("tool_calls") or [] if not isinstance(raw_tool_calls, list): msg = "tool_calls must be a list" raise MetadataResolutionError(msg) tool_calls = [] for raw_call in raw_tool_calls: if not isinstance(raw_call, dict): msg = "tool call must be an object" raise MetadataResolutionError(msg) function = raw_call.get("function") if not isinstance(function, dict): msg = "tool call is missing function" raise MetadataResolutionError(msg) name = function.get("name") if not isinstance(name, str) or not name: msg = "tool call is missing function name" raise MetadataResolutionError(msg) arguments = parse_tool_arguments(function.get("arguments", {})) tool_calls.append((name, arguments)) return tool_calls def parse_tool_arguments(raw_arguments: object) -> dict[str, object]: """Parse tool call arguments returned by Ollama.""" if isinstance(raw_arguments, dict): return {str(key): value for key, value in raw_arguments.items()} if isinstance(raw_arguments, str): parsed = json.loads(raw_arguments) if raw_arguments else {} if isinstance(parsed, dict): return {str(key): value for key, value in parsed.items()} msg = "tool arguments must be an object" raise MetadataResolutionError(msg) def validate_title_slug(title: str) -> None: """Validate a canonical book title slug.""" if not TITLE_SLUG_PATTERN.fullmatch(title): msg = f"title slug is invalid: {title}" raise MetadataResolutionError(msg) def validate_catalog_slug(value: str, label: str) -> None: """Validate a canonical catalog slug.""" if not CATALOG_SLUG_PATTERN.fullmatch(value): msg = f"{label} slug is invalid: {value}" raise MetadataResolutionError(msg) def normalize_catalog_slug(value: str) -> str: """Normalize noisy catalog names into lower snake-case slugs.""" return re.sub(r"[^a-z0-9]+", "_", value.strip().casefold()).strip("_") def normalize_title_slug(value: str) -> str: """Normalize noisy book titles into lower kebab-case slugs.""" return re.sub(r"[^a-z0-9]+", "-", value.strip().casefold()).strip("-") def is_fatal_tool_error(error: MetadataResolutionError) -> bool: """Return whether a tool error should stop the agent immediately.""" message = str(error) return message.startswith( ( "Unknown audiobook metadata tool", "Audiobook metadata tool is not enabled", ), ) def query_terms(query: str) -> tuple[str, ...]: """Return text variants useful for matching noisy audiobook metadata.""" normalized = query.strip().casefold() underscore_slug = normalize_catalog_slug(normalized) hyphen_slug = normalize_title_slug(normalized) return tuple(dict.fromkeys(term for term in (normalized, underscore_slug, hyphen_slug) if term)) def required_string(data: dict[str, object], key: str) -> str: """Read a required string field.""" value = data.get(key) if not isinstance(value, str) or not value.strip(): msg = f"{key} must be a non-empty string" raise MetadataResolutionError(msg) return value.strip() def required_int(data: dict[str, object], key: str) -> int: """Read a required integer field.""" value = data.get(key) if isinstance(value, bool) or not isinstance(value, int): msg = f"{key} must be an integer" raise MetadataResolutionError(msg) return value def optional_int(value: object, key: str) -> int | None: """Read an optional integer field.""" if value is None: return None if isinstance(value, bool) or not isinstance(value, int): msg = f"{key} must be an integer or null" raise MetadataResolutionError(msg) return value