improved eval_warnings/main.py

This commit is contained in:
2026-02-18 15:58:41 -05:00
parent af828fc9c4
commit f038f248a1

View File

@@ -3,24 +3,24 @@
from __future__ import annotations
import hashlib
import io
import logging
import re
import subprocess
import zipfile
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Annotated
from zipfile import ZipFile
import httpx
import typer
from httpx import HTTPError, post
from python.common import configure_logger
logger = logging.getLogger(__name__)
@dataclass
@dataclass(frozen=True)
class EvalWarning:
"""A single Nix evaluation warning."""
@@ -37,13 +37,6 @@ class FileChange:
fixed: str
WARNING_PATTERN = re.compile(r"(?:^[\d\-T:.Z]+ )?(warning:|trace: warning:)")
TIMESTAMP_PREFIX = re.compile(r"^[\d\-T:.Z]+ ")
NIX_STORE_PATH = re.compile(r"/nix/store/[^/]+-source/([^:]+\.nix)")
REPO_RELATIVE_PATH = re.compile(r"(?<![/\w])(systems|common|users|overlays)/[^:\s]+\.nix")
CHANGE_FILE_PATTERN = re.compile(r"^FILE:\s*(.+)$")
def run_cmd(cmd: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]:
"""Run a subprocess command and return the result.
@@ -81,41 +74,43 @@ def download_logs(run_id: str, repo: str) -> dict[str, str]:
raise RuntimeError(msg)
logs: dict[str, str] = {}
with zipfile.ZipFile(io.BytesIO(result.stdout)) as zf:
for name in zf.namelist():
with ZipFile(BytesIO(result.stdout)) as zip_file:
for name in zip_file.namelist():
if name.startswith("build-") and name.endswith(".txt"):
logs[name] = zf.read(name).decode(errors="replace")
logs[name] = zip_file.read(name).decode(errors="replace")
return logs
def parse_warnings(logs: dict[str, str]) -> list[EvalWarning]:
def parse_warnings(logs: dict[str, str]) -> set[EvalWarning]:
"""Parse Nix evaluation warnings from build log contents.
Args:
logs: Dict mapping zip entry names (e.g. "build-bob/2_Build.txt") to their text.
Returns:
Deduplicated list of warnings.
Deduplicated set of warnings.
"""
warnings: list[EvalWarning] = []
seen: set[str] = set()
warnings: set[EvalWarning] = set()
warning_pattern = re.compile(r"(?:^[\d\-T:.Z]+ )?(warning:|trace: warning:)")
timestamp_prefix = re.compile(r"^[\d\-T:.Z]+ ")
for name, content in sorted(logs.items()):
system = name.split("/")[0].removeprefix("build-")
for line in content.splitlines():
if WARNING_PATTERN.search(line):
message = TIMESTAMP_PREFIX.sub("", line).strip()
key = f"{system}:{message}"
if key not in seen:
seen.add(key)
warnings.append(EvalWarning(system=system, message=message))
if line.startswith("warning: ignoring untrusted flake configuration setting"):
logger.info("Ignoring untrusted flake configuration setting warning.")
continue
if warning_pattern.search(line):
logger.debug(f"Found warning: {line}")
message = timestamp_prefix.sub("", line).strip()
warnings.add(EvalWarning(system=system, message=message))
logger.info("Found %d unique warnings", len(warnings))
return warnings
def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]:
def extract_referenced_files(warnings: set[EvalWarning]) -> dict[str, str]:
"""Extract file paths referenced in warnings and read their contents.
Args:
@@ -127,9 +122,12 @@ def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]:
paths: set[str] = set()
warning_text = "\n".join(w.message for w in warnings)
for match in NIX_STORE_PATH.finditer(warning_text):
nix_store_path = re.compile(r"/nix/store/[^/]+-source/([^:]+\.nix)")
for match in nix_store_path.finditer(warning_text):
paths.add(match.group(1))
for match in REPO_RELATIVE_PATH.finditer(warning_text):
repo_relative_path = re.compile(r"(?<![/\w])(systems|common|users|overlays)/[^:\s]+\.nix")
for match in repo_relative_path.finditer(warning_text):
paths.add(match.group(0))
files: dict[str, str] = {}
@@ -145,7 +143,7 @@ def extract_referenced_files(warnings: list[EvalWarning]) -> dict[str, str]:
return files
def compute_warning_hash(warnings: list[EvalWarning]) -> str:
def compute_warning_hash(warnings: set[EvalWarning]) -> str:
"""Compute a short hash of the warning set for deduplication.
Args:
@@ -198,7 +196,7 @@ def check_duplicate_pr(warning_hash: str) -> bool:
def query_ollama(
warnings: list[EvalWarning],
warnings: set[EvalWarning],
files: dict[str, str],
ollama_url: str,
) -> str | None:
@@ -243,7 +241,7 @@ Analyze the following Nix evaluation warnings and suggest fixes.
say so in REASONING and do not suggest changes"""
try:
response = httpx.post(
response = post(
f"{ollama_url}/api/generate",
json={
"model": "qwen3-coder:30b",
@@ -254,7 +252,7 @@ say so in REASONING and do not suggest changes"""
timeout=300,
)
response.raise_for_status()
except httpx.HTTPError:
except HTTPError:
logger.exception("Ollama request failed")
return None
@@ -280,36 +278,27 @@ def parse_changes(response: str) -> list[FileChange]:
"""
changes: list[FileChange] = []
current_file = ""
in_original = False
in_fixed = False
section: str | None = None
original_lines: list[str] = []
fixed_lines: list[str] = []
for line in response.splitlines():
file_match = CHANGE_FILE_PATTERN.match(line)
if file_match:
current_file = file_match.group(1).strip()
elif line.strip() == "<<<<<<< ORIGINAL":
in_original = True
in_fixed = False
stripped = line.strip()
if stripped.startswith("FILE:"):
current_file = stripped.removeprefix("FILE:").strip()
elif stripped == "<<<<<<< ORIGINAL":
section = "original"
original_lines = []
elif line.strip() == "=======" and in_original:
in_original = False
in_fixed = True
elif stripped == "=======" and section == "original":
section = "fixed"
fixed_lines = []
elif line.strip() == ">>>>>>> FIXED" and in_fixed:
in_fixed = False
elif stripped == ">>>>>>> FIXED" and section == "fixed":
section = None
if current_file:
changes.append(
FileChange(
file_path=current_file,
original="\n".join(original_lines),
fixed="\n".join(fixed_lines),
)
)
elif in_original:
changes.append(FileChange(current_file, "\n".join(original_lines), "\n".join(fixed_lines)))
elif section == "original":
original_lines.append(line)
elif in_fixed:
elif section == "fixed":
fixed_lines.append(line)
logger.info("Parsed %d file changes", len(changes))
@@ -346,7 +335,7 @@ def apply_changes(changes: list[FileChange]) -> int:
def create_pr(
warning_hash: str,
warnings: list[EvalWarning],
warnings: set[EvalWarning],
llm_response: str,
run_url: str,
) -> None:
@@ -361,16 +350,12 @@ def create_pr(
branch = f"fix/eval-warning-{warning_hash}"
warning_text = "\n".join(f"[{w.system}] {w.message}" for w in warnings)
reasoning_lines: list[str] = []
capturing = False
for line in llm_response.splitlines():
if "**REASONING**" in line:
capturing = True
elif "**CHANGES**" in line:
break
elif capturing:
reasoning_lines.append(line)
reasoning = "\n".join(reasoning_lines[:50])
if "**REASONING**" not in llm_response:
logger.warning("LLM response missing **REASONING** section")
reasoning = ""
else:
_, after = llm_response.split("**REASONING**", 1)
reasoning = "\n".join(after.split("**CHANGES**", 1)[0].strip().splitlines()[:50])
run_cmd(["git", "config", "user.name", "github-actions[bot]"])
run_cmd(["git", "config", "user.email", "github-actions[bot]@users.noreply.github.com"])