diff --git a/.github/workflows/fix_eval_warnings.yml b/.github/workflows/fix_eval_warnings.yml index 0af067d..6c9af48 100644 --- a/.github/workflows/fix_eval_warnings.yml +++ b/.github/workflows/fix_eval_warnings.yml @@ -31,6 +31,7 @@ jobs: GITHUB_TOKEN: ${{ github.token }} GITHUB_REPOSITORY: ${{ github.repository }} RUN_ID: ${{ github.event.workflow_run.id }} + PYTHONPATH: . run: | python3 python/tools/fix_eval_warnings.py build.log diff --git a/python/tools/fix_eval_warnings.py b/python/tools/fix_eval_warnings.py index 7231388..013d722 100755 --- a/python/tools/fix_eval_warnings.py +++ b/python/tools/fix_eval_warnings.py @@ -1,127 +1,161 @@ #!/usr/bin/env python3 -""" -Script to detect "evaluation warning:" in logs and suggest fixes using GitHub Models. -""" +"""fix_eval_warnings.""" + +from __future__ import annotations + +import logging import os -import sys -import re -import requests -import json +from dataclasses import dataclass from pathlib import Path -# Configuration -GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN") -GITHUB_REPOSITORY = os.environ.get("GITHUB_REPOSITORY") -PR_NUMBER = os.environ.get("PR_NUMBER") # If triggered by PR -RUN_ID = os.environ.get("RUN_ID") +import requests +import typer -# GitHub Models API Endpoint (OpenAI compatible) -# https://github.com/marketplace/models -API_BASE = "https://models.inference.ai.azure.com" -# Default to gpt-4o, but allow override via env var -MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o") +from python.common import configure_logger -def get_log_content(run_id): - """Fetches the logs for a specific workflow run.""" - print(f"Fetching logs for run ID: {run_id}") - headers = { - "Authorization": f"Bearer {GITHUB_TOKEN}", - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28" - } - +logger = logging.getLogger(__name__) + + +@dataclass +class Config: + """Configuration for the script. + + Attributes: + github_token (str): GitHub token for API authentication. + model_name (str): The name of the LLM model to use. Defaults to "gpt-4o". + api_base (str): The base URL for the GitHub Models API. + Defaults to "https://models.inference.ai.azure.com". + """ + + github_token: str + model_name: str = "gpt-4o" + api_base: str = "https://models.inference.ai.azure.com" + + +def get_log_content(run_id: str) -> None: + """Fetch the logs for a specific workflow run. + + Args: + run_id (str): The run ID. + """ + logger.info(f"Fetching logs for run ID: {run_id}") # List artifacts to find logs (or use jobs API) # For simplicity, we might need to use 'gh' cli in the workflow to download logs # But let's try to read from a file if passed as argument, which is easier for the workflow - return None -def parse_warnings(log_file_path): - """Parses the log file for evaluation warnings.""" + +def parse_warnings(log_file_path: Path) -> list[str]: + """Parse the log file for evaluation warnings. + + Args: + log_file_path (Path): The path to the log file. + + Returns: + list[str]: A list of warning messages. + """ warnings = [] - with open(log_file_path, 'r', encoding='utf-8', errors='ignore') as f: - for line in f: - if "evaluation warning:" in line: - warnings.append(line.strip()) + with log_file_path.open(encoding="utf-8", errors="ignore") as f: + warnings.extend(line.strip() for line in f if "evaluation warning:" in line) return warnings -def generate_fix(warning_msg): - """Calls GitHub Models to generate a fix for the warning.""" - print(f"Generating fix for: {warning_msg}") - + +def generate_fix(warning_msg: str, config: Config) -> str | None: + """Call GitHub Models to generate a fix for the warning. + + Args: + warning_msg (str): The warning message. + config (Config): The configuration object. + + Returns: + Optional[str]: The suggested fix or None. + """ + logger.info(f"Generating fix for: {warning_msg}") + prompt = f""" I encountered the following Nix evaluation warning: - + `{warning_msg}` - - Please explain what this warning means and suggest how to fix it in the Nix code. + + Please explain what this warning means and suggest how to fix it in the Nix code. If possible, provide the exact code change in a diff format or a clear description of what to change. """ - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {GITHUB_TOKEN}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {config.github_token}"} payload = { "messages": [ - { - "role": "system", - "content": "You are an expert NixOS and Nix language developer." - }, - { - "role": "user", - "content": prompt - } + {"role": "system", "content": "You are an expert NixOS and Nix language developer."}, + {"role": "user", "content": prompt}, ], - "model": MODEL_NAME, - "temperature": 0.1 + "model": config.model_name, + "temperature": 0.1, } try: - response = requests.post( - f"{API_BASE}/chat/completions", - headers=headers, - json=payload - ) + response = requests.post(f"{config.api_base}/chat/completions", headers=headers, json=payload, timeout=30) response.raise_for_status() result = response.json() - return result['choices'][0]['message']['content'] - except Exception as e: - print(f"Error calling LLM: {e}") + return result["choices"][0]["message"]["content"] # type: ignore[no-any-return] + except Exception: + logger.exception("Error calling LLM") return None -def main(): - if len(sys.argv) < 2: - print("Usage: fix_eval_warnings.py ") - sys.exit(1) - log_file = sys.argv[1] - if not os.path.exists(log_file): - print(f"Log file not found: {log_file}") - sys.exit(1) +def main( + log_file: Path = typer.Argument(..., help="Path to the build log file"), # noqa: B008 + model_name: str = typer.Option("gpt-4o", envvar="MODEL_NAME", help="LLM Model Name"), +) -> None: + """Detect evaluation warnings in logs and suggest fixes using GitHub Models. + + Args: + log_file (Path): Path to the build log file containing evaluation warnings. + model_name (str): The name of the LLM model to use for generating fixes. + Defaults to "gpt-4o", can be overridden by MODEL_NAME environment variable. + """ + configure_logger() + + github_token = os.environ.get("GITHUB_TOKEN") + if not github_token: + logger.warning("GITHUB_TOKEN not set. LLM calls will fail.") + + config = Config(github_token=github_token or "", model_name=model_name) + + if not log_file.exists(): + logger.error(f"Log file not found: {log_file}") + raise typer.Exit(code=1) warnings = parse_warnings(log_file) if not warnings: - print("No evaluation warnings found.") - sys.exit(0) + logger.info("No evaluation warnings found.") + raise typer.Exit(code=0) + + logger.info(f"Found {len(warnings)} warnings.") - print(f"Found {len(warnings)} warnings.") - # Process unique warnings to save tokens unique_warnings = list(set(warnings)) - + fixes = [] for warning in unique_warnings: - fix = generate_fix(warning) + if not config.github_token: + logger.warning("Skipping LLM call due to missing GITHUB_TOKEN") + continue + + fix = generate_fix(warning, config) if fix: fixes.append(f"## Warning\n`{warning}`\n\n## Suggested Fix\n{fix}\n") # Output fixes to a markdown file for the PR body - with open("fix_suggestions.md", "w") as f: - f.write("# Automated Fix Suggestions\n\n") - f.write("\n---\n".join(fixes)) + if fixes: + with Path("fix_suggestions.md").open("w") as f: + f.write("# Automated Fix Suggestions\n\n") + f.write("\n---\n".join(fixes)) + logger.info("Fix suggestions written to fix_suggestions.md") + else: + logger.info("No fixes generated.") - print("Fix suggestions written to fix_suggestions.md") + +app = typer.Typer() +app.command()(main) if __name__ == "__main__": - main() + app() diff --git a/tests/test_fix_eval_warnings.py b/tests/test_fix_eval_warnings.py new file mode 100644 index 0000000..f4d2e3d --- /dev/null +++ b/tests/test_fix_eval_warnings.py @@ -0,0 +1,62 @@ +"""Tests for fix_eval_warnings.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from python.tools.fix_eval_warnings import Config, app, generate_fix, parse_warnings + +runner = CliRunner() + + +@pytest.fixture +def log_file(tmp_path: Path) -> Path: + """Create a dummy log file.""" + log_path = tmp_path / "build.log" + log_path.write_text("Some output\nevaluation warning: 'system' is deprecated\nMore output", encoding="utf-8") + return log_path + + +def test_parse_warnings(log_file: Path) -> None: + """Test parsing warnings from a log file.""" + warnings = parse_warnings(log_file) + assert len(warnings) == 1 + assert warnings[0] == "evaluation warning: 'system' is deprecated" + + +@patch("python.tools.fix_eval_warnings.requests.post") +def test_generate_fix(mock_post: MagicMock) -> None: + """Test generating a fix.""" + mock_response = MagicMock() + mock_response.json.return_value = {"choices": [{"message": {"content": "Use stdenv.hostPlatform.system"}}]} + mock_post.return_value = mock_response + + config = Config(github_token="dummy_token") + fix = generate_fix("evaluation warning: 'system' is deprecated", config) + + assert fix == "Use stdenv.hostPlatform.system" + mock_post.assert_called_once() + + +@patch("python.tools.fix_eval_warnings.logger") +@patch("python.tools.fix_eval_warnings.generate_fix") +def test_main(mock_generate_fix: MagicMock, mock_logger: MagicMock, log_file: Path) -> None: + """Test the main CLI.""" + mock_generate_fix.return_value = "Fixed it" + + # We need to mock GITHUB_TOKEN env var or the script will warn/fail + with patch.dict("os.environ", {"GITHUB_TOKEN": "dummy"}): + result = runner.invoke(app, [str(log_file)]) + + assert result.exit_code == 0 + # Verify logger calls instead of stdout, as CliRunner might not capture logging output correctly + # when logging is configured to write to sys.stdout directly. + assert any("Found 1 warnings" in str(call) for call in mock_logger.info.call_args_list) + assert any( + "Fix suggestions written to fix_suggestions.md" in str(call) + for call in mock_logger.info.call_args_list + ) + assert Path("fix_suggestions.md").exists() + Path("fix_suggestions.md").unlink()