diff --git a/overlays/default.nix b/overlays/default.nix index b069d6a..1bcdcd0 100644 --- a/overlays/default.nix +++ b/overlays/default.nix @@ -25,6 +25,7 @@ fastapi-cli httpx mypy + orjson polars psycopg pydantic diff --git a/python/data_science/ingest_posts.py b/python/data_science/ingest_posts.py index e51f7e4..c43adf5 100644 --- a/python/data_science/ingest_posts.py +++ b/python/data_science/ingest_posts.py @@ -8,20 +8,17 @@ Usage: from __future__ import annotations -import json import logging from datetime import UTC, datetime -from pathlib import Path +from pathlib import Path # noqa: TC003 this is needed for typer from typing import TYPE_CHECKING, Annotated +import orjson +import psycopg import typer -from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session -from python.orm.common import get_postgres_engine -from python.orm.data_science_dev.posts.failed_ingestion import FailedIngestion -from python.orm.data_science_dev.posts.tables import Posts +from python.common import configure_logger +from python.orm.common import get_connection_info from python.parallelize import parallelize_process if TYPE_CHECKING: @@ -41,20 +38,20 @@ def main( pattern: Annotated[str, typer.Option(help="Glob pattern for JSONL files")] = "*.jsonl", ) -> None: """Ingest JSONL post files into the weekly-partitioned posts table.""" - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(message)s", - datefmt="%H:%M:%S", - ) + configure_logger(level="INFO") + logger.info("starting ingest-posts") + logger.info("path=%s batch_size=%d workers=%d pattern=%s", path, batch_size, workers, pattern) if path.is_file(): - ingest_file(str(path), batch_size=batch_size) + ingest_file(path, batch_size=batch_size) elif path.is_dir(): ingest_directory(path, batch_size=batch_size, max_workers=workers, pattern=pattern) else: typer.echo(f"Path does not exist: {path}", err=True) raise typer.Exit(code=1) + logger.info("ingest-posts done") + def ingest_directory( directory: Path, @@ -62,74 +59,129 @@ def ingest_directory( batch_size: int, max_workers: int, pattern: str = "*.jsonl", -) -> int: +) -> None: """Ingest all JSONL files in a directory using parallel workers.""" files = sorted(directory.glob(pattern)) if not files: logger.warning("No JSONL files found in %s", directory) - return 0 + return logger.info("Found %d JSONL files to ingest", len(files)) - file_paths = [str(file) for file in files] - total_rows = 0 - - kwargs_list = [{"file_path": fp, "batch_size": batch_size} for fp in file_paths] - executor_results = parallelize_process(ingest_file, kwargs_list, max_workers=max_workers) - total_rows = sum(executor_results.results) - - logger.info("Ingestion complete — %d total rows across %d files", total_rows, len(files)) - return total_rows + kwargs_list = [{"path": fp, "batch_size": batch_size} for fp in files] + parallelize_process(ingest_file, kwargs_list, max_workers=max_workers) -def ingest_file(file_path: str, *, batch_size: int) -> int: - """Ingest a single JSONL file into the posts table. Returns total rows inserted.""" - path = Path(file_path) - engine = get_postgres_engine(name="DATA_SCIENCE_DEV") - total_rows = 0 +SCHEMA = "main" - with Session(engine) as session: - for batch in read_jsonl_batches(path, batch_size): - inserted = _ingest_batch(session, batch) - total_rows += inserted - logger.info(" %s: inserted %d rows (total: %d)", path.name, inserted, total_rows) +COLUMNS = ( + "post_id", + "user_id", + "instance", + "date", + "text", + "langs", + "like_count", + "reply_count", + "repost_count", + "reply_to", + "replied_author", + "thread_root", + "thread_root_author", + "repost_from", + "reposted_author", + "quotes", + "quoted_author", + "labels", + "sent_label", + "sent_score", +) - logger.info("Finished %s — %d rows", path.name, total_rows) - return total_rows +INSERT_FROM_STAGING = f""" + INSERT INTO {SCHEMA}.posts ({", ".join(COLUMNS)}) + SELECT {", ".join(COLUMNS)} FROM pg_temp.staging + ON CONFLICT (post_id, date) DO NOTHING +""" # noqa: S608 + +FAILED_INSERT = f""" + INSERT INTO {SCHEMA}.failed_ingestion (raw_line, error) + VALUES (%(raw_line)s, %(error)s) +""" # noqa: S608 -def _ingest_batch(session: Session, batch: list[dict]) -> int: - """Try bulk insert; on failure, binary-split to isolate bad rows.""" +def get_psycopg_connection() -> psycopg.Connection: + """Create a raw psycopg3 connection from environment variables.""" + database, host, port, username, password = get_connection_info("DATA_SCIENCE_DEV") + return psycopg.connect( + dbname=database, + host=host, + port=int(port), + user=username, + password=password, + autocommit=False, + ) + + +def ingest_file(path: Path, *, batch_size: int) -> None: + """Ingest a single JSONL file into the posts table.""" + log_trigger = max(100_000 // batch_size, 1) + failed_lines: list[dict] = [] + + with get_psycopg_connection() as connection: + for index, batch in enumerate(read_jsonl_batches(path, batch_size, failed_lines), 1): + ingest_batch(connection, batch) + if index % log_trigger == 0: + logger.info("Ingested %d batches (%d rows) from %s", index, index * batch_size, path) + + if failed_lines: + logger.warning("Recording %d malformed lines from %s", len(failed_lines), path.name) + with connection.cursor() as cursor: + cursor.executemany(FAILED_INSERT, failed_lines) + connection.commit() + + +def ingest_batch(connection: psycopg.Connection, batch: list[dict]) -> None: + """COPY batch into a temp staging table, then INSERT ... ON CONFLICT into posts.""" if not batch: - return 0 + return try: - statement = insert(Posts).values(batch).on_conflict_do_nothing(index_elements=["post_id"]) - result = session.execute(statement) - session.commit() - except (OSError, SQLAlchemyError) as error: - session.rollback() + with connection.cursor() as cursor: + cursor.execute(f""" + CREATE TEMP TABLE IF NOT EXISTS staging + (LIKE {SCHEMA}.posts INCLUDING DEFAULTS) + ON COMMIT DELETE ROWS + """) + cursor.execute("TRUNCATE pg_temp.staging") + + with cursor.copy(f"COPY pg_temp.staging ({', '.join(COLUMNS)}) FROM STDIN") as copy: + for row in batch: + copy.write_row(tuple(row.get(column) for column in COLUMNS)) + + cursor.execute(INSERT_FROM_STAGING) + connection.commit() + except (OSError, psycopg.Error) as error: + connection.rollback() if len(batch) == 1: logger.exception("Skipping bad row post_id=%s", batch[0].get("post_id")) - session.add( - FailedIngestion( - raw_line=json.dumps(batch[0], default=str), - error=str(error), + with connection.cursor() as cursor: + cursor.execute( + FAILED_INSERT, + { + "raw_line": orjson.dumps(batch[0], default=str).decode(), + "error": str(error), + }, ) - ) - session.commit() - return 0 + connection.commit() + return midpoint = len(batch) // 2 - left = _ingest_batch(session, batch[:midpoint]) - right = _ingest_batch(session, batch[midpoint:]) - return left + right - else: - return result.rowcount + ingest_batch(connection, batch[:midpoint]) + ingest_batch(connection, batch[midpoint:]) -def read_jsonl_batches(file_path: Path, batch_size: int) -> Iterator[list[dict]]: +def read_jsonl_batches(file_path: Path, batch_size: int, failed_lines: list[dict]) -> Iterator[list[dict]]: """Stream a JSONL file and yield batches of transformed rows.""" batch: list[dict] = [] with file_path.open("r", encoding="utf-8") as handle: @@ -137,8 +189,7 @@ def read_jsonl_batches(file_path: Path, batch_size: int) -> Iterator[list[dict]] line = raw_line.strip() if not line: continue - row = transform_row(json.loads(line)) - batch.append(row) + batch.extend(parse_line(line, file_path, failed_lines)) if len(batch) >= batch_size: yield batch batch = [] @@ -146,11 +197,34 @@ def read_jsonl_batches(file_path: Path, batch_size: int) -> Iterator[list[dict]] yield batch +def parse_line(line: str, file_path: Path, failed_lines: list[dict]) -> Iterator[dict]: + """Parse a JSONL line, handling concatenated JSON objects.""" + try: + yield transform_row(orjson.loads(line)) + except orjson.JSONDecodeError: + if "}{" not in line: + logger.warning("Skipping malformed line in %s: %s", file_path.name, line[:120]) + failed_lines.append({"raw_line": line, "error": "malformed JSON"}) + return + fragments = line.replace("}{", "}\n{").split("\n") + for fragment in fragments: + try: + yield transform_row(orjson.loads(fragment)) + except (orjson.JSONDecodeError, KeyError, ValueError) as error: + logger.warning("Skipping malformed fragment in %s: %s", file_path.name, fragment[:120]) + failed_lines.append({"raw_line": fragment, "error": str(error)}) + except (KeyError, ValueError) as error: + logger.warning("Skipping bad row in %s: %s", file_path.name, line[:120]) + failed_lines.append({"raw_line": line, "error": str(error)}) + + def transform_row(raw: dict) -> dict: """Transform a raw JSONL row into a dict matching the Posts table columns.""" raw["date"] = parse_date(raw["date"]) if raw.get("langs") is not None: - raw["langs"] = json.dumps(raw["langs"]) + raw["langs"] = orjson.dumps(raw["langs"]) + if raw.get("text") is not None: + raw["text"] = raw["text"].replace("\x00", "") return raw