Files
pipelines/config.py
2026-04-13 15:43:01 -04:00

90 lines
2.2 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import tomllib
@dataclass
class LoraConfig:
"""LoRA adapter hyperparameters."""
rank: int
alpha: int
dropout: float
targets: list[str]
@dataclass
class TrainingConfig:
"""Training loop hyperparameters."""
learning_rate: float
epochs: int
batch_size: int
gradient_accumulation: int
max_seq_length: int
warmup_ratio: float
weight_decay: float
logging_steps: int
save_steps: int
@dataclass
class FinetuneConfig:
"""Top-level finetune configuration."""
base_model: str
lora: LoraConfig
training: TrainingConfig
@classmethod
def from_toml(cls, config_path: Path) -> FinetuneConfig:
"""Load finetune config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["finetune"]
return cls(
base_model=raw["base_model"],
lora=LoraConfig(**raw["lora"]),
training=TrainingConfig(**raw["training"]),
)
class BenchmarkConfig:
"""Top-level benchmark configuration loaded from TOML."""
models: list[str]
model_dir: str
port: int
gpu_memory_utilization: float
temperature: float
timeout: int
concurrency: int
vllm_startup_timeout: int
@classmethod
def from_toml(cls, config_path: Path) -> BenchmarkConfig:
"""Load benchmark config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["bench"]
return cls(**raw)
def get_config_dir() -> Path:
"""Get the path to the config file."""
return Path(__file__).resolve().parent.parent.parent / "config"
def default_config_path() -> Path:
"""Get the path to the config file."""
return get_config_dir() / "config.toml"
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
if config_path is None:
config_path = default_config_path()
return FinetuneConfig.from_toml(config_path)
def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
if config_path is None:
config_path = default_config_path()
return BenchmarkConfig.from_toml(config_path)