diff --git a/python/eval_warnings/main.py b/python/eval_warnings/main.py index 45f597b..ce63391 100644 --- a/python/eval_warnings/main.py +++ b/python/eval_warnings/main.py @@ -3,24 +3,24 @@ from __future__ import annotations import hashlib -import io import logging import re import subprocess -import zipfile from dataclasses import dataclass +from io import BytesIO from pathlib import Path from typing import Annotated +from zipfile import ZipFile -import httpx import typer +from httpx import HTTPError, post from python.common import configure_logger logger = logging.getLogger(__name__) -@dataclass +@dataclass(frozen=True) class EvalWarning: """A single Nix evaluation warning.""" @@ -37,13 +37,6 @@ class FileChange: fixed: str -WARNING_PATTERN = re.compile(r"(?:^[\d\-T:.Z]+ )?(warning:|trace: warning:)") -TIMESTAMP_PREFIX = re.compile(r"^[\d\-T:.Z]+ ") -NIX_STORE_PATH = re.compile(r"/nix/store/[^/]+-source/([^:]+\.nix)") -REPO_RELATIVE_PATH = re.compile(r"(? subprocess.CompletedProcess[str]: """Run a subprocess command and return the result. @@ -81,41 +74,43 @@ def download_logs(run_id: str, repo: str) -> dict[str, str]: raise RuntimeError(msg) logs: dict[str, str] = {} - with zipfile.ZipFile(io.BytesIO(result.stdout)) as zf: - for name in zf.namelist(): + with ZipFile(BytesIO(result.stdout)) as zip_file: + for name in zip_file.namelist(): if name.startswith("build-") and name.endswith(".txt"): - logs[name] = zf.read(name).decode(errors="replace") + logs[name] = zip_file.read(name).decode(errors="replace") return logs -def parse_warnings(logs: dict[str, str]) -> list[EvalWarning]: +def parse_warnings(logs: dict[str, str]) -> set[EvalWarning]: """Parse Nix evaluation warnings from build log contents. Args: logs: Dict mapping zip entry names (e.g. "build-bob/2_Build.txt") to their text. Returns: - Deduplicated list of warnings. + Deduplicated set of warnings. """ - warnings: list[EvalWarning] = [] - seen: set[str] = set() + warnings: set[EvalWarning] = set() + warning_pattern = re.compile(r"(?:^[\d\-T:.Z]+ )?(warning:|trace: warning:)") + timestamp_prefix = re.compile(r"^[\d\-T:.Z]+ ") for name, content in sorted(logs.items()): system = name.split("/")[0].removeprefix("build-") for line in content.splitlines(): - if WARNING_PATTERN.search(line): - message = TIMESTAMP_PREFIX.sub("", line).strip() - key = f"{system}:{message}" - if key not in seen: - seen.add(key) - warnings.append(EvalWarning(system=system, message=message)) + if line.startswith("warning: ignoring untrusted flake configuration setting"): + logger.info("Ignoring untrusted flake configuration setting warning.") + continue + if warning_pattern.search(line): + logger.debug(f"Found warning: {line}") + message = timestamp_prefix.sub("", line).strip() + warnings.add(EvalWarning(system=system, message=message)) logger.info("Found %d unique warnings", len(warnings)) return warnings -def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]: +def extract_referenced_files(warnings: set[EvalWarning]) -> dict[str, str]: """Extract file paths referenced in warnings and read their contents. Args: @@ -127,9 +122,12 @@ def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]: paths: set[str] = set() warning_text = "\n".join(w.message for w in warnings) - for match in NIX_STORE_PATH.finditer(warning_text): + nix_store_path = re.compile(r"/nix/store/[^/]+-source/([^:]+\.nix)") + for match in nix_store_path.finditer(warning_text): paths.add(match.group(1)) - for match in REPO_RELATIVE_PATH.finditer(warning_text): + + repo_relative_path = re.compile(r"(? dict[str, str]: return files -def compute_warning_hash(warnings: list[EvalWarning]) -> str: +def compute_warning_hash(warnings: set[EvalWarning]) -> str: """Compute a short hash of the warning set for deduplication. Args: @@ -198,7 +196,7 @@ def check_duplicate_pr(warning_hash: str) -> bool: def query_ollama( - warnings: list[EvalWarning], + warnings: set[EvalWarning], files: dict[str, str], ollama_url: str, ) -> str | None: @@ -243,7 +241,7 @@ Analyze the following Nix evaluation warnings and suggest fixes. say so in REASONING and do not suggest changes""" try: - response = httpx.post( + response = post( f"{ollama_url}/api/generate", json={ "model": "qwen3-coder:30b", @@ -254,7 +252,7 @@ say so in REASONING and do not suggest changes""" timeout=300, ) response.raise_for_status() - except httpx.HTTPError: + except HTTPError: logger.exception("Ollama request failed") return None @@ -280,36 +278,27 @@ def parse_changes(response: str) -> list[FileChange]: """ changes: list[FileChange] = [] current_file = "" - in_original = False - in_fixed = False + section: str | None = None original_lines: list[str] = [] fixed_lines: list[str] = [] for line in response.splitlines(): - file_match = CHANGE_FILE_PATTERN.match(line) - if file_match: - current_file = file_match.group(1).strip() - elif line.strip() == "<<<<<<< ORIGINAL": - in_original = True - in_fixed = False + stripped = line.strip() + if stripped.startswith("FILE:"): + current_file = stripped.removeprefix("FILE:").strip() + elif stripped == "<<<<<<< ORIGINAL": + section = "original" original_lines = [] - elif line.strip() == "=======" and in_original: - in_original = False - in_fixed = True + elif stripped == "=======" and section == "original": + section = "fixed" fixed_lines = [] - elif line.strip() == ">>>>>>> FIXED" and in_fixed: - in_fixed = False + elif stripped == ">>>>>>> FIXED" and section == "fixed": + section = None if current_file: - changes.append( - FileChange( - file_path=current_file, - original="\n".join(original_lines), - fixed="\n".join(fixed_lines), - ) - ) - elif in_original: + changes.append(FileChange(current_file, "\n".join(original_lines), "\n".join(fixed_lines))) + elif section == "original": original_lines.append(line) - elif in_fixed: + elif section == "fixed": fixed_lines.append(line) logger.info("Parsed %d file changes", len(changes)) @@ -346,7 +335,7 @@ def apply_changes(changes: list[FileChange]) -> int: def create_pr( warning_hash: str, - warnings: list[EvalWarning], + warnings: set[EvalWarning], llm_response: str, run_url: str, ) -> None: @@ -361,16 +350,12 @@ def create_pr( branch = f"fix/eval-warning-{warning_hash}" warning_text = "\n".join(f"[{w.system}] {w.message}" for w in warnings) - reasoning_lines: list[str] = [] - capturing = False - for line in llm_response.splitlines(): - if "**REASONING**" in line: - capturing = True - elif "**CHANGES**" in line: - break - elif capturing: - reasoning_lines.append(line) - reasoning = "\n".join(reasoning_lines[:50]) + if "**REASONING**" not in llm_response: + logger.warning("LLM response missing **REASONING** section") + reasoning = "" + else: + _, after = llm_response.split("**REASONING**", 1) + reasoning = "\n".join(after.split("**CHANGES**", 1)[0].strip().splitlines()[:50]) run_cmd(["git", "config", "user.name", "github-actions[bot]"]) run_cmd(["git", "config", "user.email", "github-actions[bot]@users.noreply.github.com"])