Compare commits
8 Commits
19312962f8
...
cd38a0f277
| Author | SHA1 | Date | |
|---|---|---|---|
| cd38a0f277 | |||
| 61d86446ed | |||
| e14c20010f | |||
| bd87dd2015 | |||
| 94493647a6 | |||
| 70d65bbbe0 | |||
| 4dee9f76a7 | |||
| 0e874a3489 |
+1
-1
@@ -9,9 +9,9 @@ import typer
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from python.api.middleware import ZstdMiddleware
|
|
||||||
from python.api.routers import contact_router, views_router
|
from python.api.routers import contact_router, views_router
|
||||||
from python.common import configure_logger
|
from python.common import configure_logger
|
||||||
|
from python.fastapi_tools import ZstdMiddleware
|
||||||
from python.orm.common import get_postgres_engine
|
from python.orm.common import get_postgres_engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from python.api.dependencies import DbSession
|
from python.fastapi_tools.db import DbSession
|
||||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||||
|
|
||||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from fastapi.templating import Jinja2Templates
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session, selectinload
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
from python.api.dependencies import DbSession
|
from python.fastapi_tools.db import DbSession
|
||||||
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
from python.orm.richie.contact import Contact, ContactRelationship, Need, RelationshipType
|
||||||
|
|
||||||
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
TEMPLATES_DIR = Path(__file__).parent.parent / "templates"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from python.ebook_search.bm25_corpus import refresh_bm25_corpus
|
from python.ebook_search.bm25_corpus import load_bm25_corpus, refresh_bm25_corpus
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@@ -56,3 +56,5 @@ def refresh_bm25_for_engine(engine: Engine, config: EbookSearchConfig) -> None:
|
|||||||
"""Refresh the BM25 corpus using a SQLAlchemy engine."""
|
"""Refresh the BM25 corpus using a SQLAlchemy engine."""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
refresh_bm25_corpus(session, config)
|
refresh_bm25_corpus(session, config)
|
||||||
|
load_bm25_corpus.cache_clear()
|
||||||
|
logger.info("ebook_bm25_corpus_cache_cleared_after_refresh")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from python.common import configure_logger
|
from python.common import configure_logger
|
||||||
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
|
from python.ebook_search.api.bm25_tasks import cancel_bm25_refresh
|
||||||
from python.ebook_search.api.routes import register_admin_routes, register_page_routes, register_search_routes
|
from python.ebook_search.api.routes import admin_router, page_router, search_router
|
||||||
from python.ebook_search.api.web import STATIC_DIR
|
from python.ebook_search.api.web import STATIC_DIR
|
||||||
from python.ebook_search.bm25_corpus import ensure_bm25_corpus
|
from python.ebook_search.bm25_corpus import ensure_bm25_corpus
|
||||||
from python.ebook_search.config import load_config
|
from python.ebook_search.config import load_config
|
||||||
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
|||||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||||
"""Manage application startup and shutdown resources."""
|
"""Manage application startup and shutdown resources."""
|
||||||
logger.info("ebook_search_startup")
|
logger.info("ebook_search_startup")
|
||||||
app.state.engine = get_postgres_engine(name="RICHIE")
|
app.state.engine = get_postgres_engine(name="RICHIE", vector_engine=True)
|
||||||
with Session(app.state.engine) as session:
|
with Session(app.state.engine) as session:
|
||||||
ensure_bm25_corpus(session, app.state.config)
|
ensure_bm25_corpus(session, app.state.config)
|
||||||
try:
|
try:
|
||||||
@@ -55,9 +55,11 @@ def create_app() -> FastAPI:
|
|||||||
app.state.config.answer_enabled,
|
app.state.config.answer_enabled,
|
||||||
len(app.state.config.library_paths),
|
len(app.state.config.library_paths),
|
||||||
)
|
)
|
||||||
register_page_routes(app)
|
|
||||||
register_search_routes(app)
|
app.include_router(admin_router)
|
||||||
register_admin_routes(app)
|
app.include_router(page_router)
|
||||||
|
app.include_router(search_router)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
"""EPUB search web route modules."""
|
"""EPUB search web route modules."""
|
||||||
|
|
||||||
from python.ebook_search.api.routes import admin, page, search
|
from python.ebook_search.api.routes.admin import router as admin_router
|
||||||
|
from python.ebook_search.api.routes.page import router as page_router
|
||||||
register_admin_routes = admin.register_admin_routes
|
from python.ebook_search.api.routes.search import router as search_router
|
||||||
register_page_routes = page.register_page_routes
|
|
||||||
register_search_routes = search.register_search_routes
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"admin",
|
"admin_router",
|
||||||
"page",
|
"page_router",
|
||||||
"register_admin_routes",
|
"search_router",
|
||||||
"register_page_routes",
|
|
||||||
"register_search_routes",
|
|
||||||
"search",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
@@ -15,20 +14,12 @@ from python.ebook_search.api.web import templates
|
|||||||
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
|
from python.ebook_search.embeddings import embed_missing_chunks, embedding_model_stats
|
||||||
from python.ebook_search.ingest import ingest_configured_paths
|
from python.ebook_search.ingest import ingest_configured_paths
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin")
|
router = APIRouter(prefix="/admin")
|
||||||
EMBED_ALL_BATCH_SIZE = 32
|
EMBED_ALL_BATCH_SIZE = 32
|
||||||
|
|
||||||
|
|
||||||
def register_admin_routes(app: FastAPI) -> None:
|
|
||||||
"""Register admin routes on the app."""
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_class=HTMLResponse)
|
@router.get("", response_class=HTMLResponse)
|
||||||
def admin(request: Request) -> HTMLResponse:
|
def admin(request: Request) -> HTMLResponse:
|
||||||
"""Render the admin page."""
|
"""Render the admin page."""
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
@@ -13,19 +12,11 @@ from sqlalchemy.orm import Session
|
|||||||
from python.ebook_search.api.web import templates
|
from python.ebook_search.api.web import templates
|
||||||
from python.orm.richie import EbookSource
|
from python.orm.richie import EbookSource
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def register_page_routes(app: FastAPI) -> None:
|
|
||||||
"""Register page routes on the app."""
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_class=HTMLResponse)
|
@router.get("/", response_class=HTMLResponse)
|
||||||
def index(request: Request) -> HTMLResponse:
|
def index(request: Request) -> HTMLResponse:
|
||||||
"""Render the search page."""
|
"""Render the search page."""
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import TYPE_CHECKING, Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Form, Request
|
from fastapi import APIRouter, Form, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
@@ -15,19 +15,11 @@ from python.ebook_search.api.web import templates
|
|||||||
from python.ebook_search.search import search_ebooks
|
from python.ebook_search.search import search_ebooks
|
||||||
from python.ebook_search.timing import runtime_step_from_start
|
from python.ebook_search.timing import runtime_step_from_start
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def register_search_routes(app: FastAPI) -> None:
|
|
||||||
"""Register search routes on the app."""
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_class=HTMLResponse)
|
@router.post("/search", response_class=HTMLResponse)
|
||||||
def search(
|
def search(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@@ -108,11 +108,10 @@ def refresh_bm25_corpus(
|
|||||||
)
|
)
|
||||||
write_bm25_corpus(index_path, records, manifest)
|
write_bm25_corpus(index_path, records, manifest)
|
||||||
logger.info(
|
logger.info(
|
||||||
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s note=%s",
|
"ebook_bm25_index_refreshed path=%s chunks=%s created_at=%s",
|
||||||
index_path,
|
index_path,
|
||||||
manifest.chunk_count,
|
manifest.chunk_count,
|
||||||
manifest.created_at.isoformat(),
|
manifest.created_at.isoformat(),
|
||||||
"restart_service_to_use_refreshed_bm25_cache",
|
|
||||||
)
|
)
|
||||||
return manifest
|
return manifest
|
||||||
|
|
||||||
@@ -121,15 +120,10 @@ def refresh_bm25_corpus(
|
|||||||
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
|
def load_bm25_corpus(config: EbookSearchConfig) -> BM25Corpus:
|
||||||
"""Load the BM25 corpus into memory once per process.
|
"""Load the BM25 corpus into memory once per process.
|
||||||
|
|
||||||
This cache intentionally does not notice later on-disk corpus refreshes. Restart the service after rebuilding the
|
Background refresh tasks clear this cache after rebuilding the on-disk corpus.
|
||||||
BM25 corpus for searches to use the new index.
|
|
||||||
"""
|
"""
|
||||||
index_path = bm25_index_path(config)
|
index_path = bm25_index_path(config)
|
||||||
logger.info(
|
logger.info("ebook_bm25_corpus_cache_load path=%s", index_path)
|
||||||
"ebook_bm25_corpus_cache_load path=%s note=%s",
|
|
||||||
index_path,
|
|
||||||
"restart_service_after_bm25_refresh",
|
|
||||||
)
|
|
||||||
manifest = read_bm25_manifest(index_path)
|
manifest = read_bm25_manifest(index_path)
|
||||||
if manifest is None or not bm25_index_exists(index_path, manifest):
|
if manifest is None or not bm25_index_exists(index_path, manifest):
|
||||||
msg = f"BM25 corpus is not available: {index_path}"
|
msg = f"BM25 corpus is not available: {index_path}"
|
||||||
@@ -172,13 +166,7 @@ def fetch_bm25_corpus_records(session: Session) -> list[dict[str, object]]:
|
|||||||
EbookSource.author.label("source_author"),
|
EbookSource.author.label("source_author"),
|
||||||
EbookChapter.title.label("chapter_title"),
|
EbookChapter.title.label("chapter_title"),
|
||||||
EbookChunk.page_label.label("page_label"),
|
EbookChunk.page_label.label("page_label"),
|
||||||
func.concat_ws(
|
EbookChunk.search_text.label("bm25_text"),
|
||||||
" ",
|
|
||||||
EbookSource.title,
|
|
||||||
EbookSource.author,
|
|
||||||
EbookChapter.title,
|
|
||||||
EbookChunk.search_text,
|
|
||||||
).label("bm25_text"),
|
|
||||||
)
|
)
|
||||||
.select_from(EbookChunk)
|
.select_from(EbookChunk)
|
||||||
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
|
.join(EbookSource, EbookSource.id == EbookChunk.source_id)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ if TYPE_CHECKING:
|
|||||||
from python.ebook_search.search import SearchResult
|
from python.ebook_search.search import SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
RERANK_SCORE_WEIGHT = 0.7
|
||||||
|
HYBRID_SCORE_WEIGHT = 0.3
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -110,7 +112,7 @@ def clamp_score(score: float) -> float:
|
|||||||
|
|
||||||
def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
|
def final_rerank_score(result: SearchResult, rerank_score: float, candidates: list[SearchResult]) -> float:
|
||||||
"""Combine rerank relevance with normalized hybrid retrieval evidence."""
|
"""Combine rerank relevance with normalized hybrid retrieval evidence."""
|
||||||
return rerank_score * normalized_hybrid_score(result, candidates)
|
return (RERANK_SCORE_WEIGHT * rerank_score) + (HYBRID_SCORE_WEIGHT * normalized_hybrid_score(result, candidates))
|
||||||
|
|
||||||
|
|
||||||
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
|
def normalized_hybrid_score(result: SearchResult, candidates: list[SearchResult]) -> float:
|
||||||
|
|||||||
@@ -93,13 +93,14 @@ def search_ebooks(
|
|||||||
|
|
||||||
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
|
logger.info("ebook_search_start query_length=%s rerank=%s", len(query), rerank)
|
||||||
timings: list[RuntimeStep] = []
|
timings: list[RuntimeStep] = []
|
||||||
retrieval_query, timing = timed_result("Query preparation", retrieval_query_from_text, query)
|
bm25_query, timing = timed_result("BM25 query preparation", retrieval_query_from_text, query)
|
||||||
timings.append(timing)
|
timings.append(timing)
|
||||||
retrieval, timing = timed_result(
|
retrieval, timing = timed_result(
|
||||||
"Hybrid retrieval",
|
"Hybrid retrieval",
|
||||||
parallel_retrieval,
|
parallel_retrieval,
|
||||||
engine,
|
engine,
|
||||||
retrieval_query,
|
query,
|
||||||
|
bm25_query,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
timings.extend(retrieval.timings)
|
timings.extend(retrieval.timings)
|
||||||
@@ -130,7 +131,12 @@ def search_ebooks(
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def parallel_retrieval(engine: Engine, query: str, config: EbookSearchConfig) -> RetrievalResponse:
|
def parallel_retrieval(
|
||||||
|
engine: Engine,
|
||||||
|
vector_query: str,
|
||||||
|
bm25_query: str,
|
||||||
|
config: EbookSearchConfig,
|
||||||
|
) -> RetrievalResponse:
|
||||||
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
|
"""Run vector and BM25 candidate retrieval concurrently with separate database sessions."""
|
||||||
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor:
|
with ThreadPoolExecutor(max_workers=2, thread_name_prefix="ebook-search") as executor:
|
||||||
vector_future = executor.submit(
|
vector_future = executor.submit(
|
||||||
@@ -138,14 +144,14 @@ def parallel_retrieval(engine: Engine, query: str, config: EbookSearchConfig) ->
|
|||||||
"Embedding + vector search",
|
"Embedding + vector search",
|
||||||
vector_candidates,
|
vector_candidates,
|
||||||
engine,
|
engine,
|
||||||
query,
|
vector_query,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
bm25_future = executor.submit(
|
bm25_future = executor.submit(
|
||||||
timed_result,
|
timed_result,
|
||||||
"BM25 search",
|
"BM25 search",
|
||||||
bm25_candidates,
|
bm25_candidates,
|
||||||
query,
|
bm25_query,
|
||||||
config,
|
config,
|
||||||
)
|
)
|
||||||
vector_results, vector_timing = vector_future.result()
|
vector_results, vector_timing = vector_future.result()
|
||||||
@@ -196,7 +202,7 @@ def apply_rerank(
|
|||||||
|
|
||||||
|
|
||||||
def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]:
|
def vector_candidates(engine: Engine, query: str, config: EbookSearchConfig) -> list[SearchResult]:
|
||||||
"""Return pgvector cosine candidates for a normalized query."""
|
"""Return pgvector cosine candidates for a natural-language query."""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
|
model = session.scalar(select(EbookEmbeddingModel).where(EbookEmbeddingModel.name == config.embedding_model))
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Reusable FastAPI tools."""
|
||||||
|
|
||||||
|
from python.fastapi_tools.db import DbSession, get_db
|
||||||
|
from python.fastapi_tools.zstd_middleware import ZstdMiddleware
|
||||||
|
|
||||||
|
__all__ = ["DbSession", "ZstdMiddleware", "get_db"]
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Middleware for the FastAPI application."""
|
"""Zstd response compression middleware."""
|
||||||
|
|
||||||
from compression import zstd
|
from compression import zstd
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
+24
-2
@@ -31,8 +31,24 @@ def get_connection_info(name: str) -> tuple[str, str, str, str, str | None]:
|
|||||||
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
return cast("tuple[str, str, str, str, str | None]", (database, host, port, username, password))
|
||||||
|
|
||||||
|
|
||||||
def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -> Engine:
|
def get_postgres_engine(
|
||||||
"""Create a SQLAlchemy engine from environment variables."""
|
*,
|
||||||
|
name: str = "POSTGRES",
|
||||||
|
pool_pre_ping: bool = True,
|
||||||
|
vector_engine: bool = False,
|
||||||
|
) -> Engine:
|
||||||
|
"""Create a SQLAlchemy engine from environment variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str, optional): The name of the environment variable prefix. Defaults to "POSTGRES".
|
||||||
|
pool_pre_ping (bool, optional): Whether to ping the database before each connection. Defaults to True.
|
||||||
|
This fixes the issue of trying to use a conection that has timed out on the database side.
|
||||||
|
vector_engine (bool, optional): Whether to use the vector search schema. Defaults to False.
|
||||||
|
This updates the search path the incldued the vecore types and operators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Engine: The SQLAlchemy engine.
|
||||||
|
"""
|
||||||
database, host, port, username, password = get_connection_info(name)
|
database, host, port, username, password = get_connection_info(name)
|
||||||
|
|
||||||
url = URL.create(
|
url = URL.create(
|
||||||
@@ -44,8 +60,14 @@ def get_postgres_engine(*, name: str = "POSTGRES", pool_pre_ping: bool = True) -
|
|||||||
database=database,
|
database=database,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
connect_args = {}
|
||||||
|
# There more better way to do this is with separate PG account and a dedicated vector schema for the vector types
|
||||||
|
if vector_engine:
|
||||||
|
connect_args["options"] = "-csearch_path=main,public"
|
||||||
|
|
||||||
return create_engine(
|
return create_engine(
|
||||||
url=url,
|
url=url,
|
||||||
pool_pre_ping=pool_pre_ping,
|
pool_pre_ping=pool_pre_ping,
|
||||||
pool_recycle=1800,
|
pool_recycle=1800,
|
||||||
|
connect_args=connect_args,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from python.ebook_search.bm25_corpus import (
|
|||||||
BM25CorpusUnavailableError,
|
BM25CorpusUnavailableError,
|
||||||
BM25Manifest,
|
BM25Manifest,
|
||||||
ensure_bm25_corpus,
|
ensure_bm25_corpus,
|
||||||
|
fetch_bm25_corpus_records,
|
||||||
load_bm25_corpus,
|
load_bm25_corpus,
|
||||||
)
|
)
|
||||||
from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_config, normalize_embedding_model
|
from python.ebook_search.config import EbookSearchConfig, RerankConfig, load_config, normalize_embedding_model
|
||||||
@@ -33,7 +34,7 @@ from python.ebook_search.search import (
|
|||||||
search_ebooks,
|
search_ebooks,
|
||||||
)
|
)
|
||||||
from python.ebook_search.timing import RuntimeStep
|
from python.ebook_search.timing import RuntimeStep
|
||||||
from python.orm.richie import EbookEmbeddingModel, EbookSource, RichieBase
|
from python.orm.richie import EbookChapter, EbookChunk, EbookEmbeddingModel, EbookSource, RichieBase
|
||||||
|
|
||||||
|
|
||||||
def test_chunk_text_uses_overlap() -> None:
|
def test_chunk_text_uses_overlap() -> None:
|
||||||
@@ -86,6 +87,47 @@ def test_find_existing_source_matches_path_or_hash() -> None:
|
|||||||
assert find_existing_source(session, Path("/new/book.epub"), "a" * 64) == source
|
assert find_existing_source(session, Path("/new/book.epub"), "a" * 64) == source
|
||||||
|
|
||||||
|
|
||||||
|
def test_bm25_corpus_uses_existing_search_text_without_duplicate_metadata() -> None:
|
||||||
|
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
||||||
|
RichieBase.metadata.create_all(engine)
|
||||||
|
with sessionmaker(bind=engine, expire_on_commit=False, future=True)() as session:
|
||||||
|
source = EbookSource(
|
||||||
|
title="Book",
|
||||||
|
author="Author",
|
||||||
|
language=None,
|
||||||
|
publisher=None,
|
||||||
|
identifier=None,
|
||||||
|
file_path="/book.epub",
|
||||||
|
file_sha256="a" * 64,
|
||||||
|
file_mtime=datetime.now(tz=UTC),
|
||||||
|
file_size=10,
|
||||||
|
)
|
||||||
|
session.add(source)
|
||||||
|
session.flush()
|
||||||
|
chapter = EbookChapter(source_id=source.id, spine_index=0, title="Chapter", href=None)
|
||||||
|
session.add(chapter)
|
||||||
|
session.flush()
|
||||||
|
session.add(
|
||||||
|
EbookChunk(
|
||||||
|
id=1,
|
||||||
|
source_id=source.id,
|
||||||
|
chapter_id=chapter.id,
|
||||||
|
chunk_index=0,
|
||||||
|
text="content",
|
||||||
|
token_start=0,
|
||||||
|
token_count=1,
|
||||||
|
page_label=None,
|
||||||
|
content_sha256="b" * 64,
|
||||||
|
search_text="Book Author Chapter content",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
records = fetch_bm25_corpus_records(session)
|
||||||
|
|
||||||
|
assert records[0]["bm25_text"] == "Book Author Chapter content"
|
||||||
|
|
||||||
|
|
||||||
def test_reciprocal_rank_fusion_marks_hybrid_source() -> None:
|
def test_reciprocal_rank_fusion_marks_hybrid_source() -> None:
|
||||||
vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")]
|
vector_results = [SearchResult(chunk_id=1, text="a", source_title="A")]
|
||||||
lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")]
|
lexical_results = [SearchResult(chunk_id=2, text="b", source_title="B")]
|
||||||
@@ -119,7 +161,7 @@ def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
|
|||||||
def fake_vector_candidates(received_engine, query, _config):
|
def fake_vector_candidates(received_engine, query, _config):
|
||||||
"""Return vector candidates after confirming BM25 has started."""
|
"""Return vector candidates after confirming BM25 has started."""
|
||||||
received_engines.append(received_engine)
|
received_engines.append(received_engine)
|
||||||
assert query == "parallel"
|
assert query == "what is parallel"
|
||||||
vector_started.set()
|
vector_started.set()
|
||||||
assert bm25_started.wait(timeout=2)
|
assert bm25_started.wait(timeout=2)
|
||||||
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
|
return [SearchResult(chunk_id=1, text="vector", source_title="Vector", vector_score=0.9)]
|
||||||
@@ -135,13 +177,14 @@ def test_search_ebooks_runs_vector_and_bm25_in_parallel(monkeypatch) -> None:
|
|||||||
monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates)
|
monkeypatch.setattr("python.ebook_search.search.bm25_candidates", fake_bm25_candidates)
|
||||||
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
|
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
|
||||||
|
|
||||||
response = search_ebooks(engine, "parallel", config)
|
response = search_ebooks(engine, "what is parallel", config)
|
||||||
|
|
||||||
timings = {step.name: step for step in response.timings}
|
timings = {step.name: step for step in response.timings}
|
||||||
assert [result.chunk_id for result in response.results] == [1, 2]
|
assert [result.chunk_id for result in response.results] == [1, 2]
|
||||||
assert timings["Embedding + vector search"].counts_toward_total is False
|
assert timings["Embedding + vector search"].counts_toward_total is False
|
||||||
assert timings["BM25 search"].counts_toward_total is False
|
assert timings["BM25 search"].counts_toward_total is False
|
||||||
assert timings["Hybrid retrieval"].counts_toward_total is True
|
assert timings["Hybrid retrieval"].counts_toward_total is True
|
||||||
|
assert timings["BM25 query preparation"].counts_toward_total is True
|
||||||
assert received_engines == [engine]
|
assert received_engines == [engine]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ def test_reranking_enabled_reorders_candidates(monkeypatch: pytest.MonkeyPatch)
|
|||||||
results = rerank_chunks("query", candidates(), RerankConfig())
|
results = rerank_chunks("query", candidates(), RerankConfig())
|
||||||
|
|
||||||
assert [result.chunk_id for result in results] == [2, 1, 3]
|
assert [result.chunk_id for result in results] == [2, 1, 3]
|
||||||
assert [round(result.score, 3) for result in results] == [0.45, 0.1, 0.0]
|
assert [round(result.score, 3) for result in results] == [0.78, 0.37, 0.28]
|
||||||
assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4]
|
assert [result.rerank_score for result in results] == [0.9, 0.1, 0.4]
|
||||||
|
|
||||||
|
|
||||||
@@ -100,8 +100,8 @@ def test_reranking_cannot_ignore_hybrid_score(monkeypatch: pytest.MonkeyPatch) -
|
|||||||
results = rerank_chunks("query", candidates, RerankConfig())
|
results = rerank_chunks("query", candidates, RerankConfig())
|
||||||
|
|
||||||
assert [result.chunk_id for result in results] == [1, 2]
|
assert [result.chunk_id for result in results] == [1, 2]
|
||||||
assert results[0].score == 0.7
|
assert results[0].score == pytest.approx(0.79)
|
||||||
assert results[1].score == 0.0
|
assert results[1].score == 0.7
|
||||||
assert results[1].rerank_score == 1.0
|
assert results[1].rerank_score == 1.0
|
||||||
|
|
||||||
|
|
||||||
@@ -129,7 +129,7 @@ def test_malformed_vllm_rerank_json_does_not_crash_search(monkeypatch: pytest.Mo
|
|||||||
|
|
||||||
results = rerank_chunks("query", candidates()[:1], RerankConfig())
|
results = rerank_chunks("query", candidates()[:1], RerankConfig())
|
||||||
|
|
||||||
assert results[0].score == 0.0
|
assert results[0].score == 0.3
|
||||||
|
|
||||||
|
|
||||||
def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
@@ -147,4 +147,4 @@ def test_vllm_rerank_scores_are_clamped(monkeypatch: pytest.MonkeyPatch) -> None
|
|||||||
|
|
||||||
results = rerank_chunks("query", candidates()[:2], RerankConfig())
|
results = rerank_chunks("query", candidates()[:2], RerankConfig())
|
||||||
|
|
||||||
assert [result.rerank_score for result in results] == [0.0, 1.0]
|
assert {result.chunk_id: result.rerank_score for result in results} == {1: 0.0, 2: 1.0}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
from python.ebook_search.api.bm25_tasks import refresh_bm25_for_engine
|
||||||
from python.ebook_search.api.main import create_app
|
from python.ebook_search.api.main import create_app
|
||||||
from python.ebook_search.config import EbookSearchConfig, RerankConfig
|
from python.ebook_search.config import EbookSearchConfig, RerankConfig
|
||||||
from python.ebook_search.embeddings import EmbeddingModelStats
|
from python.ebook_search.embeddings import EmbeddingModelStats
|
||||||
@@ -232,6 +233,29 @@ def test_ui_scan_schedules_bm25_refresh_after_database_change(monkeypatch) -> No
|
|||||||
assert scheduled is True
|
assert scheduled is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_bm25_refresh_clears_loaded_corpus_cache(monkeypatch) -> None:
|
||||||
|
refreshed: list[object] = []
|
||||||
|
cache_cleared = False
|
||||||
|
|
||||||
|
def fake_refresh_bm25_corpus(session, config):
|
||||||
|
refreshed.append((session, config))
|
||||||
|
|
||||||
|
def fake_cache_clear():
|
||||||
|
nonlocal cache_cleared
|
||||||
|
cache_cleared = True
|
||||||
|
|
||||||
|
monkeypatch.setattr("python.ebook_search.api.bm25_tasks.refresh_bm25_corpus", fake_refresh_bm25_corpus)
|
||||||
|
monkeypatch.setattr("python.ebook_search.api.bm25_tasks.load_bm25_corpus.cache_clear", fake_cache_clear)
|
||||||
|
engine = create_engine("sqlite+pysqlite:///:memory:", future=True)
|
||||||
|
config = EbookSearchConfig(rerank=RerankConfig(enabled=False))
|
||||||
|
|
||||||
|
refresh_bm25_for_engine(engine, config)
|
||||||
|
|
||||||
|
assert len(refreshed) == 1
|
||||||
|
assert refreshed[0][1] == config
|
||||||
|
assert cache_cleared is True
|
||||||
|
|
||||||
|
|
||||||
def test_admin_page_shows_embedding_counts_by_model(monkeypatch) -> None:
|
def test_admin_page_shows_embedding_counts_by_model(monkeypatch) -> None:
|
||||||
def fake_embedding_model_stats(_session):
|
def fake_embedding_model_stats(_session):
|
||||||
return [
|
return [
|
||||||
|
|||||||
Reference in New Issue
Block a user