converted default_config_path and get_config_dir
This commit is contained in:
@@ -69,8 +69,9 @@ class BenchmarkConfig:
|
||||
|
||||
|
||||
def get_config_dir() -> Path:
|
||||
"""Get the path to the config file."""
|
||||
return Path(__file__).resolve().parent.parent.parent / "config"
|
||||
"""Get the path to the config directory."""
|
||||
return Path(__file__).resolve().parents[2] / "config"
|
||||
|
||||
|
||||
def default_config_path() -> Path:
|
||||
"""Get the path to the config file."""
|
||||
|
||||
@@ -23,14 +23,10 @@ import httpx
|
||||
import typer
|
||||
from tiktoken import Encoding, get_encoding
|
||||
|
||||
from pipelines.config import get_config_dir
|
||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||
|
||||
_PROMPTS_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "config"
|
||||
/ "prompts"
|
||||
/ "summarization_prompts.toml"
|
||||
)
|
||||
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||
|
||||
@@ -24,14 +24,10 @@ from typing import Annotated
|
||||
import httpx
|
||||
import typer
|
||||
|
||||
from pipelines.config import get_config_dir
|
||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||
|
||||
_PROMPTS_PATH = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "config"
|
||||
/ "prompts"
|
||||
/ "summarization_prompts.toml"
|
||||
)
|
||||
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||
|
||||
@@ -25,6 +25,8 @@ from datasets import Dataset
|
||||
from transformers import TrainingArguments
|
||||
from trl import SFTTrainer
|
||||
|
||||
from pipelines.config import default_config_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -123,7 +125,7 @@ def main(
|
||||
config_path: Annotated[
|
||||
Path,
|
||||
typer.Option("--config", help="TOML config file"),
|
||||
] = Path(__file__).parent / "config.toml",
|
||||
] = default_config_path(),
|
||||
save_gguf: Annotated[
|
||||
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
||||
] = False,
|
||||
|
||||
Reference in New Issue
Block a user