567 lines
22 KiB
Python
567 lines
22 KiB
Python
"""Resolve audiobook metadata with a controlled Ollama tool loop."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from dataclasses import asdict, dataclass, is_dataclass, replace
|
|
from os import PathLike
|
|
from typing import TYPE_CHECKING
|
|
|
|
import httpx
|
|
from sqlalchemy.orm import Session
|
|
|
|
from python.common import utcnow
|
|
from python.tools.audiobook.llm_tool_calling import (
|
|
CatalogToolRegistry,
|
|
MetadataResolutionError,
|
|
normalize_title_slug,
|
|
optional_int,
|
|
parse_tool_calls,
|
|
required_int,
|
|
required_string,
|
|
run_tool_calls,
|
|
validate_catalog_slug,
|
|
validate_title_slug,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy.engine import Engine
|
|
|
|
from python.orm.richie import AudiobookAuthor
|
|
|
|
FENCED_JSON_PATTERN = re.compile(r"^```(?:json)?\s*(?P<json>.*?)\s*```$", re.IGNORECASE | re.DOTALL)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentConfig:
|
|
"""Runtime settings for the audiobook metadata agent."""
|
|
|
|
model: str = "deepseek-v4-flash:cloud"
|
|
ollama_chat_url: str = "https://ollama.com/api/chat"
|
|
http_timeout_seconds: int = 300
|
|
max_agent_turns: int = 8
|
|
max_tool_results: int = 10
|
|
min_confidence: float = 0.85
|
|
invalid_final_retries: int = 1
|
|
standalone_series: str = "standalone"
|
|
tool_names: tuple[str, ...] = (
|
|
"search_authors",
|
|
"search_series",
|
|
"search_books",
|
|
"ensure_author",
|
|
"ensure_series",
|
|
"ensure_book",
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class StandardBookMetadata:
|
|
"""Canonical metadata for the final audiobook path."""
|
|
|
|
author_id: int
|
|
author: str
|
|
book_id: int | None
|
|
title: str
|
|
series_id: int | None
|
|
series: str
|
|
series_index: int
|
|
confidence: float
|
|
needs_review: bool
|
|
evidence: list[str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FinalMetadataFields:
|
|
"""Raw model fields after schema validation."""
|
|
|
|
author_id: int
|
|
book_id: int | None
|
|
title: str
|
|
series_id: int | None
|
|
series_index: int
|
|
confidence: float
|
|
evidence: list[str]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ResolvedBookFields:
|
|
"""Book fields after optional catalog book resolution."""
|
|
|
|
book_id: int | None
|
|
title: str
|
|
series_id: int | None
|
|
series_index: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AgentStepResult:
|
|
"""Outcome from one model response."""
|
|
|
|
metadata: StandardBookMetadata | None
|
|
invalid_final_count: int
|
|
should_continue: bool
|
|
|
|
|
|
def standard_book_metadata(
|
|
aax_file_name: str,
|
|
aax_metadata_from_ffprobe: dict[str, str],
|
|
engine: Engine,
|
|
log_path: Path,
|
|
ollama_api_key: str,
|
|
config: AgentConfig,
|
|
) -> StandardBookMetadata:
|
|
"""Resolve canonical audiobook metadata with the configured Ollama Cloud model."""
|
|
with Session(engine) as session:
|
|
registry = CatalogToolRegistry(session, log_path, config, write_agent_log)
|
|
agent = AudiobookMetadataAgent(
|
|
registry=registry, log_path=log_path, ollama_api_key=ollama_api_key, config=config
|
|
)
|
|
metadata = agent.run(aax_file_name, aax_metadata_from_ffprobe)
|
|
if metadata.needs_review:
|
|
session.rollback()
|
|
else:
|
|
registry.prune_unused_created_rows(
|
|
author_id=metadata.author_id,
|
|
book_id=metadata.book_id,
|
|
series_id=metadata.series_id,
|
|
)
|
|
session.commit()
|
|
return metadata
|
|
|
|
|
|
class AudiobookMetadataAgent:
|
|
"""Ollama-backed metadata resolver with a fixed local tool registry."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
registry: CatalogToolRegistry,
|
|
log_path: Path,
|
|
ollama_api_key: str,
|
|
config: AgentConfig,
|
|
) -> None:
|
|
"""Create an Ollama metadata agent."""
|
|
self._registry = registry
|
|
self._log_path = log_path
|
|
self._ollama_api_key = ollama_api_key
|
|
self._config = config
|
|
|
|
def run(self, aax_file_name: str, aax_metadata_from_ffprobe: dict[str, str]) -> StandardBookMetadata:
|
|
"""Resolve metadata for one AAX file."""
|
|
messages = [
|
|
{"role": "system", "content": system_prompt()},
|
|
{"role": "user", "content": user_prompt(aax_file_name, aax_metadata_from_ffprobe)},
|
|
]
|
|
invalid_final_count = 0
|
|
result: StandardBookMetadata | None = None
|
|
|
|
for turn in range(1, self._config.max_agent_turns + 1):
|
|
step = self.run_step(messages, turn, invalid_final_count)
|
|
invalid_final_count = step.invalid_final_count
|
|
if step.should_continue:
|
|
continue
|
|
result = step.metadata
|
|
break
|
|
|
|
if result is None:
|
|
return self.force_final_response(messages)
|
|
return result
|
|
|
|
def run_step(
|
|
self,
|
|
messages: list[dict[str, object]],
|
|
turn: int,
|
|
invalid_final_count: int,
|
|
) -> AgentStepResult:
|
|
"""Run one model turn and return the next agent-loop action."""
|
|
data = self.chat(messages, turn)
|
|
message = data.get("message")
|
|
if not isinstance(message, dict):
|
|
return AgentStepResult(
|
|
metadata=review_metadata("Ollama response did not include a message", self._config),
|
|
invalid_final_count=invalid_final_count,
|
|
should_continue=False,
|
|
)
|
|
|
|
try:
|
|
tool_calls = parse_tool_calls(message)
|
|
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
|
return AgentStepResult(
|
|
metadata=review_metadata(str(error), self._config),
|
|
invalid_final_count=invalid_final_count,
|
|
should_continue=False,
|
|
)
|
|
if tool_calls:
|
|
fatal_error = run_tool_calls(messages, message, tool_calls, self._registry, self._log_path, write_agent_log)
|
|
if fatal_error is not None:
|
|
return AgentStepResult(
|
|
metadata=review_metadata(fatal_error, self._config),
|
|
invalid_final_count=invalid_final_count,
|
|
should_continue=False,
|
|
)
|
|
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
|
return self.handle_final_message(messages, message, invalid_final_count)
|
|
|
|
def handle_final_message(
|
|
self,
|
|
messages: list[dict[str, object]],
|
|
message: dict[str, object],
|
|
invalid_final_count: int,
|
|
) -> AgentStepResult:
|
|
"""Validate a final model message or request one retry."""
|
|
content = message.get("content")
|
|
if not isinstance(content, str):
|
|
return AgentStepResult(
|
|
metadata=review_metadata("Ollama final response did not include string content", self._config),
|
|
invalid_final_count=invalid_final_count,
|
|
should_continue=False,
|
|
)
|
|
|
|
try:
|
|
resolved = self.validate_final(parse_final_json_content(content))
|
|
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
|
return self.handle_invalid_final(messages, error, invalid_final_count)
|
|
|
|
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
|
|
return AgentStepResult(metadata=resolved, invalid_final_count=invalid_final_count, should_continue=False)
|
|
|
|
def handle_invalid_final(
|
|
self,
|
|
messages: list[dict[str, object]],
|
|
error: json.JSONDecodeError | MetadataResolutionError,
|
|
invalid_final_count: int,
|
|
) -> AgentStepResult:
|
|
"""Log invalid final JSON and either retry or return review metadata."""
|
|
invalid_final_count += 1
|
|
write_agent_log(
|
|
self._log_path,
|
|
"final_validation_error",
|
|
error=str(error),
|
|
invalid_final_count=invalid_final_count,
|
|
)
|
|
if invalid_final_count > self._config.invalid_final_retries:
|
|
return AgentStepResult(
|
|
metadata=review_metadata(str(error), self._config),
|
|
invalid_final_count=invalid_final_count,
|
|
should_continue=False,
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"Your previous final answer was invalid. Return only valid JSON matching the required "
|
|
f"schema. Validation error: {error}"
|
|
),
|
|
},
|
|
)
|
|
return AgentStepResult(metadata=None, invalid_final_count=invalid_final_count, should_continue=True)
|
|
|
|
def force_final_response(self, messages: list[dict[str, object]]) -> StandardBookMetadata:
|
|
"""Request a no-tool final answer after the normal turn limit."""
|
|
messages.append({"role": "user", "content": forced_final_prompt()})
|
|
write_agent_log(self._log_path, "forced_final_request", reason="max_turns")
|
|
data = self.chat(messages, self._config.max_agent_turns + 1, tools_enabled=False)
|
|
message = data.get("message")
|
|
if not isinstance(message, dict):
|
|
return review_metadata("Ollama forced final response did not include a message", self._config)
|
|
content = message.get("content")
|
|
if not isinstance(content, str):
|
|
return review_metadata("Ollama forced final response did not include string content", self._config)
|
|
try:
|
|
resolved = self.validate_final(parse_final_json_content(content))
|
|
except (json.JSONDecodeError, MetadataResolutionError) as error:
|
|
return review_metadata(f"Ollama forced final response was invalid: {error}", self._config)
|
|
write_agent_log(self._log_path, "final_metadata", metadata=resolved)
|
|
return resolved
|
|
|
|
def chat(self, messages: list[dict[str, object]], turn: int, *, tools_enabled: bool = True) -> dict[str, object]:
|
|
"""Send one chat request to Ollama and log the request and response."""
|
|
payload = {
|
|
"model": self._config.model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"options": {"temperature": 0},
|
|
}
|
|
tool_names = []
|
|
if tools_enabled:
|
|
payload["tools"] = self._registry.tool_schemas()
|
|
tool_names = self._config.tool_names
|
|
write_agent_log(
|
|
self._log_path,
|
|
"model_request",
|
|
model=self._config.model,
|
|
turn=turn,
|
|
message_count=len(messages),
|
|
tool_names=tool_names,
|
|
tools_enabled=tools_enabled,
|
|
)
|
|
write_agent_log(
|
|
self._log_path,
|
|
"llm_messages_sent",
|
|
model=self._config.model,
|
|
turn=turn,
|
|
messages=messages,
|
|
tools_enabled=tools_enabled,
|
|
)
|
|
response = httpx.post(
|
|
self._config.ollama_chat_url,
|
|
headers={"Authorization": f"Bearer {self._ollama_api_key}"},
|
|
json=payload,
|
|
timeout=self._config.http_timeout_seconds,
|
|
)
|
|
response.raise_for_status()
|
|
raw_data = response.json()
|
|
if not isinstance(raw_data, dict):
|
|
return {}
|
|
data = {str(key): value for key, value in raw_data.items()}
|
|
message = data.get("message", {})
|
|
content = message.get("content") if isinstance(message, dict) else ""
|
|
write_agent_log(
|
|
self._log_path,
|
|
"llm_message_received",
|
|
model=self._config.model,
|
|
turn=turn,
|
|
message=message,
|
|
)
|
|
write_agent_log(
|
|
self._log_path,
|
|
"model_response",
|
|
model=self._config.model,
|
|
turn=turn,
|
|
has_tool_calls=bool(isinstance(message, dict) and message.get("tool_calls")),
|
|
content_chars=len(content) if isinstance(content, str) else 0,
|
|
)
|
|
return data
|
|
|
|
def validate_final(self, raw_metadata: object) -> StandardBookMetadata:
|
|
"""Validate final model metadata against catalog rows."""
|
|
fields = parse_final_metadata_fields(raw_metadata)
|
|
fields = replace(fields, title=normalize_title_slug(fields.title))
|
|
author = self.validate_author(fields.author_id)
|
|
validate_title_slug(fields.title)
|
|
book_fields = self.resolve_book_fields(fields)
|
|
series = self.validate_series(fields.author_id, book_fields.series_id, book_fields.series_index)
|
|
|
|
return StandardBookMetadata(
|
|
author_id=fields.author_id,
|
|
author=author.name,
|
|
book_id=book_fields.book_id,
|
|
title=book_fields.title,
|
|
series_id=book_fields.series_id,
|
|
series=series,
|
|
series_index=book_fields.series_index,
|
|
confidence=fields.confidence,
|
|
needs_review=fields.confidence < self._config.min_confidence,
|
|
evidence=fields.evidence,
|
|
)
|
|
|
|
def validate_author(self, author_id: int) -> AudiobookAuthor:
|
|
"""Validate that an author id was seen and exists."""
|
|
if author_id not in self._registry.seen_author_ids:
|
|
msg = f"author_id {author_id} was not returned by search_authors"
|
|
raise MetadataResolutionError(msg)
|
|
author = self._registry.get_author(author_id)
|
|
if author is None:
|
|
msg = f"author_id {author_id} does not exist"
|
|
raise MetadataResolutionError(msg)
|
|
validate_catalog_slug(author.name, "author")
|
|
return author
|
|
|
|
def resolve_book_fields(self, fields: FinalMetadataFields) -> ResolvedBookFields:
|
|
"""Resolve final book fields from a seen book id or created book."""
|
|
if fields.book_id is None:
|
|
ensured = self._registry.ensure_book(
|
|
fields.title,
|
|
fields.author_id,
|
|
fields.series_id,
|
|
fields.series_index,
|
|
)
|
|
return ResolvedBookFields(
|
|
book_id=ensured.book.id,
|
|
title=ensured.book.title,
|
|
series_id=ensured.book.series_id,
|
|
series_index=ensured.book.series_index,
|
|
)
|
|
|
|
if fields.book_id not in self._registry.seen_book_ids:
|
|
msg = f"book_id {fields.book_id} was not returned by search_books"
|
|
raise MetadataResolutionError(msg)
|
|
book = self._registry.get_book(fields.book_id)
|
|
if book is None:
|
|
msg = f"book_id {fields.book_id} does not exist"
|
|
raise MetadataResolutionError(msg)
|
|
if book.author_id != fields.author_id:
|
|
msg = f"book_id {fields.book_id} does not belong to author_id {fields.author_id}"
|
|
raise MetadataResolutionError(msg)
|
|
return ResolvedBookFields(
|
|
book_id=fields.book_id,
|
|
title=book.title,
|
|
series_id=book.series_id,
|
|
series_index=book.series_index,
|
|
)
|
|
|
|
def validate_series(self, author_id: int, series_id: int | None, series_index: int) -> str:
|
|
"""Validate final series fields and return the canonical series slug."""
|
|
if series_id is None:
|
|
if series_index != 0:
|
|
msg = "standalone books must use series_index 0"
|
|
raise MetadataResolutionError(msg)
|
|
return self._config.standalone_series
|
|
|
|
if series_id not in self._registry.seen_series_ids:
|
|
msg = f"series_id {series_id} was not returned by search_series"
|
|
raise MetadataResolutionError(msg)
|
|
series = self._registry.get_series(series_id)
|
|
if series is None:
|
|
msg = f"series_id {series_id} does not exist"
|
|
raise MetadataResolutionError(msg)
|
|
if series.author_id != author_id:
|
|
msg = f"series_id {series_id} does not belong to author_id {author_id}"
|
|
raise MetadataResolutionError(msg)
|
|
if series_index <= 0:
|
|
msg = "series books must use a positive series_index"
|
|
raise MetadataResolutionError(msg)
|
|
validate_catalog_slug(series.name, "series")
|
|
return series.name
|
|
|
|
|
|
def write_agent_log(log_path: Path, event: str, **fields: object) -> None:
|
|
"""Append one JSONL audit event."""
|
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
record = {
|
|
"created": utcnow().isoformat(),
|
|
"event": event,
|
|
**{key: json_log_value(value) for key, value in fields.items()},
|
|
}
|
|
with log_path.open("a", encoding="utf-8") as file:
|
|
file.write(json.dumps(record, sort_keys=True))
|
|
file.write("\n")
|
|
|
|
|
|
def json_log_value(value: object) -> object:
|
|
"""Return a JSON-serializable value for audit logs."""
|
|
if is_dataclass(value) and not isinstance(value, type):
|
|
return json_log_value(asdict(value))
|
|
if isinstance(value, dict):
|
|
return {str(key): json_log_value(item) for key, item in value.items()}
|
|
if isinstance(value, list | tuple):
|
|
return [json_log_value(item) for item in value]
|
|
if isinstance(value, set):
|
|
return [json_log_value(item) for item in sorted(value, key=str)]
|
|
if isinstance(value, PathLike):
|
|
return str(value)
|
|
return value
|
|
|
|
|
|
def system_prompt() -> str:
|
|
"""Return the stable system prompt."""
|
|
return """You standardize Audible audiobook metadata against a private catalog.
|
|
|
|
Rules:
|
|
- You must use the provided tools before returning final metadata.
|
|
- Only use author_id, series_id, or book_id values returned by tools.
|
|
- Return final metadata as JSON only. Do not wrap it in Markdown.
|
|
- The final JSON object must contain author_id, book_id, title, series_id, series_index, confidence, and evidence.
|
|
- title must be a canonical title slug using lower-case words separated by hyphens.
|
|
- Use series_id null and series_index 0 for standalone books.
|
|
- If you use a series_id, series_index must be an integer greater than or equal to 1.
|
|
- Do not create publisher collections or author collections as series unless the book metadata clearly gives a
|
|
numbered series.
|
|
- Series belong to authors. Use a series_id only when it belongs to the selected author_id.
|
|
- Always search for the author before creating one. If no exact author slug exists, call ensure_author.
|
|
- Always search for a series with author_id before creating one. If no exact series slug exists, call ensure_series.
|
|
- Always search for a book before creating one. If no exact title slug exists, call ensure_book.
|
|
- If a tool returns an error, correct your tool arguments or final metadata before continuing.
|
|
- confidence must be a number from 0 to 1.
|
|
- evidence must be a short list of strings explaining which filename, tags, and catalog rows support the answer."""
|
|
|
|
|
|
def forced_final_prompt() -> str:
|
|
"""Return the no-tools finalization prompt."""
|
|
return (
|
|
"Stop calling tools. Return final metadata as JSON only using the tool results already provided. "
|
|
"If search_books returned no matching rows but author and series are known, use book_id null and resolve "
|
|
"the title slug from the AAX filename and ffprobe tags. The validator will create the missing book. "
|
|
"Use only author_id and series_id values returned by earlier tool results."
|
|
)
|
|
|
|
|
|
def user_prompt(aax_file_name: str, metadata: dict[str, str]) -> str:
|
|
"""Build the user prompt from source metadata."""
|
|
return (
|
|
"Resolve this Audible audiobook.\n\n"
|
|
f"AAX file name: {aax_file_name}\n\n"
|
|
"ffprobe format tags:\n"
|
|
f"{json.dumps(metadata, indent=2, sort_keys=True)}"
|
|
)
|
|
|
|
|
|
def parse_final_json_content(content: str) -> object:
|
|
"""Parse final model content, accepting bare or fenced JSON."""
|
|
stripped = content.strip()
|
|
if match := FENCED_JSON_PATTERN.fullmatch(stripped):
|
|
stripped = match.group("json").strip()
|
|
return json.loads(stripped)
|
|
|
|
|
|
def parse_final_metadata_fields(raw_metadata: object) -> FinalMetadataFields:
|
|
"""Parse the model's final JSON object into typed fields."""
|
|
if not isinstance(raw_metadata, dict):
|
|
msg = "Final metadata must be a JSON object"
|
|
raise MetadataResolutionError(msg)
|
|
data = {str(key): value for key, value in raw_metadata.items()}
|
|
return FinalMetadataFields(
|
|
author_id=required_int(data, "author_id"),
|
|
book_id=optional_int(data.get("book_id"), "book_id"),
|
|
title=required_string(data, "title"),
|
|
series_id=optional_int(data.get("series_id"), "series_id"),
|
|
series_index=required_int(data, "series_index"),
|
|
confidence=required_float(data, "confidence"),
|
|
evidence=required_string_list(data, "evidence"),
|
|
)
|
|
|
|
|
|
def review_metadata(reason: str, config: AgentConfig) -> StandardBookMetadata:
|
|
"""Return a metadata result that must be reviewed manually."""
|
|
return StandardBookMetadata(
|
|
author_id=0,
|
|
author="unknown_author",
|
|
book_id=None,
|
|
title="unknown-title",
|
|
series_id=None,
|
|
series=config.standalone_series,
|
|
series_index=0,
|
|
confidence=0,
|
|
needs_review=True,
|
|
evidence=[reason],
|
|
)
|
|
|
|
|
|
def required_float(data: dict[str, object], key: str) -> float:
|
|
"""Read a required float field."""
|
|
value = data.get(key)
|
|
if isinstance(value, bool) or not isinstance(value, int | float):
|
|
msg = f"{key} must be a number"
|
|
raise MetadataResolutionError(msg)
|
|
confidence = float(value)
|
|
if confidence < 0 or confidence > 1:
|
|
msg = f"{key} must be between 0 and 1"
|
|
raise MetadataResolutionError(msg)
|
|
return confidence
|
|
|
|
|
|
def required_string_list(data: dict[str, object], key: str) -> list[str]:
|
|
"""Read a required list of strings."""
|
|
value = data.get(key)
|
|
if not isinstance(value, list) or not value or not all(isinstance(item, str) for item in value):
|
|
msg = f"{key} must be a non-empty list of strings"
|
|
raise MetadataResolutionError(msg)
|
|
strings = [item.strip() for item in value if item.strip()]
|
|
if not strings:
|
|
msg = f"{key} must include at least one non-empty string"
|
|
raise MetadataResolutionError(msg)
|
|
return strings
|