"""Embedding model helpers.""" from __future__ import annotations import logging from dataclasses import dataclass from typing import TYPE_CHECKING from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import insert from python.ebook_search.llm_interface import request_embeddings from python.orm.richie import ( EbookChunk, EbookChunkEmbedding1024, EbookChunkEmbedding2560, EbookChunkEmbedding4096, EbookEmbeddingModel, ) logger = logging.getLogger(__name__) if TYPE_CHECKING: from collections.abc import Sequence from sqlalchemy.orm import Session from python.ebook_search.config import EbookSearchConfig MODEL_DIMENSIONS = { "qwen3-embedding-0.6b": 1024, "qwen3-embedding-4b": 2560, "qwen3-embedding-8b": 4096, } def get_embedding_table( dimension: int, ) -> type[EbookChunkEmbedding1024 | EbookChunkEmbedding2560 | EbookChunkEmbedding4096]: """Return the embedding table mapped to an embedding dimension.""" embedding_tables = { 1024: EbookChunkEmbedding1024, 2560: EbookChunkEmbedding2560, 4096: EbookChunkEmbedding4096, } table = embedding_tables.get(dimension) if not table: msg = f"Embedding dimension {dimension} is not supported" raise ValueError(msg) return table @dataclass(frozen=True) class EmbeddingModelStats: """Embedding coverage for one model.""" model_name: str dimension: int embedded_chunks: int total_chunks: int @property def missing_chunks(self) -> int: """Return chunks missing this embedding model.""" return max(self.total_chunks - self.embedded_chunks, 0) def embed_texts(texts: Sequence[str], config: EbookSearchConfig) -> list[list[float]]: """Embed text with the configured vLLM embedding model.""" logger.info( "ebook_embed_request_start base_url=%s model=%s count=%s", config.embedding_base_url, config.embedding_model, len(texts), ) vectors = request_embeddings(texts, config) expected_dimension = MODEL_DIMENSIONS[config.embedding_model] for vector in vectors: if len(vector) != expected_dimension: msg = f"Expected {expected_dimension} dimensions, got {len(vector)}" raise ValueError(msg) logger.info( "ebook_embed_request_complete model=%s count=%s dimension=%s", config.embedding_model, len(vectors), expected_dimension, ) return vectors def embed_query(query: str, config: EbookSearchConfig) -> list[float]: """Embed a search query with the Qwen retrieval instruction.""" instructed_query = f"Instruct: Retrieve relevant passages for the query.\nQuery: {query}" return embed_texts([instructed_query], config)[0] def ensure_embedding_models(session: Session) -> None: """Ensure supported embedding model rows exist.""" for name, dimension in MODEL_DIMENSIONS.items(): existing = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == name)) if existing is None: session.add(EbookEmbeddingModel(name=name, dimension=dimension, is_default=name == "qwen3-embedding-0.6b")) logger.info("ebook_embedding_model_created model=%s dimension=%s", name, dimension) session.flush() def embedding_model_stats(session: Session) -> list[EmbeddingModelStats]: """Return embedding coverage counts for every supported model.""" total_chunks = session.scalar(select(func.count(EbookChunk.id))) or 0 models = { model.name: model for model in session.scalars( select(EbookEmbeddingModel) .where(EbookEmbeddingModel.name.in_(MODEL_DIMENSIONS)) .order_by(EbookEmbeddingModel.name) ) } stats: list[EmbeddingModelStats] = [] for model_name, dimension in MODEL_DIMENSIONS.items(): model = models.get(model_name) embedded_chunks = 0 if model is not None: table = get_embedding_table(dimension) embedded_chunks = session.scalar(select(func.count(table.id)).where(table.model_id == model.id)) or 0 stats.append( EmbeddingModelStats( model_name=model_name, dimension=dimension, embedded_chunks=embedded_chunks, total_chunks=total_chunks, ) ) return stats def embed_missing_chunks(session: Session, config: EbookSearchConfig) -> int: """Embed chunks missing embeddings for the configured model.""" ensure_embedding_models(session) model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model)) if model is None: supported_models = ", ".join(MODEL_DIMENSIONS) msg = f"Unknown embedding model: {config.embedding_model}. Supported models: {supported_models}" raise ValueError(msg) table = get_embedding_table(model.dimension) chunks = list( session.scalars( select(EbookChunk) .outerjoin(table, (table.chunk_id == EbookChunk.id) & (table.model_id == model.id)) .where(table.id.is_(None)) .order_by(EbookChunk.id) .limit(config.embedding_batch_size) ) ) if not chunks: logger.info("ebook_embed_missing_none model=%s", config.embedding_model) return 0 logger.info("ebook_embed_missing_batch_start model=%s count=%s", config.embedding_model, len(chunks)) vectors = embed_texts([chunk.text for chunk in chunks], config) rows = [ {"chunk_id": chunk.id, "model_id": model.id, "embedding": vector} for chunk, vector in zip(chunks, vectors, strict=True) ] statement = insert(table).values(rows).on_conflict_do_nothing(index_elements=["chunk_id", "model_id"]) session.execute(statement) session.flush() logger.info("ebook_embed_missing_batch_complete model=%s count=%s", config.embedding_model, len(rows)) return len(rows)