"""Nornsight — BERTopic POC Training Script. Pulls a small stratified sample (~11.5k posts) from main.posts, trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally. POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts) Full training rate will be: random() < 0.005 (~1.08M posts) """ from __future__ import annotations import logging import time from pathlib import Path from bertopic import BERTopic from sklearn.cluster import MiniBatchKMeans from sqlalchemy import func, select from sqlalchemy.orm import Session from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config from pipelines.orm.common import get_postgres_engine from pipelines.orm.data_science_dev.posts import Posts from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS from pipelines.pipelines.common import configure_logger logger = logging.getLogger(__name__) def main() -> None: """Train and persist the BERTopic model.""" configure_logger() config = get_bertopic_train_config() docs = load_sample(config) if not docs: logger.warning("No training documents were selected") return train(docs, config) logger.info(f"Done. Model saved as version {config.model_version}") logger.info("Next: run infer.py to label a sample of posts in the database") def load_sample(config: BertTopicTrainConfig) -> list[str]: logger.info("Connecting to PostgreSQL via SQLAlchemy") engine = get_postgres_engine(name="DATA_SCIENCE_DEV") logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})") start = time.perf_counter() with Session(engine) as session: texts = session.scalars( select(Posts.text).where( Posts.text.is_not(None), Posts.langs.in_(ENGLISH_LANGS), func.length(Posts.text) > config.min_text_length, func.random() < config.sample_rate, ) ).all() elapsed = time.perf_counter() - start logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s") # Basic cleaning — strip whitespace and deduplicate docs = list({text.strip() for text in texts}) logger.info(f"After cleaning and dedup: {len(docs)} posts") return docs def train(docs: list[str], config: BertTopicTrainConfig) -> None: logger.info( f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})" ) cluster_model = MiniBatchKMeans( n_clusters=config.n_topics, random_state=42, batch_size=1024, n_init=3, verbose=1, ) topic_model = BERTopic( hdbscan_model=cluster_model, language="english", calculate_probabilities=False, # saves memory verbose=True, ) logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)") start = time.perf_counter() topic_model.fit_transform(docs) elapsed = time.perf_counter() - start logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)") # Log topic summary for quick inspection topic_info = topic_model.get_topic_info() logger.info(f"Topics found: {len(topic_info)}") logger.info(f"\n{topic_info.to_string()}") model_save_path = Path(config.model_save_path) model_save_path.mkdir(parents=True, exist_ok=True) logger.info(f"Saving model to {model_save_path}") topic_model.save( str(model_save_path), serialization="safetensors", save_ctfidf=True, save_embedding_model=True, ) logger.info("Model saved") if __name__ == "__main__": main()