From 45bdd7b6293c611098182978803e65dad928389c Mon Sep 17 00:00:00 2001 From: Richie Cahill Date: Tue, 28 Apr 2026 22:20:18 -0400 Subject: [PATCH] added bert_topic train.py and infer.py --- pipelines/bert_topic/infer.py | 116 +++++++++++++++++++++++++++++++++ pipelines/bert_topic/train.py | 119 ++++++++++++++++++++++++++++++++++ pipelines/config.py | 55 ++++++++++++++++ 3 files changed, 290 insertions(+) create mode 100644 pipelines/bert_topic/infer.py create mode 100644 pipelines/bert_topic/train.py diff --git a/pipelines/bert_topic/infer.py b/pipelines/bert_topic/infer.py new file mode 100644 index 0000000..9491660 --- /dev/null +++ b/pipelines/bert_topic/infer.py @@ -0,0 +1,116 @@ +"""Nornsight — BERTopic POC Inference Script. + +Loads the trained model and labels a small batch of posts, +writing results to main.post_topic for inspection. + +POC: processes a single batch of 1k posts to validate the pipeline end-to-end. +""" + +from __future__ import annotations + +import logging +import time +from collections import Counter +from pathlib import Path + +from bertopic import BERTopic +from sqlalchemy import Engine, func, insert, select +from sqlalchemy.orm import Session + +from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.posts import PostTopic, Posts +from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS +from pipelines.pipelines.common import configure_logger + +logger = logging.getLogger(__name__) + + +def main() -> None: + """Run BERTopic inference against a sample of posts.""" + configure_logger() + + config = get_bertopic_infer_config() + run_inference(config) + logger.info( + "POC inference complete. Check main.post_topic in DBeaver to inspect results." + ) + + +def run_inference(config: BertTopicInferConfig) -> None: + model_save_path = Path(config.model_save_path) + + logger.info(f"Loading BERTopic model from {model_save_path}") + topic_model = BERTopic.load(str(model_save_path)) + + topic_info = topic_model.get_topic_info() + label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"])) + logger.info(f"Model loaded with {len(label_map)} topics") + + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + + post_ids, texts = get_post_ids_and_test(engine, config) + + logger.info(f"Fetched {len(texts)} posts") + + logger.info("Running BERTopic transform") + start = time.perf_counter() + topics, _probabilities = topic_model.transform(texts) + elapsed = time.perf_counter() - start + logger.info(f"Transform complete in {elapsed:.1f}s") + + # Write results to main.post_topic + records = [ + { + "post_id": pid, + "topic_id": int(topic_id), + "topic_label": label_map.get(int(topic_id), "unknown"), + "model_version": config.model_version, + } + for pid, topic_id in zip(post_ids, topics) + ] + with Session(engine) as session: + session.execute(insert(PostTopic), records) + session.commit() + + count_topics(records) + logger.info(f"Wrote {len(records)} topic labels to main.post_topic") + + +def get_post_ids_and_test( + engine: Engine, + config: BertTopicInferConfig, +) -> None | tuple[list[int], list[str]]: + with Session(engine) as session: + logger.info(f"Fetching {config.poc_batch_size} posts for inference") + # Pull a fresh batch for inference — distinct from training sample + # using a fixed seed offset so we're not re-labeling training posts + stmt = select(Posts).where( + Posts.text.is_not(None), + Posts.langs.in_(ENGLISH_LANGS), + func.length(Posts.text) > config.min_text_length, + ) + if config.poc_batch_size > 0: + stmt = stmt.limit(config.poc_batch_size) + + posts = session.scalars(stmt).all() + if not posts: + logger.warning("No posts were selected for inference") + return [], [] + + post_ids = [post.post_id for post in posts] + texts = [post.text.strip() for post in posts] + + return post_ids, texts + + +def count_topics(records: list[dict]) -> None: + topic_counts = Counter(record.get("topic_label", "unknown") for record in records) + + logger.info("Topic distribution in this batch:") + for label, count in topic_counts.most_common(10): + logger.info(" %s: %d", label, count) + + +if __name__ == "__main__": + main() diff --git a/pipelines/bert_topic/train.py b/pipelines/bert_topic/train.py new file mode 100644 index 0000000..848edd4 --- /dev/null +++ b/pipelines/bert_topic/train.py @@ -0,0 +1,119 @@ +"""Nornsight — BERTopic POC Training Script. + +Pulls a small stratified sample (~11.5k posts) from main.posts, +trains BERTopic with MiniBatchKMeans on Jeeves, and saves the model locally. + +POC sample rate: random() < 0.00005 (~0.005% of 230M = ~11.5k posts) +Full training rate will be: random() < 0.005 (~1.08M posts) +""" + +from __future__ import annotations + +import logging +import time +from pathlib import Path + +from bertopic import BERTopic +from sklearn.cluster import MiniBatchKMeans +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from pipelines.config import BertTopicTrainConfig, get_bertopic_train_config +from pipelines.orm.common import get_postgres_engine +from pipelines.orm.data_science_dev.posts import Posts +from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS +from pipelines.pipelines.common import configure_logger + +logger = logging.getLogger(__name__) + + +def main() -> None: + """Train and persist the BERTopic model.""" + configure_logger() + + config = get_bertopic_train_config() + docs = load_sample(config) + if not docs: + logger.warning("No training documents were selected") + return + + train(docs, config) + logger.info(f"Done. Model saved as version {config.model_version}") + logger.info("Next: run infer.py to label a sample of posts in the database") + + +def load_sample(config: BertTopicTrainConfig) -> list[str]: + logger.info("Connecting to PostgreSQL via SQLAlchemy") + engine = get_postgres_engine(name="DATA_SCIENCE_DEV") + + logger.info(f"Pulling sample from main.posts (sample_rate={config.sample_rate})") + start = time.perf_counter() + + with Session(engine) as session: + texts = session.scalars( + select(Posts.text).where( + Posts.text.is_not(None), + Posts.langs.in_(ENGLISH_LANGS), + func.length(Posts.text) > config.min_text_length, + func.random() < config.sample_rate, + ) + ).all() + + elapsed = time.perf_counter() - start + logger.info(f"Fetched {len(texts)} rows in {elapsed:.1f}s") + + # Basic cleaning — strip whitespace and deduplicate + docs = list({text.strip() for text in texts}) + logger.info(f"After cleaning and dedup: {len(docs)} posts") + + return docs + + +def train(docs: list[str], config: BertTopicTrainConfig) -> None: + logger.info( + f"Initialising BERTopic with MiniBatchKMeans (n_topics={config.n_topics})" + ) + + cluster_model = MiniBatchKMeans( + n_clusters=config.n_topics, + random_state=42, + batch_size=1024, + n_init=3, + verbose=1, + ) + + topic_model = BERTopic( + hdbscan_model=cluster_model, + language="english", + calculate_probabilities=False, # saves memory + verbose=True, + ) + + logger.info(f"Starting fit_transform on {len(docs)} posts (CPU)") + start = time.perf_counter() + + topic_model.fit_transform(docs) + + elapsed = time.perf_counter() - start + logger.info(f"Training complete in {elapsed:.1f}s ({elapsed / 60:.1f} min)") + + # Log topic summary for quick inspection + topic_info = topic_model.get_topic_info() + logger.info(f"Topics found: {len(topic_info)}") + logger.info(f"\n{topic_info.to_string()}") + + model_save_path = Path(config.model_save_path) + model_save_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving model to {model_save_path}") + + topic_model.save( + str(model_save_path), + serialization="safetensors", + save_ctfidf=True, + save_embedding_model=True, + ) + logger.info("Model saved") + + +if __name__ == "__main__": + main() diff --git a/pipelines/config.py b/pipelines/config.py index 4bdaadb..3bfd32c 100644 --- a/pipelines/config.py +++ b/pipelines/config.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass from os import getenv +from datetime import date from pathlib import Path import tomllib @@ -101,6 +102,44 @@ class OpenAIConfig: ) +class BertTopicTrainConfig: + """BERTopic training configuration loaded from TOML.""" + + sample_rate: float + min_text_length: int + n_topics: int + model_save_path: str + model_version: str | None = None + + @classmethod + def from_toml(cls, config_path: Path) -> BertTopicTrainConfig: + """Load BERTopic training config from a TOML file.""" + raw = tomllib.loads(config_path.read_text())["bertopic"]["train"] + + today = date.today().isoformat() + if raw.get("model_version") is None: + raw["model_version"] = ( + f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}" + ) + return cls(**raw) + + +@dataclass +class BertTopicInferConfig: + """BERTopic inference configuration loaded from TOML.""" + + min_text_length: int + poc_batch_size: int + model_version: str + model_save_path: str + + @classmethod + def from_toml(cls, config_path: Path) -> BertTopicInferConfig: + """Load BERTopic inference config from a TOML file.""" + raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"] + return cls(**raw) + + def get_config_dir() -> Path: """Get the path to the config directory.""" return Path(__file__).resolve().parents[2] / "config" @@ -127,3 +166,19 @@ def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig: if config_path is None: config_path = default_config_path() return BenchmarkConfig.from_toml(config_path) + + +def get_bertopic_train_config( + config_path: Path | None = None, +) -> BertTopicTrainConfig: + if config_path is None: + config_path = default_config_path() + return BertTopicTrainConfig.from_toml(config_path) + + +def get_bertopic_infer_config( + config_path: Path | None = None, +) -> BertTopicInferConfig: + if config_path is None: + config_path = default_config_path() + return BertTopicInferConfig.from_toml(config_path)