setup FinetuneConfig

This commit is contained in:
2026-04-10 21:40:17 -04:00
parent 1409e9c63e
commit 0d81f2d17b

View File

@@ -14,9 +14,11 @@ from __future__ import annotations
import json import json
import logging import logging
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Annotated from typing import Annotated
import tomllib
import typer import typer
from unsloth import FastLanguageModel from unsloth import FastLanguageModel
from datasets import Dataset from datasets import Dataset
@@ -25,32 +27,49 @@ from trl import SFTTrainer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BASE_MODEL = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit"
# LoRA hyperparameters @dataclass
LORA_RANK = 32 class LoraConfig:
LORA_ALPHA = 32 """LoRA adapter hyperparameters."""
LORA_DROPOUT = 0.0
LORA_TARGETS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
# Training hyperparameters tuned for ~2k examples on a 3090 rank: int
LEARNING_RATE = 2e-4 alpha: int
EPOCHS = 3 dropout: float
BATCH_SIZE = 2 targets: list[str]
GRADIENT_ACCUMULATION = 8 # effective batch = 16
MAX_SEQ_LENGTH = 4096
WARMUP_RATIO = 0.05 @dataclass
WEIGHT_DECAY = 0.01 class TrainingConfig:
LOGGING_STEPS = 10 """Training loop hyperparameters."""
SAVE_STEPS = 100
learning_rate: float
epochs: int
batch_size: int
gradient_accumulation: int
max_seq_length: int
warmup_ratio: float
weight_decay: float
logging_steps: int
save_steps: int
@dataclass
class FinetuneConfig:
"""Top-level finetune configuration."""
base_model: str
lora: LoraConfig
training: TrainingConfig
@classmethod
def from_toml(cls, config_path: Path) -> FinetuneConfig:
"""Load finetune config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["finetune"]
return cls(
base_model=raw["base_model"],
lora=LoraConfig(**raw["lora"]),
training=TrainingConfig(**raw["training"]),
)
def _messages_to_chatml(messages: list[dict]) -> str: def _messages_to_chatml(messages: list[dict]) -> str:
@@ -95,12 +114,10 @@ def main(
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path( output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
"output/qwen-bill-summarizer", "output/qwen-bill-summarizer",
), ),
base_model: Annotated[str, typer.Option("--base-model", help="Unsloth model ID")] = BASE_MODEL, config_path: Annotated[
epochs: Annotated[int, typer.Option("--epochs", help="Training epochs")] = EPOCHS, Path,
batch_size: Annotated[int, typer.Option("--batch-size", help="Per-device batch size")] = BATCH_SIZE, typer.Option("--config", help="TOML config file"),
learning_rate: Annotated[float, typer.Option("--lr", help="Learning rate")] = LEARNING_RATE, ] = Path(__file__).parent / "config.toml",
lora_rank: Annotated[int, typer.Option("--lora-rank", help="LoRA rank")] = LORA_RANK,
max_seq_length: Annotated[int, typer.Option("--max-seq-length", help="Max sequence length")] = MAX_SEQ_LENGTH,
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False, save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
) -> None: ) -> None:
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA.""" """Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
@@ -110,21 +127,23 @@ def main(
message = f"Dataset not found: {dataset_path}" message = f"Dataset not found: {dataset_path}"
raise typer.BadParameter(message) raise typer.BadParameter(message)
logger.info("Loading base model: %s", base_model) config = FinetuneConfig.from_toml(config_path)
logger.info("Loading base model: %s", config.base_model)
model, tokenizer = FastLanguageModel.from_pretrained( model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model, model_name=config.base_model,
max_seq_length=max_seq_length, max_seq_length=config.training.max_seq_length,
load_in_4bit=True, load_in_4bit=True,
dtype=None, dtype=None,
) )
logger.info("Applying LoRA (rank=%d, alpha=%d)", lora_rank, LORA_ALPHA) logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
model = FastLanguageModel.get_peft_model( model = FastLanguageModel.get_peft_model(
model, model,
r=lora_rank, r=config.lora.rank,
lora_alpha=LORA_ALPHA, lora_alpha=config.lora.alpha,
lora_dropout=LORA_DROPOUT, lora_dropout=config.lora.dropout,
target_modules=LORA_TARGETS, target_modules=config.lora.targets,
bias="none", bias="none",
use_gradient_checkpointing="unsloth", use_gradient_checkpointing="unsloth",
random_state=42, random_state=42,
@@ -137,18 +156,18 @@ def main(
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset)) logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"), output_dir=str(output_dir / "checkpoints"),
num_train_epochs=epochs, num_train_epochs=config.training.epochs,
per_device_train_batch_size=batch_size, per_device_train_batch_size=config.training.batch_size,
gradient_accumulation_steps=GRADIENT_ACCUMULATION, gradient_accumulation_steps=config.training.gradient_accumulation,
learning_rate=learning_rate, learning_rate=config.training.learning_rate,
warmup_ratio=WARMUP_RATIO, warmup_ratio=config.training.warmup_ratio,
weight_decay=WEIGHT_DECAY, weight_decay=config.training.weight_decay,
lr_scheduler_type="cosine", lr_scheduler_type="cosine",
logging_steps=LOGGING_STEPS, logging_steps=config.training.logging_steps,
save_steps=SAVE_STEPS, save_steps=config.training.save_steps,
save_total_limit=3, save_total_limit=3,
eval_strategy="steps", eval_strategy="steps",
eval_steps=SAVE_STEPS, eval_steps=config.training.save_steps,
load_best_model_at_end=True, load_best_model_at_end=True,
bf16=True, bf16=True,
optim="adamw_8bit", optim="adamw_8bit",
@@ -162,11 +181,16 @@ def main(
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=validation_dataset, eval_dataset=validation_dataset,
args=training_args, args=training_args,
max_seq_length=max_seq_length, max_seq_length=config.training.max_seq_length,
packing=True, packing=True,
) )
logger.info("Starting training: %d train, %d val, %d epochs", len(train_dataset), len(validation_dataset), epochs) logger.info(
"Starting training: %d train, %d val, %d epochs",
len(train_dataset),
len(validation_dataset),
config.training.epochs,
)
trainer.train() trainer.train()
merged_path = str(output_dir / "merged") merged_path = str(output_dir / "merged")