added llm_tool_calling.py
This commit is contained in:
@@ -0,0 +1,565 @@
|
|||||||
|
"""LLM tool calling support for audiobook metadata resolution."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import or_, select
|
||||||
|
|
||||||
|
from python.orm.richie import Audiobook, AudiobookAuthor, AudiobookSeries
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from python.tools.audiobook.metadata_agent import AgentConfig
|
||||||
|
|
||||||
|
CATALOG_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+)*$")
|
||||||
|
TITLE_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
|
||||||
|
|
||||||
|
LogWriter = Callable[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataResolutionError(ValueError):
|
||||||
|
"""Metadata resolution failed validation."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class EnsuredBook:
|
||||||
|
"""Book row plus whether it was created."""
|
||||||
|
|
||||||
|
book: Audiobook
|
||||||
|
action: str
|
||||||
|
|
||||||
|
|
||||||
|
class CatalogToolRegistry:
|
||||||
|
"""Controlled catalog tools exposed to the metadata model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
log_path: Path,
|
||||||
|
config: AgentConfig,
|
||||||
|
write_log: LogWriter,
|
||||||
|
) -> None:
|
||||||
|
"""Create a registry bound to one database session and audit log."""
|
||||||
|
self.session = session
|
||||||
|
self.log_path = log_path
|
||||||
|
self.config = config
|
||||||
|
self.write_log = write_log
|
||||||
|
self.seen_author_ids: set[int] = set()
|
||||||
|
self.seen_series_ids: set[int] = set()
|
||||||
|
self.seen_book_ids: set[int] = set()
|
||||||
|
self.created_author_ids: set[int] = set()
|
||||||
|
self.created_series_ids: set[int] = set()
|
||||||
|
self.created_book_ids: set[int] = set()
|
||||||
|
|
||||||
|
def tool_schemas(self) -> list[dict[str, object]]:
|
||||||
|
"""Return Ollama tool schemas."""
|
||||||
|
schemas = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_authors",
|
||||||
|
"description": "Search canonical audiobook authors by slug or noisy source text.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_series",
|
||||||
|
"description": "Search canonical audiobook series by slug or noisy source text.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"author_id": {"type": ["integer", "null"]},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_books",
|
||||||
|
"description": "Search canonical audiobook titles with optional author and series filters.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"author_id": {"type": ["integer", "null"]},
|
||||||
|
"series_id": {"type": ["integer", "null"]},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "ensure_author",
|
||||||
|
"description": "Normalize an author name to a catalog slug, then return or create that author.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string"}},
|
||||||
|
"required": ["name"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "ensure_series",
|
||||||
|
"description": "Normalize a series name to a catalog slug, then return or create it for an author.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"author_id": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "author_id"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "ensure_book",
|
||||||
|
"description": "Normalize a title to a book slug, then return or create it for an author/series.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"author_id": {"type": "integer"},
|
||||||
|
"series_id": {"type": ["integer", "null"]},
|
||||||
|
"series_index": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["title", "author_id", "series_id", "series_index"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
enabled_tool_names = set(self.config.tool_names)
|
||||||
|
return [schema for schema in schemas if schema["function"]["name"] in enabled_tool_names]
|
||||||
|
|
||||||
|
def run(self, name: str, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Run one catalog tool and audit the call."""
|
||||||
|
handlers = {
|
||||||
|
"search_authors": self.run_search_authors,
|
||||||
|
"search_series": self.run_search_series,
|
||||||
|
"search_books": self.run_search_books,
|
||||||
|
"ensure_author": self.run_ensure_author,
|
||||||
|
"ensure_series": self.run_ensure_series,
|
||||||
|
"ensure_book": self.run_ensure_book,
|
||||||
|
}
|
||||||
|
handler = handlers.get(name)
|
||||||
|
if handler is None:
|
||||||
|
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="unknown_tool")
|
||||||
|
msg = f"Unknown audiobook metadata tool: {name}"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
if name not in self.config.tool_names:
|
||||||
|
self.write_log(self.log_path, "tool_error", tool=name, arguments=arguments, error="tool_not_enabled")
|
||||||
|
msg = f"Audiobook metadata tool is not enabled: {name}"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
|
||||||
|
started = time.perf_counter()
|
||||||
|
self.write_log(self.log_path, "tool_call", tool=name, arguments=arguments)
|
||||||
|
result = handler(arguments)
|
||||||
|
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
||||||
|
self.write_log(
|
||||||
|
self.log_path,
|
||||||
|
"tool_result",
|
||||||
|
tool=name,
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
result_count=len(result),
|
||||||
|
preview=result[:3],
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_author(self, author_id: int) -> AudiobookAuthor | None:
|
||||||
|
"""Return an author by id."""
|
||||||
|
return self.session.get(AudiobookAuthor, author_id)
|
||||||
|
|
||||||
|
def get_book(self, book_id: int) -> Audiobook | None:
|
||||||
|
"""Return a book by id."""
|
||||||
|
return self.session.get(Audiobook, book_id)
|
||||||
|
|
||||||
|
def get_series(self, series_id: int) -> AudiobookSeries | None:
|
||||||
|
"""Return a series by id."""
|
||||||
|
return self.session.get(AudiobookSeries, series_id)
|
||||||
|
|
||||||
|
def prune_unused_created_rows(self, *, author_id: int, book_id: int | None, series_id: int | None) -> None:
|
||||||
|
"""Remove catalog rows created during this run but not used by final metadata."""
|
||||||
|
used_book_ids = {book_id} if book_id is not None else set()
|
||||||
|
for created_book_id in self.created_book_ids - used_book_ids:
|
||||||
|
if book := self.get_book(created_book_id):
|
||||||
|
self.session.delete(book)
|
||||||
|
|
||||||
|
self.session.flush()
|
||||||
|
used_series_ids = {series_id} if series_id is not None else set()
|
||||||
|
for created_series_id in self.created_series_ids - used_series_ids:
|
||||||
|
series = self.get_series(created_series_id)
|
||||||
|
if series and not series.books:
|
||||||
|
self.session.delete(series)
|
||||||
|
|
||||||
|
self.session.flush()
|
||||||
|
for created_author_id in self.created_author_ids - {author_id}:
|
||||||
|
author = self.get_author(created_author_id)
|
||||||
|
if author and not author.books and not author.series:
|
||||||
|
self.session.delete(author)
|
||||||
|
|
||||||
|
def run_search_authors(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Search authors from tool arguments and remember returned ids."""
|
||||||
|
query = required_string(arguments, "query")
|
||||||
|
statement = select(AudiobookAuthor).order_by(AudiobookAuthor.name).limit(self.config.max_tool_results)
|
||||||
|
if terms := query_terms(query):
|
||||||
|
statement = statement.where(or_(*(AudiobookAuthor.name.ilike(f"%{term}%") for term in terms)))
|
||||||
|
|
||||||
|
authors = self.session.scalars(statement).all()
|
||||||
|
self.seen_author_ids.update(author.id for author in authors)
|
||||||
|
return [{"id": author.id, "name": author.name} for author in authors]
|
||||||
|
|
||||||
|
def run_search_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Search series from tool arguments and remember returned ids."""
|
||||||
|
query = required_string(arguments, "query")
|
||||||
|
author_id = optional_int(arguments.get("author_id"), "author_id")
|
||||||
|
statement = select(AudiobookSeries).order_by(AudiobookSeries.name).limit(self.config.max_tool_results)
|
||||||
|
if terms := query_terms(query):
|
||||||
|
statement = statement.where(or_(*(AudiobookSeries.name.ilike(f"%{term}%") for term in terms)))
|
||||||
|
if author_id is not None:
|
||||||
|
statement = statement.where(AudiobookSeries.author_id == author_id)
|
||||||
|
|
||||||
|
series_rows = self.session.scalars(statement).all()
|
||||||
|
self.seen_series_ids.update(series.id for series in series_rows)
|
||||||
|
self.seen_author_ids.update(series.author_id for series in series_rows)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": series.id,
|
||||||
|
"name": series.name,
|
||||||
|
"author_id": series.author_id,
|
||||||
|
"author": series.author.name,
|
||||||
|
}
|
||||||
|
for series in series_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def run_search_books(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Search books from tool arguments and remember returned ids."""
|
||||||
|
query = required_string(arguments, "query")
|
||||||
|
author_id = optional_int(arguments.get("author_id"), "author_id")
|
||||||
|
series_id = optional_int(arguments.get("series_id"), "series_id")
|
||||||
|
statement = select(Audiobook).order_by(Audiobook.title).limit(self.config.max_tool_results)
|
||||||
|
if terms := query_terms(query):
|
||||||
|
statement = statement.where(or_(*(Audiobook.title.ilike(f"%{term}%") for term in terms)))
|
||||||
|
if author_id is not None:
|
||||||
|
statement = statement.where(Audiobook.author_id == author_id)
|
||||||
|
if series_id is not None:
|
||||||
|
statement = statement.where(Audiobook.series_id == series_id)
|
||||||
|
|
||||||
|
books = self.session.scalars(statement).all()
|
||||||
|
self.seen_book_ids.update(book.id for book in books)
|
||||||
|
self.seen_author_ids.update(book.author_id for book in books)
|
||||||
|
self.seen_series_ids.update(book.series_id for book in books if book.series_id is not None)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": book.id,
|
||||||
|
"title": book.title,
|
||||||
|
"author_id": book.author_id,
|
||||||
|
"author": book.author.name,
|
||||||
|
"series_id": book.series_id,
|
||||||
|
"series": book.series.name if book.series else self.config.standalone_series,
|
||||||
|
"series_index": book.series_index,
|
||||||
|
}
|
||||||
|
for book in books
|
||||||
|
]
|
||||||
|
|
||||||
|
def run_ensure_author(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Ensure an author from tool arguments and return a tool result."""
|
||||||
|
name = normalize_catalog_slug(required_string(arguments, "name"))
|
||||||
|
validate_catalog_slug(name, "author")
|
||||||
|
author = self.session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
|
||||||
|
action = "existing"
|
||||||
|
if author is None:
|
||||||
|
author = AudiobookAuthor(name=name)
|
||||||
|
self.session.add(author)
|
||||||
|
self.session.flush()
|
||||||
|
self.created_author_ids.add(author.id)
|
||||||
|
action = "created"
|
||||||
|
|
||||||
|
self.seen_author_ids.add(author.id)
|
||||||
|
return [{"id": author.id, "name": author.name, "action": action}]
|
||||||
|
|
||||||
|
def run_ensure_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Ensure a series from tool arguments and return a tool result."""
|
||||||
|
name = normalize_catalog_slug(required_string(arguments, "name"))
|
||||||
|
author_id = required_int(arguments, "author_id")
|
||||||
|
validate_catalog_slug(name, "series")
|
||||||
|
author = self.required_author(author_id)
|
||||||
|
series = self.session.scalar(
|
||||||
|
select(AudiobookSeries).where(
|
||||||
|
AudiobookSeries.name == name,
|
||||||
|
AudiobookSeries.author_id == author.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
action = "existing"
|
||||||
|
if series is None:
|
||||||
|
series = AudiobookSeries(name=name, author=author)
|
||||||
|
self.session.add(series)
|
||||||
|
self.session.flush()
|
||||||
|
self.created_series_ids.add(series.id)
|
||||||
|
action = "created"
|
||||||
|
|
||||||
|
self.seen_author_ids.add(author.id)
|
||||||
|
self.seen_series_ids.add(series.id)
|
||||||
|
return [self.series_result(series, action)]
|
||||||
|
|
||||||
|
def run_ensure_book(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
||||||
|
"""Ensure a book from tool arguments and return a tool result."""
|
||||||
|
title = required_string(arguments, "title")
|
||||||
|
author_id = required_int(arguments, "author_id")
|
||||||
|
series_id = optional_int(arguments.get("series_id"), "series_id")
|
||||||
|
series_index = required_int(arguments, "series_index")
|
||||||
|
ensured = self.ensure_book(title, author_id, series_id, series_index)
|
||||||
|
return [self.book_result(ensured.book, ensured.action)]
|
||||||
|
|
||||||
|
def ensure_book(
|
||||||
|
self,
|
||||||
|
title: str,
|
||||||
|
author_id: int,
|
||||||
|
series_id: int | None,
|
||||||
|
series_index: int,
|
||||||
|
) -> EnsuredBook:
|
||||||
|
"""Return an existing book row, or create it after validating ownership."""
|
||||||
|
title = normalize_title_slug(title)
|
||||||
|
validate_title_slug(title)
|
||||||
|
author = self.required_author(author_id)
|
||||||
|
series = None
|
||||||
|
if series_id is None:
|
||||||
|
if series_index != 0:
|
||||||
|
msg = "standalone books must use series_index 0"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
else:
|
||||||
|
series = self.required_series(series_id)
|
||||||
|
if series.author_id != author.id:
|
||||||
|
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
if series_index <= 0:
|
||||||
|
msg = "series books must use a positive series_index"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
|
||||||
|
statement = select(Audiobook).where(
|
||||||
|
Audiobook.title == title,
|
||||||
|
Audiobook.author_id == author.id,
|
||||||
|
)
|
||||||
|
if series is None:
|
||||||
|
statement = statement.where(Audiobook.series_id.is_(None))
|
||||||
|
else:
|
||||||
|
statement = statement.where(Audiobook.series_id == series.id)
|
||||||
|
book = self.session.scalar(statement)
|
||||||
|
if book is None:
|
||||||
|
book = Audiobook(title=title, author=author, series=series, series_index=series_index)
|
||||||
|
self.session.add(book)
|
||||||
|
self.session.flush()
|
||||||
|
self.created_book_ids.add(book.id)
|
||||||
|
action = "created"
|
||||||
|
else:
|
||||||
|
action = "existing"
|
||||||
|
|
||||||
|
self.seen_book_ids.add(book.id)
|
||||||
|
self.seen_author_ids.add(author.id)
|
||||||
|
if book.series_id is not None:
|
||||||
|
self.seen_series_ids.add(book.series_id)
|
||||||
|
return EnsuredBook(book=book, action=action)
|
||||||
|
|
||||||
|
def required_author(self, author_id: int) -> AudiobookAuthor:
|
||||||
|
"""Return an author or fail metadata resolution."""
|
||||||
|
author = self.get_author(author_id)
|
||||||
|
if author is None:
|
||||||
|
msg = f"author_id {author_id} does not exist"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
return author
|
||||||
|
|
||||||
|
def required_series(self, series_id: int) -> AudiobookSeries:
|
||||||
|
"""Return a series or fail metadata resolution."""
|
||||||
|
series = self.get_series(series_id)
|
||||||
|
if series is None:
|
||||||
|
msg = f"series_id {series_id} does not exist"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
return series
|
||||||
|
|
||||||
|
def series_result(self, series: AudiobookSeries, action: str) -> dict[str, object]:
|
||||||
|
"""Build a normalized series tool result."""
|
||||||
|
return {
|
||||||
|
"id": series.id,
|
||||||
|
"name": series.name,
|
||||||
|
"author_id": series.author_id,
|
||||||
|
"author": series.author.name,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
|
||||||
|
def book_result(self, book: Audiobook, action: str) -> dict[str, object]:
|
||||||
|
"""Build a normalized book tool result."""
|
||||||
|
return {
|
||||||
|
"id": book.id,
|
||||||
|
"title": book.title,
|
||||||
|
"author_id": book.author_id,
|
||||||
|
"author": book.author.name,
|
||||||
|
"series_id": book.series_id,
|
||||||
|
"series": book.series.name if book.series else self.config.standalone_series,
|
||||||
|
"series_index": book.series_index,
|
||||||
|
"action": action,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_tool_calls(
|
||||||
|
messages: list[dict[str, object]],
|
||||||
|
message: dict[str, object],
|
||||||
|
tool_calls: list[tuple[str, dict[str, object]]],
|
||||||
|
registry: CatalogToolRegistry,
|
||||||
|
log_path: Path,
|
||||||
|
write_log: LogWriter,
|
||||||
|
) -> str | None:
|
||||||
|
"""Run tool calls, append tool messages, and return fatal error text when stopped."""
|
||||||
|
messages.append(message)
|
||||||
|
for tool_name, arguments in tool_calls:
|
||||||
|
try:
|
||||||
|
tool_result = registry.run(tool_name, arguments)
|
||||||
|
except MetadataResolutionError as error:
|
||||||
|
if is_fatal_tool_error(error):
|
||||||
|
return str(error)
|
||||||
|
write_log(log_path, "tool_error", tool=tool_name, arguments=arguments, error=str(error))
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"content": json.dumps({"error": str(error)}, sort_keys=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"content": json.dumps(tool_result, sort_keys=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_calls(message: dict[str, object]) -> list[tuple[str, dict[str, object]]]:
|
||||||
|
"""Parse Ollama tool calls from a response message."""
|
||||||
|
raw_tool_calls = message.get("tool_calls") or []
|
||||||
|
if not isinstance(raw_tool_calls, list):
|
||||||
|
msg = "tool_calls must be a list"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for raw_call in raw_tool_calls:
|
||||||
|
if not isinstance(raw_call, dict):
|
||||||
|
msg = "tool call must be an object"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
function = raw_call.get("function")
|
||||||
|
if not isinstance(function, dict):
|
||||||
|
msg = "tool call is missing function"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
name = function.get("name")
|
||||||
|
if not isinstance(name, str) or not name:
|
||||||
|
msg = "tool call is missing function name"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
arguments = parse_tool_arguments(function.get("arguments", {}))
|
||||||
|
tool_calls.append((name, arguments))
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_arguments(raw_arguments: object) -> dict[str, object]:
|
||||||
|
"""Parse tool call arguments returned by Ollama."""
|
||||||
|
if isinstance(raw_arguments, dict):
|
||||||
|
return {str(key): value for key, value in raw_arguments.items()}
|
||||||
|
if isinstance(raw_arguments, str):
|
||||||
|
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return {str(key): value for key, value in parsed.items()}
|
||||||
|
msg = "tool arguments must be an object"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_title_slug(title: str) -> None:
|
||||||
|
"""Validate a canonical book title slug."""
|
||||||
|
if not TITLE_SLUG_PATTERN.fullmatch(title):
|
||||||
|
msg = f"title slug is invalid: {title}"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_catalog_slug(value: str, label: str) -> None:
|
||||||
|
"""Validate a canonical catalog slug."""
|
||||||
|
if not CATALOG_SLUG_PATTERN.fullmatch(value):
|
||||||
|
msg = f"{label} slug is invalid: {value}"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_catalog_slug(value: str) -> str:
|
||||||
|
"""Normalize noisy catalog names into lower snake-case slugs."""
|
||||||
|
return re.sub(r"[^a-z0-9]+", "_", value.strip().casefold()).strip("_")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_title_slug(value: str) -> str:
|
||||||
|
"""Normalize noisy book titles into lower kebab-case slugs."""
|
||||||
|
return re.sub(r"[^a-z0-9]+", "-", value.strip().casefold()).strip("-")
|
||||||
|
|
||||||
|
|
||||||
|
def is_fatal_tool_error(error: MetadataResolutionError) -> bool:
|
||||||
|
"""Return whether a tool error should stop the agent immediately."""
|
||||||
|
message = str(error)
|
||||||
|
return message.startswith(
|
||||||
|
(
|
||||||
|
"Unknown audiobook metadata tool",
|
||||||
|
"Audiobook metadata tool is not enabled",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def query_terms(query: str) -> tuple[str, ...]:
|
||||||
|
"""Return text variants useful for matching noisy audiobook metadata."""
|
||||||
|
normalized = query.strip().casefold()
|
||||||
|
underscore_slug = normalize_catalog_slug(normalized)
|
||||||
|
hyphen_slug = normalize_title_slug(normalized)
|
||||||
|
return tuple(dict.fromkeys(term for term in (normalized, underscore_slug, hyphen_slug) if term))
|
||||||
|
|
||||||
|
|
||||||
|
def required_string(data: dict[str, object], key: str) -> str:
|
||||||
|
"""Read a required string field."""
|
||||||
|
value = data.get(key)
|
||||||
|
if not isinstance(value, str) or not value.strip():
|
||||||
|
msg = f"{key} must be a non-empty string"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def required_int(data: dict[str, object], key: str) -> int:
|
||||||
|
"""Read a required integer field."""
|
||||||
|
value = data.get(key)
|
||||||
|
if isinstance(value, bool) or not isinstance(value, int):
|
||||||
|
msg = f"{key} must be an integer"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def optional_int(value: object, key: str) -> int | None:
|
||||||
|
"""Read an optional integer field."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, bool) or not isinstance(value, int):
|
||||||
|
msg = f"{key} must be an integer or null"
|
||||||
|
raise MetadataResolutionError(msg)
|
||||||
|
return value
|
||||||
@@ -4,30 +4,35 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from dataclasses import asdict, dataclass, is_dataclass, replace
|
from dataclasses import asdict, dataclass, is_dataclass, replace
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from sqlalchemy import or_, select
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from python.common import utcnow
|
from python.common import utcnow
|
||||||
from python.orm.richie import Audiobook, AudiobookAuthor, AudiobookSeries
|
from python.tools.audiobook.llm_tool_calling import (
|
||||||
|
CatalogToolRegistry,
|
||||||
|
MetadataResolutionError,
|
||||||
|
normalize_title_slug,
|
||||||
|
optional_int,
|
||||||
|
parse_tool_calls,
|
||||||
|
required_int,
|
||||||
|
required_string,
|
||||||
|
run_tool_calls,
|
||||||
|
validate_catalog_slug,
|
||||||
|
validate_title_slug,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
CATALOG_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:_[a-z0-9]+)*$")
|
from python.orm.richie import AudiobookAuthor
|
||||||
|
|
||||||
FENCED_JSON_PATTERN = re.compile(r"^```(?:json)?\s*(?P<json>.*?)\s*```$", re.IGNORECASE | re.DOTALL)
|
FENCED_JSON_PATTERN = re.compile(r"^```(?:json)?\s*(?P<json>.*?)\s*```$", re.IGNORECASE | re.DOTALL)
|
||||||
TITLE_SLUG_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataResolutionError(ValueError):
|
|
||||||
"""Metadata resolution failed validation."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -91,14 +96,6 @@ class ResolvedBookFields:
|
|||||||
series_index: int
|
series_index: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class EnsuredBook:
|
|
||||||
"""Book row plus whether it was created."""
|
|
||||||
|
|
||||||
book: Audiobook
|
|
||||||
action: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AgentStepResult:
|
class AgentStepResult:
|
||||||
"""Outcome from one model response."""
|
"""Outcome from one model response."""
|
||||||
@@ -118,7 +115,7 @@ def standard_book_metadata(
|
|||||||
) -> StandardBookMetadata:
|
) -> StandardBookMetadata:
|
||||||
"""Resolve canonical audiobook metadata with the configured Ollama Cloud model."""
|
"""Resolve canonical audiobook metadata with the configured Ollama Cloud model."""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
registry = CatalogToolRegistry(session, log_path, config)
|
registry = CatalogToolRegistry(session, log_path, config, write_agent_log)
|
||||||
agent = AudiobookMetadataAgent(
|
agent = AudiobookMetadataAgent(
|
||||||
registry=registry, log_path=log_path, ollama_api_key=ollama_api_key, config=config
|
registry=registry, log_path=log_path, ollama_api_key=ollama_api_key, config=config
|
||||||
)
|
)
|
||||||
@@ -126,388 +123,15 @@ def standard_book_metadata(
|
|||||||
if metadata.needs_review:
|
if metadata.needs_review:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
else:
|
else:
|
||||||
registry.prune_unused_created_rows(metadata)
|
registry.prune_unused_created_rows(
|
||||||
|
author_id=metadata.author_id,
|
||||||
|
book_id=metadata.book_id,
|
||||||
|
series_id=metadata.series_id,
|
||||||
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
class CatalogToolRegistry:
|
|
||||||
"""Controlled catalog tools exposed to the metadata model."""
|
|
||||||
|
|
||||||
def __init__(self, session: Session, log_path: Path, config: AgentConfig) -> None:
|
|
||||||
"""Create a registry bound to one database session and audit log."""
|
|
||||||
self._session = session
|
|
||||||
self._log_path = log_path
|
|
||||||
self._config = config
|
|
||||||
self.seen_author_ids: set[int] = set()
|
|
||||||
self.seen_series_ids: set[int] = set()
|
|
||||||
self.seen_book_ids: set[int] = set()
|
|
||||||
self.created_author_ids: set[int] = set()
|
|
||||||
self.created_series_ids: set[int] = set()
|
|
||||||
self.created_book_ids: set[int] = set()
|
|
||||||
|
|
||||||
def tool_schemas(self) -> list[dict[str, object]]:
|
|
||||||
"""Return Ollama tool schemas."""
|
|
||||||
schemas = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_authors",
|
|
||||||
"description": "Search canonical audiobook authors by slug or noisy source text.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_series",
|
|
||||||
"description": "Search canonical audiobook series by slug or noisy source text.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string"},
|
|
||||||
"author_id": {"type": ["integer", "null"]},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "search_books",
|
|
||||||
"description": "Search canonical audiobook titles with optional author and series filters.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {"type": "string"},
|
|
||||||
"author_id": {"type": ["integer", "null"]},
|
|
||||||
"series_id": {"type": ["integer", "null"]},
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "ensure_author",
|
|
||||||
"description": "Normalize an author name to a catalog slug, then return or create that author.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"name": {"type": "string"}},
|
|
||||||
"required": ["name"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "ensure_series",
|
|
||||||
"description": "Normalize a series name to a catalog slug, then return or create it for an author.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"author_id": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"required": ["name", "author_id"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "ensure_book",
|
|
||||||
"description": "Normalize a title to a book slug, then return or create it for an author/series.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"title": {"type": "string"},
|
|
||||||
"author_id": {"type": "integer"},
|
|
||||||
"series_id": {"type": ["integer", "null"]},
|
|
||||||
"series_index": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"required": ["title", "author_id", "series_id", "series_index"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
enabled_tool_names = set(self._config.tool_names)
|
|
||||||
return [schema for schema in schemas if schema["function"]["name"] in enabled_tool_names]
|
|
||||||
|
|
||||||
def run(self, name: str, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Run a single read-only catalog tool."""
|
|
||||||
handlers = {
|
|
||||||
"search_authors": self.run_search_authors,
|
|
||||||
"search_series": self.run_search_series,
|
|
||||||
"search_books": self.run_search_books,
|
|
||||||
"ensure_author": self.run_ensure_author,
|
|
||||||
"ensure_series": self.run_ensure_series,
|
|
||||||
"ensure_book": self.run_ensure_book,
|
|
||||||
}
|
|
||||||
handler = handlers.get(name)
|
|
||||||
if handler is None:
|
|
||||||
write_agent_log(self._log_path, "tool_error", tool=name, arguments=arguments, error="unknown_tool")
|
|
||||||
msg = f"Unknown audiobook metadata tool: {name}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if name not in self._config.tool_names:
|
|
||||||
write_agent_log(self._log_path, "tool_error", tool=name, arguments=arguments, error="tool_not_enabled")
|
|
||||||
msg = f"Audiobook metadata tool is not enabled: {name}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
started = time.perf_counter()
|
|
||||||
write_agent_log(self._log_path, "tool_call", tool=name, arguments=arguments)
|
|
||||||
result = handler(arguments)
|
|
||||||
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
|
||||||
write_agent_log(
|
|
||||||
self._log_path,
|
|
||||||
"tool_result",
|
|
||||||
tool=name,
|
|
||||||
duration_ms=duration_ms,
|
|
||||||
result_count=len(result),
|
|
||||||
preview=result[:3],
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_author(self, author_id: int) -> AudiobookAuthor | None:
|
|
||||||
"""Return an author by id."""
|
|
||||||
return self._session.get(AudiobookAuthor, author_id)
|
|
||||||
|
|
||||||
def get_book(self, book_id: int) -> Audiobook | None:
|
|
||||||
"""Return a book by id."""
|
|
||||||
return self._session.get(Audiobook, book_id)
|
|
||||||
|
|
||||||
def get_series(self, series_id: int) -> AudiobookSeries | None:
|
|
||||||
"""Return a series by id."""
|
|
||||||
return self._session.get(AudiobookSeries, series_id)
|
|
||||||
|
|
||||||
def prune_unused_created_rows(self, metadata: StandardBookMetadata) -> None:
|
|
||||||
"""Remove catalog rows created during this run but not used by final metadata."""
|
|
||||||
used_book_ids = {metadata.book_id} if metadata.book_id is not None else set()
|
|
||||||
for book_id in self.created_book_ids - used_book_ids:
|
|
||||||
if book := self.get_book(book_id):
|
|
||||||
self._session.delete(book)
|
|
||||||
|
|
||||||
self._session.flush()
|
|
||||||
used_series_ids = {metadata.series_id} if metadata.series_id is not None else set()
|
|
||||||
for series_id in self.created_series_ids - used_series_ids:
|
|
||||||
series = self.get_series(series_id)
|
|
||||||
if series and not series.books:
|
|
||||||
self._session.delete(series)
|
|
||||||
|
|
||||||
self._session.flush()
|
|
||||||
for author_id in self.created_author_ids - {metadata.author_id}:
|
|
||||||
author = self.get_author(author_id)
|
|
||||||
if author and not author.books and not author.series:
|
|
||||||
self._session.delete(author)
|
|
||||||
|
|
||||||
def run_search_authors(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Search authors from tool arguments and remember returned ids."""
|
|
||||||
query = required_string(arguments, "query")
|
|
||||||
statement = select(AudiobookAuthor).order_by(AudiobookAuthor.name).limit(self._config.max_tool_results)
|
|
||||||
if terms := query_terms(query):
|
|
||||||
statement = statement.where(or_(*(AudiobookAuthor.name.ilike(f"%{term}%") for term in terms)))
|
|
||||||
|
|
||||||
authors = self._session.scalars(statement).all()
|
|
||||||
self.seen_author_ids.update(author.id for author in authors)
|
|
||||||
return [{"id": author.id, "name": author.name} for author in authors]
|
|
||||||
|
|
||||||
def run_search_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Search series from tool arguments and remember returned ids."""
|
|
||||||
query = required_string(arguments, "query")
|
|
||||||
author_id = optional_int(arguments.get("author_id"), "author_id")
|
|
||||||
statement = select(AudiobookSeries).order_by(AudiobookSeries.name).limit(self._config.max_tool_results)
|
|
||||||
if terms := query_terms(query):
|
|
||||||
statement = statement.where(or_(*(AudiobookSeries.name.ilike(f"%{term}%") for term in terms)))
|
|
||||||
if author_id is not None:
|
|
||||||
statement = statement.where(AudiobookSeries.author_id == author_id)
|
|
||||||
|
|
||||||
series_rows = self._session.scalars(statement).all()
|
|
||||||
self.seen_series_ids.update(series.id for series in series_rows)
|
|
||||||
self.seen_author_ids.update(series.author_id for series in series_rows)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": series.id,
|
|
||||||
"name": series.name,
|
|
||||||
"author_id": series.author_id,
|
|
||||||
"author": series.author.name,
|
|
||||||
}
|
|
||||||
for series in series_rows
|
|
||||||
]
|
|
||||||
|
|
||||||
def run_search_books(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Search books from tool arguments and remember returned ids."""
|
|
||||||
query = required_string(arguments, "query")
|
|
||||||
author_id = optional_int(arguments.get("author_id"), "author_id")
|
|
||||||
series_id = optional_int(arguments.get("series_id"), "series_id")
|
|
||||||
statement = select(Audiobook).order_by(Audiobook.title).limit(self._config.max_tool_results)
|
|
||||||
if terms := query_terms(query):
|
|
||||||
statement = statement.where(or_(*(Audiobook.title.ilike(f"%{term}%") for term in terms)))
|
|
||||||
if author_id is not None:
|
|
||||||
statement = statement.where(Audiobook.author_id == author_id)
|
|
||||||
if series_id is not None:
|
|
||||||
statement = statement.where(Audiobook.series_id == series_id)
|
|
||||||
|
|
||||||
books = self._session.scalars(statement).all()
|
|
||||||
self.seen_book_ids.update(book.id for book in books)
|
|
||||||
self.seen_author_ids.update(book.author_id for book in books)
|
|
||||||
self.seen_series_ids.update(book.series_id for book in books if book.series_id is not None)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": book.id,
|
|
||||||
"title": book.title,
|
|
||||||
"author_id": book.author_id,
|
|
||||||
"author": book.author.name,
|
|
||||||
"series_id": book.series_id,
|
|
||||||
"series": book.series.name if book.series else self._config.standalone_series,
|
|
||||||
"series_index": book.series_index,
|
|
||||||
}
|
|
||||||
for book in books
|
|
||||||
]
|
|
||||||
|
|
||||||
def run_ensure_author(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Ensure an author from tool arguments and return a tool result."""
|
|
||||||
name = normalize_catalog_slug(required_string(arguments, "name"))
|
|
||||||
validate_catalog_slug(name, "author")
|
|
||||||
author = self._session.scalar(select(AudiobookAuthor).where(AudiobookAuthor.name == name))
|
|
||||||
action = "existing"
|
|
||||||
if author is None:
|
|
||||||
author = AudiobookAuthor(name=name)
|
|
||||||
self._session.add(author)
|
|
||||||
self._session.flush()
|
|
||||||
self.created_author_ids.add(author.id)
|
|
||||||
action = "created"
|
|
||||||
|
|
||||||
self.seen_author_ids.add(author.id)
|
|
||||||
return [{"id": author.id, "name": author.name, "action": action}]
|
|
||||||
|
|
||||||
def run_ensure_series(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Ensure a series from tool arguments and return a tool result."""
|
|
||||||
name = normalize_catalog_slug(required_string(arguments, "name"))
|
|
||||||
author_id = required_int(arguments, "author_id")
|
|
||||||
validate_catalog_slug(name, "series")
|
|
||||||
author = self.required_author(author_id)
|
|
||||||
series = self._session.scalar(
|
|
||||||
select(AudiobookSeries).where(
|
|
||||||
AudiobookSeries.name == name,
|
|
||||||
AudiobookSeries.author_id == author.id,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
action = "existing"
|
|
||||||
if series is None:
|
|
||||||
series = AudiobookSeries(name=name, author=author)
|
|
||||||
self._session.add(series)
|
|
||||||
self._session.flush()
|
|
||||||
self.created_series_ids.add(series.id)
|
|
||||||
action = "created"
|
|
||||||
|
|
||||||
self.seen_author_ids.add(author.id)
|
|
||||||
self.seen_series_ids.add(series.id)
|
|
||||||
return [self.series_result(series, action)]
|
|
||||||
|
|
||||||
def run_ensure_book(self, arguments: dict[str, object]) -> list[dict[str, object]]:
|
|
||||||
"""Ensure a book from tool arguments and return a tool result."""
|
|
||||||
title = required_string(arguments, "title")
|
|
||||||
author_id = required_int(arguments, "author_id")
|
|
||||||
series_id = optional_int(arguments.get("series_id"), "series_id")
|
|
||||||
series_index = required_int(arguments, "series_index")
|
|
||||||
ensured = self.ensure_book(title, author_id, series_id, series_index)
|
|
||||||
return [self.book_result(ensured.book, ensured.action)]
|
|
||||||
|
|
||||||
def ensure_book(
|
|
||||||
self,
|
|
||||||
title: str,
|
|
||||||
author_id: int,
|
|
||||||
series_id: int | None,
|
|
||||||
series_index: int,
|
|
||||||
) -> EnsuredBook:
|
|
||||||
"""Return an existing book row, or create it after validating ownership."""
|
|
||||||
title = normalize_title_slug(title)
|
|
||||||
validate_title_slug(title)
|
|
||||||
author = self.required_author(author_id)
|
|
||||||
series = None
|
|
||||||
if series_id is None:
|
|
||||||
if series_index != 0:
|
|
||||||
msg = "standalone books must use series_index 0"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
else:
|
|
||||||
series = self.required_series(series_id)
|
|
||||||
if series.author_id != author.id:
|
|
||||||
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
if series_index <= 0:
|
|
||||||
msg = "series books must use a positive series_index"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
statement = select(Audiobook).where(
|
|
||||||
Audiobook.title == title,
|
|
||||||
Audiobook.author_id == author.id,
|
|
||||||
)
|
|
||||||
if series is None:
|
|
||||||
statement = statement.where(Audiobook.series_id.is_(None))
|
|
||||||
else:
|
|
||||||
statement = statement.where(Audiobook.series_id == series.id)
|
|
||||||
book = self._session.scalar(statement)
|
|
||||||
if book is None:
|
|
||||||
book = Audiobook(title=title, author=author, series=series, series_index=series_index)
|
|
||||||
self._session.add(book)
|
|
||||||
self._session.flush()
|
|
||||||
self.created_book_ids.add(book.id)
|
|
||||||
action = "created"
|
|
||||||
else:
|
|
||||||
action = "existing"
|
|
||||||
|
|
||||||
self.seen_book_ids.add(book.id)
|
|
||||||
self.seen_author_ids.add(author.id)
|
|
||||||
if book.series_id is not None:
|
|
||||||
self.seen_series_ids.add(book.series_id)
|
|
||||||
return EnsuredBook(book=book, action=action)
|
|
||||||
|
|
||||||
def required_author(self, author_id: int) -> AudiobookAuthor:
|
|
||||||
"""Return an author or fail metadata resolution."""
|
|
||||||
author = self.get_author(author_id)
|
|
||||||
if author is None:
|
|
||||||
msg = f"author_id {author_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return author
|
|
||||||
|
|
||||||
def required_series(self, series_id: int) -> AudiobookSeries:
|
|
||||||
"""Return a series or fail metadata resolution."""
|
|
||||||
series = self.get_series(series_id)
|
|
||||||
if series is None:
|
|
||||||
msg = f"series_id {series_id} does not exist"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return series
|
|
||||||
|
|
||||||
def series_result(self, series: AudiobookSeries, action: str) -> dict[str, object]:
|
|
||||||
"""Build a normalized series tool result."""
|
|
||||||
return {
|
|
||||||
"id": series.id,
|
|
||||||
"name": series.name,
|
|
||||||
"author_id": series.author_id,
|
|
||||||
"author": series.author.name,
|
|
||||||
"action": action,
|
|
||||||
}
|
|
||||||
|
|
||||||
def book_result(self, book: Audiobook, action: str) -> dict[str, object]:
|
|
||||||
"""Build a normalized book tool result."""
|
|
||||||
return {
|
|
||||||
"id": book.id,
|
|
||||||
"title": book.title,
|
|
||||||
"author_id": book.author_id,
|
|
||||||
"author": book.author.name,
|
|
||||||
"series_id": book.series_id,
|
|
||||||
"series": book.series.name if book.series else self._config.standalone_series,
|
|
||||||
"series_index": book.series_index,
|
|
||||||
"action": action,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AudiobookMetadataAgent:
|
class AudiobookMetadataAgent:
|
||||||
"""Ollama-backed metadata resolver with a fixed local tool registry."""
|
"""Ollama-backed metadata resolver with a fixed local tool registry."""
|
||||||
|
|
||||||
@@ -571,45 +195,15 @@ class AudiobookMetadataAgent:
|
|||||||
should_continue=False,
|
should_continue=False,
|
||||||
)
|
)
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
return self.handle_tool_calls(messages, message, tool_calls, invalid_final_count)
|
fatal_error = run_tool_calls(messages, message, tool_calls, self._registry, self._log_path, write_agent_log)
|
||||||
return self.handle_final_message(messages, message, invalid_final_count)
|
if fatal_error is not None:
|
||||||
|
|
||||||
def handle_tool_calls(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, object]],
|
|
||||||
message: dict[str, object],
|
|
||||||
tool_calls: list[tuple[str, dict[str, object]]],
|
|
||||||
invalid_final_count: int,
|
|
||||||
) -> AgentStepResult:
|
|
||||||
"""Run tool calls from one model response and append tool results."""
|
|
||||||
messages.append(message)
|
|
||||||
for tool_name, arguments in tool_calls:
|
|
||||||
try:
|
|
||||||
tool_result = self._registry.run(tool_name, arguments)
|
|
||||||
except MetadataResolutionError as error:
|
|
||||||
if is_fatal_tool_error(error):
|
|
||||||
return AgentStepResult(
|
return AgentStepResult(
|
||||||
metadata=review_metadata(str(error), self._config),
|
metadata=review_metadata(fatal_error, self._config),
|
||||||
invalid_final_count=invalid_final_count,
|
invalid_final_count=invalid_final_count,
|
||||||
should_continue=False,
|
should_continue=False,
|
||||||
)
|
)
|
||||||
write_agent_log(self._log_path, "tool_error", tool=tool_name, arguments=arguments, error=str(error))
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"content": json.dumps({"error": str(error)}, sort_keys=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"content": json.dumps(tool_result, sort_keys=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
||||||
|
return self.handle_final_message(messages, message, invalid_final_count)
|
||||||
|
|
||||||
def handle_final_message(
|
def handle_final_message(
|
||||||
self,
|
self,
|
||||||
@@ -905,43 +499,6 @@ def user_prompt(aax_file_name: str, metadata: dict[str, str]) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_calls(message: dict[str, object]) -> list[tuple[str, dict[str, object]]]:
|
|
||||||
"""Parse Ollama tool calls from a response message."""
|
|
||||||
raw_tool_calls = message.get("tool_calls") or []
|
|
||||||
if not isinstance(raw_tool_calls, list):
|
|
||||||
msg = "tool_calls must be a list"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
for raw_call in raw_tool_calls:
|
|
||||||
if not isinstance(raw_call, dict):
|
|
||||||
msg = "tool call must be an object"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
function = raw_call.get("function")
|
|
||||||
if not isinstance(function, dict):
|
|
||||||
msg = "tool call is missing function"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
name = function.get("name")
|
|
||||||
if not isinstance(name, str) or not name:
|
|
||||||
msg = "tool call is missing function name"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
arguments = parse_tool_arguments(function.get("arguments", {}))
|
|
||||||
tool_calls.append((name, arguments))
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
def parse_tool_arguments(raw_arguments: object) -> dict[str, object]:
|
|
||||||
"""Parse tool call arguments returned by Ollama."""
|
|
||||||
if isinstance(raw_arguments, dict):
|
|
||||||
return {str(key): value for key, value in raw_arguments.items()}
|
|
||||||
if isinstance(raw_arguments, str):
|
|
||||||
parsed = json.loads(raw_arguments) if raw_arguments else {}
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
return {str(key): value for key, value in parsed.items()}
|
|
||||||
msg = "tool arguments must be an object"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_final_json_content(content: str) -> object:
|
def parse_final_json_content(content: str) -> object:
|
||||||
"""Parse final model content, accepting bare or fenced JSON."""
|
"""Parse final model content, accepting bare or fenced JSON."""
|
||||||
stripped = content.strip()
|
stripped = content.strip()
|
||||||
@@ -967,41 +524,6 @@ def parse_final_metadata_fields(raw_metadata: object) -> FinalMetadataFields:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_title_slug(title: str) -> None:
|
|
||||||
"""Validate a canonical book title slug."""
|
|
||||||
if not TITLE_SLUG_PATTERN.fullmatch(title):
|
|
||||||
msg = f"title slug is invalid: {title}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_catalog_slug(value: str, label: str) -> None:
|
|
||||||
"""Validate a canonical catalog slug."""
|
|
||||||
if not CATALOG_SLUG_PATTERN.fullmatch(value):
|
|
||||||
msg = f"{label} slug is invalid: {value}"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_catalog_slug(value: str) -> str:
|
|
||||||
"""Normalize noisy catalog names into lower snake-case slugs."""
|
|
||||||
return re.sub(r"[^a-z0-9]+", "_", value.strip().casefold()).strip("_")
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_title_slug(value: str) -> str:
|
|
||||||
"""Normalize noisy book titles into lower kebab-case slugs."""
|
|
||||||
return re.sub(r"[^a-z0-9]+", "-", value.strip().casefold()).strip("-")
|
|
||||||
|
|
||||||
|
|
||||||
def is_fatal_tool_error(error: MetadataResolutionError) -> bool:
|
|
||||||
"""Return whether a tool error should stop the agent immediately."""
|
|
||||||
message = str(error)
|
|
||||||
return message.startswith(
|
|
||||||
(
|
|
||||||
"Unknown audiobook metadata tool",
|
|
||||||
"Audiobook metadata tool is not enabled",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
|
def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
|
||||||
"""Return a metadata result that must be reviewed manually."""
|
"""Return a metadata result that must be reviewed manually."""
|
||||||
return StandardBookMetadata(
|
return StandardBookMetadata(
|
||||||
@@ -1018,42 +540,6 @@ def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def query_terms(query: str) -> tuple[str, ...]:
|
|
||||||
"""Return text variants useful for matching noisy audiobook metadata."""
|
|
||||||
normalized = query.strip().casefold()
|
|
||||||
underscore_slug = normalize_catalog_slug(normalized)
|
|
||||||
hyphen_slug = normalize_title_slug(normalized)
|
|
||||||
return tuple(dict.fromkeys(term for term in (normalized, underscore_slug, hyphen_slug) if term))
|
|
||||||
|
|
||||||
|
|
||||||
def required_string(data: dict[str, object], key: str) -> str:
|
|
||||||
"""Read a required string field."""
|
|
||||||
value = data.get(key)
|
|
||||||
if not isinstance(value, str) or not value.strip():
|
|
||||||
msg = f"{key} must be a non-empty string"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return value.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def required_int(data: dict[str, object], key: str) -> int:
|
|
||||||
"""Read a required integer field."""
|
|
||||||
value = data.get(key)
|
|
||||||
if isinstance(value, bool) or not isinstance(value, int):
|
|
||||||
msg = f"{key} must be an integer"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def optional_int(value: object, key: str) -> int | None:
|
|
||||||
"""Read an optional integer field."""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
if isinstance(value, bool) or not isinstance(value, int):
|
|
||||||
msg = f"{key} must be an integer or null"
|
|
||||||
raise MetadataResolutionError(msg)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def required_float(data: dict[str, object], key: str) -> float:
|
def required_float(data: dict[str, object], key: str) -> float:
|
||||||
"""Read a required float field."""
|
"""Read a required float field."""
|
||||||
value = data.get(key)
|
value = data.get(key)
|
||||||
|
|||||||
Reference in New Issue
Block a user