converted default_config_path and get_config_dir

This commit is contained in:
2026-04-21 22:51:10 -04:00
parent 51d6240690
commit 87a7f5312f
4 changed files with 10 additions and 15 deletions

View File

@@ -69,8 +69,9 @@ class BenchmarkConfig:
def get_config_dir() -> Path: def get_config_dir() -> Path:
"""Get the path to the config file.""" """Get the path to the config directory."""
return Path(__file__).resolve().parent.parent.parent / "config" return Path(__file__).resolve().parents[2] / "config"
def default_config_path() -> Path: def default_config_path() -> Path:
"""Get the path to the config file.""" """Get the path to the config file."""

View File

@@ -23,14 +23,10 @@ import httpx
import typer import typer
from tiktoken import Encoding, get_encoding from tiktoken import Encoding, get_encoding
from pipelines.config import get_config_dir
from pipelines.tools.bill_token_compression import compress_bill_text from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = ( _PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]

View File

@@ -24,14 +24,10 @@ from typing import Annotated
import httpx import httpx
import typer import typer
from pipelines.config import get_config_dir
from pipelines.tools.bill_token_compression import compress_bill_text from pipelines.tools.bill_token_compression import compress_bill_text
_PROMPTS_PATH = ( _PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
Path(__file__).resolve().parents[2]
/ "config"
/ "prompts"
/ "summarization_prompts.toml"
)
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]

View File

@@ -25,6 +25,8 @@ from datasets import Dataset
from transformers import TrainingArguments from transformers import TrainingArguments
from trl import SFTTrainer from trl import SFTTrainer
from pipelines.config import default_config_path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -123,7 +125,7 @@ def main(
config_path: Annotated[ config_path: Annotated[
Path, Path,
typer.Option("--config", help="TOML config file"), typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml", ] = default_config_path(),
save_gguf: Annotated[ save_gguf: Annotated[
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF") bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
] = False, ] = False,