"""Nornsight — BERTopic POC Inference Script. Loads the trained model and labels a small batch of posts, writing results to main.post_topic for inspection. POC: processes a single batch of 1k posts to validate the pipeline end-to-end. """ from __future__ import annotations import logging import time from collections import Counter from pathlib import Path from bertopic import BERTopic from sqlalchemy import Engine, func, insert, select from sqlalchemy.orm import Session from pipelines.config import BertTopicInferConfig, get_bertopic_infer_config from pipelines.orm.common import get_postgres_engine from pipelines.orm.data_science_dev.posts import PostTopic, Posts from pipelines.orm.data_science_dev.posts.lang_filters import ENGLISH_LANGS from pipelines.pipelines.common import configure_logger logger = logging.getLogger(__name__) def main() -> None: """Run BERTopic inference against a sample of posts.""" configure_logger() config = get_bertopic_infer_config() run_inference(config) logger.info( "POC inference complete. Check main.post_topic in DBeaver to inspect results." ) def run_inference(config: BertTopicInferConfig) -> None: model_save_path = Path(config.model_save_path) logger.info(f"Loading BERTopic model from {model_save_path}") topic_model = BERTopic.load(str(model_save_path)) topic_info = topic_model.get_topic_info() label_map: dict[int, str] = dict(zip(topic_info["Topic"], topic_info["Name"])) logger.info(f"Model loaded with {len(label_map)} topics") engine = get_postgres_engine(name="DATA_SCIENCE_DEV") post_ids, texts = get_post_ids_and_test(engine, config) logger.info(f"Fetched {len(texts)} posts") logger.info("Running BERTopic transform") start = time.perf_counter() topics, _probabilities = topic_model.transform(texts) elapsed = time.perf_counter() - start logger.info(f"Transform complete in {elapsed:.1f}s") # Write results to main.post_topic records = [ { "post_id": pid, "topic_id": int(topic_id), "topic_label": label_map.get(int(topic_id), "unknown"), "model_version": config.model_version, } for pid, topic_id in zip(post_ids, topics) ] with Session(engine) as session: session.execute(insert(PostTopic), records) session.commit() count_topics(records) logger.info(f"Wrote {len(records)} topic labels to main.post_topic") def get_post_ids_and_test( engine: Engine, config: BertTopicInferConfig, ) -> None | tuple[list[int], list[str]]: with Session(engine) as session: logger.info(f"Fetching {config.poc_batch_size} posts for inference") # Pull a fresh batch for inference — distinct from training sample # using a fixed seed offset so we're not re-labeling training posts stmt = select(Posts).where( Posts.text.is_not(None), Posts.langs.in_(ENGLISH_LANGS), func.length(Posts.text) > config.min_text_length, ) if config.poc_batch_size > 0: stmt = stmt.limit(config.poc_batch_size) posts = session.scalars(stmt).all() if not posts: logger.warning("No posts were selected for inference") return [], [] post_ids = [post.post_id for post in posts] texts = [post.text.strip() for post in posts] return post_ids, texts def count_topics(records: list[dict]) -> None: topic_counts = Counter(record.get("topic_label", "unknown") for record in records) logger.info("Topic distribution in this batch:") for label, count in topic_counts.most_common(10): logger.info(" %s: %d", label, count) if __name__ == "__main__": main()