added bert_topic train.py and infer.py

This commit is contained in:
2026-04-28 22:20:18 -04:00
parent 3056c19f69
commit 2038a90b3c
3 changed files with 292 additions and 0 deletions
+116
View File
@@ -0,0 +1,116 @@
"""Nornsight — BERTopic POC Inference Script.
Loads the trained model and labels a small batch of posts,
writing results to main.post_topic for inspection.
POC: processes a single batch of 1k posts to validate the pipeline end-to-end.
"""
from __future__ import annotations
import logging
import time
from collections import Counter
from pathlib import Path
from bertopic import BERTopic
from sqlalchemy import Engine, func, insert, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import PostTopic, Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Run BERTopic inference against a sample of posts."""
configure_logger()
config = get_bertopic_infer_config()
run_inference(config)
logger.info(
"POC inference complete. Check main.post_topic in DBeaver to inspect results."
)
def run_inference(config: BertTopicInferConfig) -> None:
model_save_path = Path(config.model_save_path)
logger.info(f"Loading BERTopic model from {model_save_path}")
topic_model = BERTopic.load(str(model_save_path))
topic_info = topic_model.get_topic_info()
label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"]))
logger.info(f"Model loaded with {len(label_map)} topics")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
post_ids, texts = get_post_ids_and_test(engine, config)
logger.info(f"Fetched {len(texts)} posts")
logger.info("Running BERTopic transform")
start = time.perf_counter()
topics, _probabilities = topic_model.transform(texts)
elapsed = time.perf_counter() - start
logger.info(f"Transform complete in {elapsed:.1f}s")
# Write results to main.post_topic
records = [
{
"post_id": pid,
"topic_id": int(topic_id),
"topic_label": label_map.get(int(topic_id), "unknown"),
"model_version": config.model_version,
}
for pid, topic_id in zip(post_ids, topics)
]
with Session(engine) as session:
session.execute(insert(PostTopic), records)
session.commit()
count_topics(records)
logger.info(f"Wrote {len(records)} topic labels to main.post_topic")
def get_post_ids_and_test(
engine: Engine,
config: BertTopicInferConfig,
) -> None | tuple[list[int], list[str]]:
with Session(engine) as session:
logger.info(f"Fetching {config.poc_batch_size} posts for inference")
# Pull a fresh batch for inference — distinct from training sample
# using a fixed seed offset so we're not re-labeling training posts
stmt = select(Posts).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
)
if config.poc_batch_size > 0:
stmt = stmt.limit(config.poc_batch_size)
posts = session.scalars(stmt).all()
if not posts:
logger.warning("No posts were selected for inference")
return [], []
post_ids = [post.post_id for post in posts]
texts = [post.text.strip() for post in posts]
return post_ids, texts
def count_topics(records: list[dict]) -> None:
topic_counts = Counter(record.get("topic_label", "unknown") for record in records)
logger.info("Topic distribution in this batch:")
for label, count in topic_counts.most_common(10):
logger.info(" %s: %d", label, count)
if __name__ == "__main__":
main()
+119
View File
@@ -0,0 +1,119 @@
"""Nornsight — BERTopic POC Training Script.
Pulls a small stratified sample (~11.5k posts) from main.posts,
trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally.
POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts)
Full training rate will be: random() < 0.005 (~1.08M posts)
"""
from __future__ import annotations
import logging
import time
from pathlib import Path
from bertopic import BERTopic
from sklearn.cluster import MiniBatchKMeans
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config
from pipelines.orm.common import get_postgres_engine
from pipelines.orm.data_science_dev.posts import Posts
from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS
from pipelines.pipelines.common import configure_logger
logger = logging.getLogger(__name__)
def main() -> None:
"""Train and persist the BERTopic model."""
configure_logger()
config = get_bertopic_train_config()
docs = load_sample(config)
if not docs:
logger.warning("No training documents were selected")
return
train(docs, config)
logger.info(f"Done. Model saved as version {config.model_version}")
logger.info("Next: run infer.py to label a sample of posts in the database")
def load_sample(config: BertTopicTrainConfig) -> list[str]:
logger.info("Connecting to PostgreSQL via SQLAlchemy")
engine = get_postgres_engine(name="DATA_SCIENCE_DEV")
logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})")
start = time.perf_counter()
with Session(engine) as session:
texts = session.scalars(
select(Posts.text).where(
Posts.text.is_not(None),
Posts.langs.in_(ENGLISH_LANGS),
func.length(Posts.text) > config.min_text_length,
func.random() < config.sample_rate,
)
).all()
elapsed = time.perf_counter() - start
logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s")
# Basic cleaning — strip whitespace and deduplicate
docs = list({text.strip() for text in texts})
logger.info(f"After cleaning and dedup: {len(docs)} posts")
return docs
def train(docs: list[str], config: BertTopicTrainConfig) -> None:
logger.info(
f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})"
)
cluster_model = MiniBatchKMeans(
n_clusters=config.n_topics,
random_state=42,
batch_size=1024,
n_init=3,
verbose=1,
)
topic_model = BERTopic(
hdbscan_model=cluster_model,
language="english",
calculate_probabilities=False, # saves memory
verbose=True,
)
logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)")
start = time.perf_counter()
topic_model.fit_transform(docs)
elapsed = time.perf_counter() - start
logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)")
# Log topic summary for quick inspection
topic_info = topic_model.get_topic_info()
logger.info(f"Topics found: {len(topic_info)}")
logger.info(f"\n{topic_info.to_string()}")
model_save_path = Path(config.model_save_path)
model_save_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving model to {model_save_path}")
topic_model.save(
str(model_save_path),
serialization="safetensors",
save_ctfidf=True,
save_embedding_model=True,
)
logger.info("Model saved")
if __name__ == "__main__":
main()