improved eval_warnings/main.py

This commit is contained in:
2026-02-18 15:58:41 -05:00
parent af828fc9c4
commit f038f248a1

View File

@@ -3,24 +3,24 @@
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import io
import logging import logging
import re import re
import subprocess import subprocess
import zipfile
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Annotated from typing import Annotated
from zipfile import ZipFile
import httpx
import typer import typer
from httpx import HTTPError, post
from python.common import configure_logger from python.common import configure_logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass(frozen=True)
class EvalWarning: class EvalWarning:
"""A single Nix evaluation warning.""" """A single Nix evaluation warning."""
@@ -37,13 +37,6 @@ class FileChange:
fixed: str fixed: str
WARNING_PATTERN = re.compile(r"(?:^[\d\-T:.Z]+ )?(warning:|trace: warning:)")
TIMESTAMP_PREFIX = re.compile(r"^[\d\-T:.Z]+ ")
NIX_STORE_PATH = re.compile(r"/nix/store/[^/]+-source/([^:]+\.nix)")
REPO_RELATIVE_PATH = re.compile(r"(?<![/\w])(systems|common|users|overlays)/[^:\s]+\.nix")
CHANGE_FILE_PATTERN = re.compile(r"^FILE:\s*(.+)$")
def run_cmd(cmd: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]: def run_cmd(cmd: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]:
"""Run a subprocess command and return the result. """Run a subprocess command and return the result.
@@ -81,41 +74,43 @@ def download_logs(run_id: str, repo: str) -> dict[str, str]:
raise RuntimeError(msg) raise RuntimeError(msg)
logs: dict[str, str] = {} logs: dict[str, str] = {}
with zipfile.ZipFile(io.BytesIO(result.stdout)) as zf: with ZipFile(BytesIO(result.stdout)) as zip_file:
for name in zf.namelist(): for name in zip_file.namelist():
if name.startswith("build-") and name.endswith(".txt"): if name.startswith("build-") and name.endswith(".txt"):
logs[name] = zf.read(name).decode(errors="replace") logs[name] = zip_file.read(name).decode(errors="replace")
return logs return logs
def parse_warnings(logs: dict[str, str]) -> list[EvalWarning]: def parse_warnings(logs: dict[str, str]) -> set[EvalWarning]:
"""Parse Nix evaluation warnings from build log contents. """Parse Nix evaluation warnings from build log contents.
Args: Args:
logs: Dict mapping zip entry names (e.g. "build-bob/2_Build.txt") to their text. logs: Dict mapping zip entry names (e.g. "build-bob/2_Build.txt") to their text.
Returns: Returns:
Deduplicated list of warnings. Deduplicated set of warnings.
""" """
warnings: list[EvalWarning] = [] warnings: set[EvalWarning] = set()
seen: set[str] = set() warning_pattern = re.compile(r"(?:^[\d\-T:.Z]+ )?(warning:|trace: warning:)")
timestamp_prefix = re.compile(r"^[\d\-T:.Z]+ ")
for name, content in sorted(logs.items()): for name, content in sorted(logs.items()):
system = name.split("/")[0].removeprefix("build-") system = name.split("/")[0].removeprefix("build-")
for line in content.splitlines(): for line in content.splitlines():
if WARNING_PATTERN.search(line): if line.startswith("warning: ignoring untrusted flake configuration setting"):
message = TIMESTAMP_PREFIX.sub("", line).strip() logger.info("Ignoring untrusted flake configuration setting warning.")
key = f"{system}:{message}" continue
if key not in seen: if warning_pattern.search(line):
seen.add(key) logger.debug(f"Found warning: {line}")
warnings.append(EvalWarning(system=system, message=message)) message = timestamp_prefix.sub("", line).strip()
warnings.add(EvalWarning(system=system, message=message))
logger.info("Found %d unique warnings", len(warnings)) logger.info("Found %d unique warnings", len(warnings))
return warnings return warnings
def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]: def extract_referenced_files(warnings: set[EvalWarning]) -> dict[str, str]:
"""Extract file paths referenced in warnings and read their contents. """Extract file paths referenced in warnings and read their contents.
Args: Args:
@@ -127,9 +122,12 @@ def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]:
paths: set[str] = set() paths: set[str] = set()
warning_text = "\n".join(w.message for w in warnings) warning_text = "\n".join(w.message for w in warnings)
for match in NIX_STORE_PATH.finditer(warning_text): nix_store_path = re.compile(r"/nix/store/[^/]+-source/([^:]+\.nix)")
for match in nix_store_path.finditer(warning_text):
paths.add(match.group(1)) paths.add(match.group(1))
for match in REPO_RELATIVE_PATH.finditer(warning_text):
repo_relative_path = re.compile(r"(?<![/\w])(systems|common|users|overlays)/[^:\s]+\.nix")
for match in repo_relative_path.finditer(warning_text):
paths.add(match.group(0)) paths.add(match.group(0))
files: dict[str, str] = {} files: dict[str, str] = {}
@@ -145,7 +143,7 @@ def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]:
return files return files
def compute_warning_hash(warnings: list[EvalWarning]) -> str: def compute_warning_hash(warnings: set[EvalWarning]) -> str:
"""Compute a short hash of the warning set for deduplication. """Compute a short hash of the warning set for deduplication.
Args: Args:
@@ -198,7 +196,7 @@ def check_duplicate_pr(warning_hash: str) -> bool:
def query_ollama( def query_ollama(
warnings: list[EvalWarning], warnings: set[EvalWarning],
files: dict[str, str], files: dict[str, str],
ollama_url: str, ollama_url: str,
) -> str | None: ) -> str | None:
@@ -243,7 +241,7 @@ Analyze the following Nix evaluation warnings and suggest fixes.
say so in REASONING and do not suggest changes""" say so in REASONING and do not suggest changes"""
try: try:
response = httpx.post( response = post(
f"{ollama_url}/api/generate", f"{ollama_url}/api/generate",
json={ json={
"model": "qwen3-coder:30b", "model": "qwen3-coder:30b",
@@ -254,7 +252,7 @@ say so in REASONING and do not suggest changes"""
timeout=300, timeout=300,
) )
response.raise_for_status() response.raise_for_status()
except httpx.HTTPError: except HTTPError:
logger.exception("Ollama request failed") logger.exception("Ollama request failed")
return None return None
@@ -280,36 +278,27 @@ def parse_changes(response: str) -> list[FileChange]:
""" """
changes: list[FileChange] = [] changes: list[FileChange] = []
current_file = "" current_file = ""
in_original = False section: str | None = None
in_fixed = False
original_lines: list[str] = [] original_lines: list[str] = []
fixed_lines: list[str] = [] fixed_lines: list[str] = []
for line in response.splitlines(): for line in response.splitlines():
file_match = CHANGE_FILE_PATTERN.match(line) stripped = line.strip()
if file_match: if stripped.startswith("FILE:"):
current_file = file_match.group(1).strip() current_file = stripped.removeprefix("FILE:").strip()
elif line.strip() == "<<<<<<< ORIGINAL": elif stripped == "<<<<<<< ORIGINAL":
in_original = True section = "original"
in_fixed = False
original_lines = [] original_lines = []
elif line.strip() == "=======" and in_original: elif stripped == "=======" and section == "original":
in_original = False section = "fixed"
in_fixed = True
fixed_lines = [] fixed_lines = []
elif line.strip() == ">>>>>>> FIXED" and in_fixed: elif stripped == ">>>>>>> FIXED" and section == "fixed":
in_fixed = False section = None
if current_file: if current_file:
changes.append( changes.append(FileChange(current_file, "\n".join(original_lines), "\n".join(fixed_lines)))
FileChange( elif section == "original":
file_path=current_file,
original="\n".join(original_lines),
fixed="\n".join(fixed_lines),
)
)
elif in_original:
original_lines.append(line) original_lines.append(line)
elif in_fixed: elif section == "fixed":
fixed_lines.append(line) fixed_lines.append(line)
logger.info("Parsed %d file changes", len(changes)) logger.info("Parsed %d file changes", len(changes))
@@ -346,7 +335,7 @@ def apply_changes(changes: list[FileChange]) -> int:
def create_pr( def create_pr(
warning_hash: str, warning_hash: str,
warnings: list[EvalWarning], warnings: set[EvalWarning],
llm_response: str, llm_response: str,
run_url: str, run_url: str,
) -> None: ) -> None:
@@ -361,16 +350,12 @@ def create_pr(
branch = f"fix/eval-warning-{warning_hash}" branch = f"fix/eval-warning-{warning_hash}"
warning_text = "\n".join(f"[{w.system}] {w.message}" for w in warnings) warning_text = "\n".join(f"[{w.system}] {w.message}" for w in warnings)
reasoning_lines: list[str] = [] if "**REASONING**" not in llm_response:
capturing = False logger.warning("LLM response missing **REASONING** section")
for line in llm_response.splitlines(): reasoning = ""
if "**REASONING**" in line: else:
capturing = True _, after = llm_response.split("**REASONING**", 1)
elif "**CHANGES**" in line: reasoning = "\n".join(after.split("**CHANGES**", 1)[0].strip().splitlines()[:50])
break
elif capturing:
reasoning_lines.append(line)
reasoning = "\n".join(reasoning_lines[:50])
run_cmd(["git", "config", "user.name", "github-actions[bot]"]) run_cmd(["git", "config", "user.name", "github-actions[bot]"])
run_cmd(["git", "config", "user.email", "github-actions[bot]@users.noreply.github.com"]) run_cmd(["git", "config", "user.email", "github-actions[bot]@users.noreply.github.com"])