147 lines
3.8 KiB
Python
147 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import date
|
|
from pathlib import Path
|
|
import tomllib
|
|
|
|
|
|
@dataclass
|
|
class LoraConfig:
|
|
"""LoRA adapter hyperparameters."""
|
|
|
|
rank: int
|
|
alpha: int
|
|
dropout: float
|
|
targets: list[str]
|
|
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
"""Training loop hyperparameters."""
|
|
|
|
learning_rate: float
|
|
epochs: int
|
|
batch_size: int
|
|
gradient_accumulation: int
|
|
max_seq_length: int
|
|
warmup_ratio: float
|
|
weight_decay: float
|
|
logging_steps: int
|
|
save_steps: int
|
|
|
|
|
|
@dataclass
|
|
class FinetuneConfig:
|
|
"""Top-level finetune configuration."""
|
|
|
|
base_model: str
|
|
lora: LoraConfig
|
|
training: TrainingConfig
|
|
|
|
@classmethod
|
|
def from_toml(cls, config_path: Path) -> FinetuneConfig:
|
|
"""Load finetune config from a TOML file."""
|
|
raw = tomllib.loads(config_path.read_text())["finetune"]
|
|
return cls(
|
|
base_model=raw["base_model"],
|
|
lora=LoraConfig(**raw["lora"]),
|
|
training=TrainingConfig(**raw["training"]),
|
|
)
|
|
|
|
|
|
class BenchmarkConfig:
|
|
"""Top-level benchmark configuration loaded from TOML."""
|
|
|
|
models: list[str]
|
|
model_dir: str
|
|
port: int
|
|
gpu_memory_utilization: float
|
|
temperature: float
|
|
timeout: int
|
|
concurrency: int
|
|
vllm_startup_timeout: int
|
|
|
|
@classmethod
|
|
def from_toml(cls, config_path: Path) -> BenchmarkConfig:
|
|
"""Load benchmark config from a TOML file."""
|
|
raw = tomllib.loads(config_path.read_text())["bench"]
|
|
return cls(**raw)
|
|
|
|
|
|
@dataclass
|
|
class BertTopicTrainConfig:
|
|
"""BERTopic training configuration loaded from TOML."""
|
|
|
|
sample_rate: float
|
|
min_text_length: int
|
|
n_topics: int
|
|
model_save_path: str
|
|
model_version: str | None = None
|
|
|
|
@classmethod
|
|
def from_toml(cls, config_path: Path) -> BertTopicTrainConfig:
|
|
"""Load BERTopic training config from a TOML file."""
|
|
raw = tomllib.loads(config_path.read_text())["bertopic"]["train"]
|
|
|
|
today = date.today().isoformat()
|
|
if raw.get("model_version") is None:
|
|
raw["model_version"] = (
|
|
f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}"
|
|
)
|
|
return cls(**raw)
|
|
|
|
|
|
@dataclass
|
|
class BertTopicInferConfig:
|
|
"""BERTopic inference configuration loaded from TOML."""
|
|
|
|
min_text_length: int
|
|
poc_batch_size: int
|
|
model_version: str
|
|
model_save_path: str
|
|
|
|
@classmethod
|
|
def from_toml(cls, config_path: Path) -> BertTopicInferConfig:
|
|
"""Load BERTopic inference config from a TOML file."""
|
|
raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"]
|
|
return cls(**raw)
|
|
|
|
|
|
def get_config_dir() -> Path:
|
|
"""Get the path to the config file."""
|
|
return Path(__file__).resolve().parent.parent.parent / "config"
|
|
|
|
|
|
def default_config_path() -> Path:
|
|
"""Get the path to the config file."""
|
|
return get_config_dir() / "config.toml"
|
|
|
|
|
|
def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig:
|
|
if config_path is None:
|
|
config_path = default_config_path()
|
|
return FinetuneConfig.from_toml(config_path)
|
|
|
|
|
|
def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig:
|
|
if config_path is None:
|
|
config_path = default_config_path()
|
|
return BenchmarkConfig.from_toml(config_path)
|
|
|
|
|
|
def get_bertopic_train_config(
|
|
config_path: Path | None = None,
|
|
) -> BertTopicTrainConfig:
|
|
if config_path is None:
|
|
config_path = default_config_path()
|
|
return BertTopicTrainConfig.from_toml(config_path)
|
|
|
|
|
|
def get_bertopic_infer_config(
|
|
config_path: Path | None = None,
|
|
) -> BertTopicInferConfig:
|
|
if config_path is None:
|
|
config_path = default_config_path()
|
|
return BertTopicInferConfig.from_toml(config_path)
|