Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 36718bbce0 | |||
| 2facb82bd4 | |||
| 8d5a6e202b | |||
| f32c895561 |
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from os import getenv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
@@ -68,15 +69,54 @@ class BenchmarkConfig:
|
|||||||
return cls(**raw)
|
return cls(**raw)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OpenAIConfig:
|
||||||
|
"""OpenAI API configuration."""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
openai_project_id: str
|
||||||
|
openai_chat_completions_url: str
|
||||||
|
model: str
|
||||||
|
timeout_seconds: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> OpenAIConfig:
|
||||||
|
"""Load OpenAI config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text()).get("openai", {})
|
||||||
|
api_key = getenv("CLOSEDAI_TOKEN")
|
||||||
|
if not api_key:
|
||||||
|
message = "CLOSEDAI_TOKEN is required"
|
||||||
|
raise KeyError(message)
|
||||||
|
return cls(
|
||||||
|
api_key=api_key,
|
||||||
|
openai_project_id=raw.get(
|
||||||
|
"openai_project_id", "proj_fQBPEXFgnS87Fk6wZwploFwE"
|
||||||
|
),
|
||||||
|
openai_chat_completions_url=raw.get(
|
||||||
|
"openai_chat_completions_url",
|
||||||
|
"https://api.openai.com/v1/chat/completions",
|
||||||
|
),
|
||||||
|
model=raw.get("model", "gpt-5.4-mini"),
|
||||||
|
timeout_seconds=raw.get("timeout_seconds", 60),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dir() -> Path:
|
def get_config_dir() -> Path:
|
||||||
"""Get the path to the config file."""
|
"""Get the path to the config directory."""
|
||||||
return Path(__file__).resolve().parent.parent.parent / "config"
|
return Path(__file__).resolve().parents[2] / "config"
|
||||||
|
|
||||||
|
|
||||||
def default_config_path() -> Path:
|
def default_config_path() -> Path:
|
||||||
"""Get the path to the config file."""
|
"""Get the path to the config file."""
|
||||||
return get_config_dir() / "config.toml"
|
return get_config_dir() / "config.toml"
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_config(config_path: Path | None = None) -> OpenAIConfig:
|
||||||
|
if config_path is None:
|
||||||
|
config_path = default_config_path()
|
||||||
|
return OpenAIConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
|
||||||
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = default_config_path()
|
config_path = default_config_path()
|
||||||
@@ -23,10 +23,14 @@ import httpx
|
|||||||
import typer
|
import typer
|
||||||
from tiktoken import Encoding, get_encoding
|
from tiktoken import Encoding, get_encoding
|
||||||
|
|
||||||
from pipelines.config import get_config_dir
|
|
||||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
_PROMPTS_PATH = (
|
||||||
|
Path(__file__).resolve().parents[2]
|
||||||
|
/ "config"
|
||||||
|
/ "prompts"
|
||||||
|
/ "summarization_prompts.toml"
|
||||||
|
)
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|||||||
@@ -24,10 +24,14 @@ from typing import Annotated
|
|||||||
import httpx
|
import httpx
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from pipelines.config import get_config_dir
|
|
||||||
from pipelines.tools.bill_token_compression import compress_bill_text
|
from pipelines.tools.bill_token_compression import compress_bill_text
|
||||||
|
|
||||||
_PROMPTS_PATH = get_config_dir() / "prompts" / "summarization_prompts.toml"
|
_PROMPTS_PATH = (
|
||||||
|
Path(__file__).resolve().parents[2]
|
||||||
|
/ "config"
|
||||||
|
/ "prompts"
|
||||||
|
/ "summarization_prompts.toml"
|
||||||
|
)
|
||||||
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
_PROMPTS = tomllib.loads(_PROMPTS_PATH.read_text())["summarization"]
|
||||||
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
SUMMARIZATION_SYSTEM_PROMPT: str = _PROMPTS["system_prompt"]
|
||||||
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
SUMMARIZATION_USER_TEMPLATE: str = _PROMPTS["user_template"]
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ from datasets import Dataset
|
|||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import SFTTrainer
|
from trl import SFTTrainer
|
||||||
|
|
||||||
from pipelines.config import default_config_path
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -125,7 +123,7 @@ def main(
|
|||||||
config_path: Annotated[
|
config_path: Annotated[
|
||||||
Path,
|
Path,
|
||||||
typer.Option("--config", help="TOML config file"),
|
typer.Option("--config", help="TOML config file"),
|
||||||
] = default_config_path(),
|
] = Path(__file__).parent / "config.toml",
|
||||||
save_gguf: Annotated[
|
save_gguf: Annotated[
|
||||||
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")
|
||||||
] = False,
|
] = False,
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ from typing import Annotated
|
|||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from pipelines.containers.lib import check_gpu_free
|
from pipelines.tools.containers.lib import check_gpu_free
|
||||||
from pipelines.containers.vllm import start_vllm, stop_vllm
|
from pipelines.tools.containers.vllm import start_vllm, stop_vllm
|
||||||
from pipelines.tools.downloader import is_model_present
|
from pipelines.tools.downloader import is_model_present
|
||||||
from pipelines.tools.models import BenchmarkConfig
|
from pipelines.tools.models import BenchmarkConfig
|
||||||
from pipelines.tools.vllm_client import VLLMClient
|
from pipelines.tools.vllm_client import VLLMClient
|
||||||
|
|||||||
Reference in New Issue
Block a user