171 lines
5.9 KiB
Python
171 lines
5.9 KiB
Python
"""Embedding model helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
|
|
from python.ebook_search.llm_interface import request_embeddings
|
|
from python.orm.richie import (
|
|
EbookChunk,
|
|
EbookChunkEmbedding1024,
|
|
EbookChunkEmbedding2560,
|
|
EbookChunkEmbedding4096,
|
|
EbookEmbeddingModel,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from python.ebook_search.config import EbookSearchConfig
|
|
|
|
MODEL_DIMENSIONS = {
|
|
"qwen3-embedding-0.6b": 1024,
|
|
"qwen3-embedding-4b": 2560,
|
|
"qwen3-embedding-8b": 4096,
|
|
}
|
|
|
|
|
|
def get_embedding_table(
|
|
dimension: int,
|
|
) -> type[EbookChunkEmbedding1024 | EbookChunkEmbedding2560 | EbookChunkEmbedding4096]:
|
|
"""Return the embedding table mapped to an embedding dimension."""
|
|
embedding_tables = {
|
|
1024: EbookChunkEmbedding1024,
|
|
2560: EbookChunkEmbedding2560,
|
|
4096: EbookChunkEmbedding4096,
|
|
}
|
|
table = embedding_tables.get(dimension)
|
|
if not table:
|
|
msg = f"Embedding dimension {dimension} is not supported"
|
|
raise ValueError(msg)
|
|
return table
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class EmbeddingModelStats:
|
|
"""Embedding coverage for one model."""
|
|
|
|
model_name: str
|
|
dimension: int
|
|
embedded_chunks: int
|
|
total_chunks: int
|
|
|
|
@property
|
|
def missing_chunks(self) -> int:
|
|
"""Return chunks missing this embedding model."""
|
|
return max(self.total_chunks - self.embedded_chunks, 0)
|
|
|
|
|
|
def embed_texts(texts: Sequence[str], config: EbookSearchConfig) -> list[list[float]]:
|
|
"""Embed text with the configured vLLM embedding model."""
|
|
logger.info(
|
|
"ebook_embed_request_start base_url=%s model=%s count=%s",
|
|
config.embedding_base_url,
|
|
config.embedding_model,
|
|
len(texts),
|
|
)
|
|
vectors = request_embeddings(texts, config)
|
|
expected_dimension = MODEL_DIMENSIONS[config.embedding_model]
|
|
for vector in vectors:
|
|
if len(vector) != expected_dimension:
|
|
msg = f"Expected {expected_dimension} dimensions, got {len(vector)}"
|
|
raise ValueError(msg)
|
|
logger.info(
|
|
"ebook_embed_request_complete model=%s count=%s dimension=%s",
|
|
config.embedding_model,
|
|
len(vectors),
|
|
expected_dimension,
|
|
)
|
|
return vectors
|
|
|
|
|
|
def embed_query(query: str, config: EbookSearchConfig) -> list[float]:
|
|
"""Embed a search query with the Qwen retrieval instruction."""
|
|
instructed_query = f"Instruct: Retrieve relevant passages for the query.\nQuery: {query}"
|
|
return embed_texts([instructed_query], config)[0]
|
|
|
|
|
|
def ensure_embedding_models(session: Session) -> None:
|
|
"""Ensure supported embedding model rows exist."""
|
|
for name, dimension in MODEL_DIMENSIONS.items():
|
|
existing = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == name))
|
|
if existing is None:
|
|
session.add(EbookEmbeddingModel(name=name, dimension=dimension, is_default=name == "qwen3-embedding-0.6b"))
|
|
logger.info("ebook_embedding_model_created model=%s dimension=%s", name, dimension)
|
|
session.flush()
|
|
|
|
|
|
def embedding_model_stats(session: Session) -> list[EmbeddingModelStats]:
|
|
"""Return embedding coverage counts for every supported model."""
|
|
total_chunks = session.scalar(select(func.count(EbookChunk.id))) or 0
|
|
models = {
|
|
model.name: model
|
|
for model in session.scalars(
|
|
select(EbookEmbeddingModel)
|
|
.where(EbookEmbeddingModel.name.in_(MODEL_DIMENSIONS))
|
|
.order_by(EbookEmbeddingModel.name)
|
|
)
|
|
}
|
|
|
|
stats: list[EmbeddingModelStats] = []
|
|
for model_name, dimension in MODEL_DIMENSIONS.items():
|
|
model = models.get(model_name)
|
|
embedded_chunks = 0
|
|
if model is not None:
|
|
table = get_embedding_table(dimension)
|
|
embedded_chunks = session.scalar(select(func.count(table.id)).where(table.model_id == model.id)) or 0
|
|
stats.append(
|
|
EmbeddingModelStats(
|
|
model_name=model_name,
|
|
dimension=dimension,
|
|
embedded_chunks=embedded_chunks,
|
|
total_chunks=total_chunks,
|
|
)
|
|
)
|
|
return stats
|
|
|
|
|
|
def embed_missing_chunks(session: Session, config: EbookSearchConfig) -> int:
|
|
"""Embed chunks missing embeddings for the configured model."""
|
|
ensure_embedding_models(session)
|
|
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
|
|
if model is None:
|
|
supported_models = ", ".join(MODEL_DIMENSIONS)
|
|
msg = f"Unknown embedding model: {config.embedding_model}. Supported models: {supported_models}"
|
|
raise ValueError(msg)
|
|
|
|
table = get_embedding_table(model.dimension)
|
|
chunks = list(
|
|
session.scalars(
|
|
select(EbookChunk)
|
|
.outerjoin(table, (table.chunk_id == EbookChunk.id) & (table.model_id == model.id))
|
|
.where(table.id.is_(None))
|
|
.order_by(EbookChunk.id)
|
|
.limit(config.embedding_batch_size)
|
|
)
|
|
)
|
|
if not chunks:
|
|
logger.info("ebook_embed_missing_none model=%s", config.embedding_model)
|
|
return 0
|
|
|
|
logger.info("ebook_embed_missing_batch_start model=%s count=%s", config.embedding_model, len(chunks))
|
|
vectors = embed_texts([chunk.text for chunk in chunks], config)
|
|
rows = [
|
|
{"chunk_id": chunk.id, "model_id": model.id, "embedding": vector}
|
|
for chunk, vector in zip(chunks, vectors, strict=True)
|
|
]
|
|
statement = insert(table).values(rows).on_conflict_do_nothing(index_elements=["chunk_id", "model_id"])
|
|
session.execute(statement)
|
|
session.flush()
|
|
logger.info("ebook_embed_missing_batch_complete model=%s count=%s", config.embedding_model, len(rows))
|
|
return len(rows)
|