diff --git a/python/prompt_bench/finetune.py b/python/prompt_bench/finetune.py index 30ae3d3..3bcea4a 100644 --- a/python/prompt_bench/finetune.py +++ b/python/prompt_bench/finetune.py @@ -14,9 +14,11 @@ from __future__ import annotations import json import logging +from dataclasses import dataclass from pathlib import Path from typing import Annotated +import tomllib import typer from unsloth import FastLanguageModel from datasets import Dataset @@ -25,32 +27,49 @@ from trl import SFTTrainer logger = logging.getLogger(__name__) -BASE_MODEL = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit" -# LoRA hyperparameters -LORA_RANK = 32 -LORA_ALPHA = 32 -LORA_DROPOUT = 0.0 -LORA_TARGETS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] +@dataclass +class LoraConfig: + """LoRA adapter hyperparameters.""" -# Training hyperparameters tuned for ~2k examples on a 3090 -LEARNING_RATE = 2e-4 -EPOCHS = 3 -BATCH_SIZE = 2 -GRADIENT_ACCUMULATION = 8 # effective batch = 16 -MAX_SEQ_LENGTH = 4096 -WARMUP_RATIO = 0.05 -WEIGHT_DECAY = 0.01 -LOGGING_STEPS = 10 -SAVE_STEPS = 100 + rank: int + alpha: int + dropout: float + targets: list[str] + + +@dataclass +class TrainingConfig: + """Training loop hyperparameters.""" + + learning_rate: float + epochs: int + batch_size: int + gradient_accumulation: int + max_seq_length: int + warmup_ratio: float + weight_decay: float + logging_steps: int + save_steps: int + + +@dataclass +class FinetuneConfig: + """Top-level finetune configuration.""" + + base_model: str + lora: LoraConfig + training: TrainingConfig + + @classmethod + def from_toml(cls, config_path: Path) -> FinetuneConfig: + """Load finetune config from a TOML file.""" + raw = tomllib.loads(config_path.read_text())["finetune"] + return cls( + base_model=raw["base_model"], + lora=LoraConfig(**raw["lora"]), + training=TrainingConfig(**raw["training"]), + ) def _messages_to_chatml(messages: list[dict]) -> str: @@ -95,12 +114,10 @@ def main( output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path( "output/qwen-bill-summarizer", ), - base_model: Annotated[str, typer.Option("--base-model", help="Unsloth model ID")] = BASE_MODEL, - epochs: Annotated[int, typer.Option("--epochs", help="Training epochs")] = EPOCHS, - batch_size: Annotated[int, typer.Option("--batch-size", help="Per-device batch size")] = BATCH_SIZE, - learning_rate: Annotated[float, typer.Option("--lr", help="Learning rate")] = LEARNING_RATE, - lora_rank: Annotated[int, typer.Option("--lora-rank", help="LoRA rank")] = LORA_RANK, - max_seq_length: Annotated[int, typer.Option("--max-seq-length", help="Max sequence length")] = MAX_SEQ_LENGTH, + config_path: Annotated[ + Path, + typer.Option("--config", help="TOML config file"), + ] = Path(__file__).parent / "config.toml", save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False, ) -> None: """Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA.""" @@ -110,21 +127,23 @@ def main( message = f"Dataset not found: {dataset_path}" raise typer.BadParameter(message) - logger.info("Loading base model: %s", base_model) + config = FinetuneConfig.from_toml(config_path) + + logger.info("Loading base model: %s", config.base_model) model, tokenizer = FastLanguageModel.from_pretrained( - model_name=base_model, - max_seq_length=max_seq_length, + model_name=config.base_model, + max_seq_length=config.training.max_seq_length, load_in_4bit=True, dtype=None, ) - logger.info("Applying LoRA (rank=%d, alpha=%d)", lora_rank, LORA_ALPHA) + logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha) model = FastLanguageModel.get_peft_model( model, - r=lora_rank, - lora_alpha=LORA_ALPHA, - lora_dropout=LORA_DROPOUT, - target_modules=LORA_TARGETS, + r=config.lora.rank, + lora_alpha=config.lora.alpha, + lora_dropout=config.lora.dropout, + target_modules=config.lora.targets, bias="none", use_gradient_checkpointing="unsloth", random_state=42, @@ -137,18 +156,18 @@ def main( logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset)) training_args = TrainingArguments( output_dir=str(output_dir / "checkpoints"), - num_train_epochs=epochs, - per_device_train_batch_size=batch_size, - gradient_accumulation_steps=GRADIENT_ACCUMULATION, - learning_rate=learning_rate, - warmup_ratio=WARMUP_RATIO, - weight_decay=WEIGHT_DECAY, + num_train_epochs=config.training.epochs, + per_device_train_batch_size=config.training.batch_size, + gradient_accumulation_steps=config.training.gradient_accumulation, + learning_rate=config.training.learning_rate, + warmup_ratio=config.training.warmup_ratio, + weight_decay=config.training.weight_decay, lr_scheduler_type="cosine", - logging_steps=LOGGING_STEPS, - save_steps=SAVE_STEPS, + logging_steps=config.training.logging_steps, + save_steps=config.training.save_steps, save_total_limit=3, eval_strategy="steps", - eval_steps=SAVE_STEPS, + eval_steps=config.training.save_steps, load_best_model_at_end=True, bf16=True, optim="adamw_8bit", @@ -162,11 +181,16 @@ def main( train_dataset=train_dataset, eval_dataset=validation_dataset, args=training_args, - max_seq_length=max_seq_length, + max_seq_length=config.training.max_seq_length, packing=True, ) - logger.info("Starting training: %d train, %d val, %d epochs", len(train_dataset), len(validation_dataset), epochs) + logger.info( + "Starting training: %d train, %d val, %d epochs", + len(train_dataset), + len(validation_dataset), + config.training.epochs, + ) trainer.train() merged_path = str(output_dir / "merged")