from __future__ import annotations from dataclasses import dataclass from os import getenv from datetime import date from pathlib import Path import tomllib @dataclass class LoraConfig: """LoRA adapter hyperparameters.""" rank: int alpha: int dropout: float targets: list[str] @dataclass class TrainingConfig: """Training loop hyperparameters.""" learning_rate: float epochs: int batch_size: int gradient_accumulation: int max_seq_length: int warmup_ratio: float weight_decay: float logging_steps: int save_steps: int @dataclass class FinetuneConfig: """Top-level finetune configuration.""" base_model: str lora: LoraConfig training: TrainingConfig @classmethod def from_toml(cls, config_path: Path) -> FinetuneConfig: """Load finetune config from a TOML file.""" raw = tomllib.loads(config_path.read_text())["finetune"] return cls( base_model=raw["base_model"], lora=LoraConfig(**raw["lora"]), training=TrainingConfig(**raw["training"]), ) class BenchmarkConfig: """Top-level benchmark configuration loaded from TOML.""" models: list[str] model_dir: str port: int gpu_memory_utilization: float temperature: float timeout: int concurrency: int vllm_startup_timeout: int @classmethod def from_toml(cls, config_path: Path) -> BenchmarkConfig: """Load benchmark config from a TOML file.""" raw = tomllib.loads(config_path.read_text())["bench"] return cls(**raw) @dataclass class OpenAIConfig: """OpenAI API configuration.""" api_key: str openai_project_id: str openai_chat_completions_url: str model: str timeout_seconds: int @classmethod def from_toml(cls, config_path: Path) -> OpenAIConfig: """Load OpenAI config from a TOML file.""" raw = tomllib.loads(config_path.read_text()).get("openai", {}) api_key = getenv("CLOSEDAI_TOKEN") if not api_key: message = "CLOSEDAI_TOKEN is required" raise KeyError(message) return cls( api_key=api_key, openai_project_id=raw.get( "openai_project_id", "proj_fQBPEXFgnS87Fk6wZwploFwE" ), openai_chat_completions_url=raw.get( "openai_chat_completions_url", "https://api.openai.com/v1/chat/completions", ), model=raw.get("model", "gpt-5.4-mini"), timeout_seconds=raw.get("timeout_seconds", 60), ) class BertTopicTrainConfig: """BERTopic training configuration loaded from TOML.""" sample_rate: float min_text_length: int n_topics: int model_save_path: str model_version: str | None = None @classmethod def from_toml(cls, config_path: Path) -> BertTopicTrainConfig: """Load BERTopic training config from a TOML file.""" raw = tomllib.loads(config_path.read_text())["bertopic"]["train"] today = date.today().isoformat() if raw.get("model_version") is None: raw["model_version"] = ( f"{today}-{raw['sample_rate']}-{raw['min_text_length']}-{raw['n_topics']}" ) return cls(**raw) @dataclass class BertTopicInferConfig: """BERTopic inference configuration loaded from TOML.""" min_text_length: int poc_batch_size: int model_version: str model_save_path: str @classmethod def from_toml(cls, config_path: Path) -> BertTopicInferConfig: """Load BERTopic inference config from a TOML file.""" raw = tomllib.loads(config_path.read_text())["bertopic"]["infer"] return cls(**raw) def get_config_dir() -> Path: """Get the path to the config directory.""" return Path(__file__).resolve().parents[2] / "config" def default_config_path() -> Path: """Get the path to the config file.""" return get_config_dir() / "config.toml" def get_openai_config(config_path: Path | None = None) -> OpenAIConfig: if config_path is None: config_path = default_config_path() return OpenAIConfig.from_toml(config_path) def get_finetune_config(config_path: Path | None = None) -> FinetuneConfig: if config_path is None: config_path = default_config_path() return FinetuneConfig.from_toml(config_path) def get_benchmark_config(config_path: Path | None = None) -> BenchmarkConfig: if config_path is None: config_path = default_config_path() return BenchmarkConfig.from_toml(config_path) def get_bertopic_train_config( config_path: Path | None = None, ) -> BertTopicTrainConfig: if config_path is None: config_path = default_config_path() return BertTopicTrainConfig.from_toml(config_path) def get_bertopic_infer_config( config_path: Path | None = None, ) -> BertTopicInferConfig: if config_path is None: config_path = default_config_path() return BertTopicInferConfig.from_toml(config_path)