191 lines
6.3 KiB
Python
191 lines
6.3 KiB
Python
"""EPUB ingestion into Richie DB."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import tiktoken
|
|
from sqlalchemy import or_, select
|
|
|
|
from python.ebook_search.epub_parse import parse_epub
|
|
from python.orm.richie import EbookChapter, EbookChunk, EbookSource
|
|
|
|
logger = logging.getLogger(__name__)
|
|
DEFAULT_CHUNK_TOKENS = 700
|
|
DEFAULT_CHUNK_OVERLAP = 100
|
|
|
|
if TYPE_CHECKING:
|
|
from sqlalchemy.orm import Session
|
|
|
|
from python.ebook_search.config import EbookSearchConfig
|
|
from python.ebook_search.epub_parse import ParsedChapter
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TextChunk:
|
|
"""A token-bounded chunk of text."""
|
|
|
|
text: str
|
|
token_start: int
|
|
token_count: int
|
|
|
|
|
|
def chunk_text(
|
|
text: str,
|
|
*,
|
|
chunk_tokens: int = DEFAULT_CHUNK_TOKENS,
|
|
overlap_tokens: int = DEFAULT_CHUNK_OVERLAP,
|
|
) -> list[TextChunk]:
|
|
"""Split text into overlapping token chunks."""
|
|
if chunk_tokens <= 0:
|
|
msg = "chunk_tokens must be positive"
|
|
raise ValueError(msg)
|
|
if overlap_tokens < 0 or overlap_tokens >= chunk_tokens:
|
|
msg = "overlap_tokens must be non-negative and smaller than chunk_tokens"
|
|
raise ValueError(msg)
|
|
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
tokens = encoding.encode(text)
|
|
if not tokens:
|
|
return []
|
|
|
|
chunks: list[TextChunk] = []
|
|
step = chunk_tokens - overlap_tokens
|
|
for start in range(0, len(tokens), step):
|
|
chunk = tokens[start : start + chunk_tokens]
|
|
if not chunk:
|
|
continue
|
|
chunks.append(
|
|
TextChunk(
|
|
text=encoding.decode(chunk).strip(),
|
|
token_start=start,
|
|
token_count=len(chunk),
|
|
)
|
|
)
|
|
if start + chunk_tokens >= len(tokens):
|
|
break
|
|
return [chunk for chunk in chunks if chunk.text]
|
|
|
|
|
|
def ingest_configured_paths(session: Session, config: EbookSearchConfig) -> int:
|
|
"""Ingest every EPUB found under configured library paths."""
|
|
count = 0
|
|
for library_path in config.library_paths:
|
|
path = Path(library_path).expanduser()
|
|
logger.info("ebook_ingest_path_start path=%s", path)
|
|
if path.is_file() and path.suffix.lower() == ".epub":
|
|
count += int(ingest_file(session, path))
|
|
elif path.is_dir():
|
|
for epub_path in sorted(path.rglob("*.epub")):
|
|
count += int(ingest_file(session, epub_path))
|
|
else:
|
|
logger.warning("ebook_ingest_path_missing path=%s", path)
|
|
logger.info("ebook_ingest_paths_complete changed_files=%s configured_paths=%s", count, len(config.library_paths))
|
|
return count
|
|
|
|
|
|
def ingest_file(session: Session, path: Path) -> bool:
|
|
"""Ingest one EPUB file. Return True when the database changed."""
|
|
resolved_path = path.expanduser().resolve()
|
|
logger.info("ebook_ingest_file_start path=%s", resolved_path)
|
|
file_hash = sha256_file(resolved_path)
|
|
existing = find_existing_source(session, resolved_path, file_hash)
|
|
if existing is not None and existing.file_sha256 == file_hash:
|
|
stat = resolved_path.stat()
|
|
existing.file_path = str(resolved_path)
|
|
existing.file_mtime = datetime.fromtimestamp(stat.st_mtime, tz=UTC)
|
|
existing.file_size = stat.st_size
|
|
session.flush()
|
|
logger.info("ebook_ingest_file_unchanged source_id=%s path=%s", existing.id, resolved_path)
|
|
return False
|
|
if existing is not None:
|
|
logger.info("ebook_ingest_file_replacing source_id=%s path=%s", existing.id, resolved_path)
|
|
session.delete(existing)
|
|
session.flush()
|
|
|
|
stat = resolved_path.stat()
|
|
parsed = parse_epub(resolved_path)
|
|
source = EbookSource(
|
|
title=parsed.title,
|
|
author=parsed.author,
|
|
language=parsed.language,
|
|
publisher=parsed.publisher,
|
|
identifier=parsed.identifier,
|
|
file_path=str(resolved_path),
|
|
file_sha256=file_hash,
|
|
file_mtime=datetime.fromtimestamp(stat.st_mtime, tz=UTC),
|
|
file_size=stat.st_size,
|
|
)
|
|
session.add(source)
|
|
session.flush()
|
|
|
|
chunk_index = 0
|
|
for spine_index, parsed_chapter in enumerate(parsed.chapters):
|
|
chapter = EbookChapter(
|
|
source_id=source.id,
|
|
spine_index=spine_index,
|
|
title=parsed_chapter.title,
|
|
href=parsed_chapter.href,
|
|
)
|
|
session.add(chapter)
|
|
session.flush()
|
|
chunk_index = add_chapter_chunks(session, source, chapter, parsed_chapter, chunk_index)
|
|
|
|
session.flush()
|
|
logger.info(
|
|
"ebook_ingest_file_complete source_id=%s path=%s chapters=%s chunks=%s",
|
|
source.id,
|
|
resolved_path,
|
|
len(parsed.chapters),
|
|
chunk_index,
|
|
)
|
|
return True
|
|
|
|
|
|
def find_existing_source(session: Session, path: Path, file_hash: str) -> EbookSource | None:
|
|
"""Find an existing source by canonical path or file hash."""
|
|
return session.scalar(
|
|
select(EbookSource).where(or_(EbookSource.file_path == str(path), EbookSource.file_sha256 == file_hash))
|
|
)
|
|
|
|
|
|
def add_chapter_chunks(
|
|
session: Session,
|
|
source: EbookSource,
|
|
chapter: EbookChapter,
|
|
parsed_chapter: ParsedChapter,
|
|
chunk_index: int,
|
|
) -> int:
|
|
"""Add chunk rows for one parsed chapter and return the next chunk index."""
|
|
page_label = parsed_chapter.page_labels[0] if parsed_chapter.page_labels else None
|
|
for text_chunk in chunk_text(parsed_chapter.text):
|
|
session.add(
|
|
EbookChunk(
|
|
source_id=source.id,
|
|
chapter_id=chapter.id,
|
|
chunk_index=chunk_index,
|
|
text=text_chunk.text,
|
|
token_start=text_chunk.token_start,
|
|
token_count=text_chunk.token_count,
|
|
page_label=page_label,
|
|
content_sha256=hashlib.sha256(text_chunk.text.encode()).hexdigest(),
|
|
search_text=f"{source.title} {source.author or ''} {chapter.title or ''} {text_chunk.text}",
|
|
)
|
|
)
|
|
chunk_index += 1
|
|
return chunk_index
|
|
|
|
|
|
def sha256_file(path: Path) -> str:
|
|
"""Calculate the SHA-256 digest for a file."""
|
|
digest = hashlib.sha256()
|
|
with path.open("rb") as file:
|
|
for block in iter(lambda: file.read(1024 * 1024), b""):
|
|
digest.update(block)
|
|
return digest.hexdigest()
|