diff --git a/pipelines/config.py b/pipelines/config.py index b70df9b..24bbf15 100644 --- a/pipelines/config.py +++ b/pipelines/config.py @@ -69,8 +69,9 @@ class BenchmarkConfig: def get_config_dir() -> Path: - """Get the path to the config file.""" - return Path(__file__).resolve().parent.parent.parent / "config" + """Get the path to the config directory.""" + return Path(__file__).resolve().parents[2] / "config" + def default_config_path() -> Path: """Get the path to the config file.""" diff --git a/pipelines/tools/batch_bill_summarizer.py b/pipelines/tools/batch_bill_summarizer.py index 33c5e6d..ee1a1c5 100644 --- a/pipelines/tools/batch_bill_summarizer.py +++ b/pipelines/tools/batch_bill_summarizer.py @@ -23,14 +23,10 @@ import httpx import typer from tiktoken import Encoding, get_encoding +from pipelines.config import get_config_dir from pipelines.tools.bill_token_compression import compress_bill_text -_PROMPTS_PATH = ( - Path(__file__).resolve().parents[2] - / "config" - / "prompts" - / "summarization_prompts.toml" -) +_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml" _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] diff --git a/pipelines/tools/compresion_test.py b/pipelines/tools/compresion_test.py index c84c298..3814a20 100644 --- a/pipelines/tools/compresion_test.py +++ b/pipelines/tools/compresion_test.py @@ -24,14 +24,10 @@ from typing import Annotated import httpx import typer +from pipelines.config import get_config_dir from pipelines.tools.bill_token_compression import compress_bill_text -_PROMPTS_PATH = ( - Path(__file__).resolve().parents[2] - / "config" - / "prompts" - / "summarization_prompts.toml" -) +_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml" _PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"] SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"] SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"] diff --git a/pipelines/tools/finetune.py b/pipelines/tools/finetune.py index 3fccf57..33df584 100644 --- a/pipelines/tools/finetune.py +++ b/pipelines/tools/finetune.py @@ -25,6 +25,8 @@ from datasets import Dataset from transformers import TrainingArguments from trl import SFTTrainer +from pipelines.config import default_config_path + logger = logging.getLogger(__name__) @@ -123,7 +125,7 @@ def main( config_path: Annotated[ Path, typer.Option("--config", help="TOML config file"), - ] = Path(__file__).parent / "config.toml", + ] = default_config_path(), save_gguf: Annotated[ bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF") ] = False,