setup FinetuneConfig

This commit is contained in:
2026-04-10 21:40:17 -04:00
parent 1409e9c63e
commit 0d81f2d17b

View File

@@ -14,9 +14,11 @@ from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated
import tomllib
import typer
from unsloth import FastLanguageModel
from datasets import Dataset
@@ -25,32 +27,49 @@ from trl import SFTTrainer
logger = logging.getLogger(__name__)
BASE_MODEL = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit"
# LoRA hyperparameters
LORA_RANK = 32
LORA_ALPHA = 32
LORA_DROPOUT = 0.0
LORA_TARGETS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
@dataclass
class LoraConfig:
"""LoRA adapter hyperparameters."""
# Training hyperparameters tuned for ~2k examples on a 3090
LEARNING_RATE = 2e-4
EPOCHS = 3
BATCH_SIZE = 2
GRADIENT_ACCUMULATION = 8 # effective batch = 16
MAX_SEQ_LENGTH = 4096
WARMUP_RATIO = 0.05
WEIGHT_DECAY = 0.01
LOGGING_STEPS = 10
SAVE_STEPS = 100
rank: int
alpha: int
dropout: float
targets: list[str]
@dataclass
class TrainingConfig:
"""Training loop hyperparameters."""
learning_rate: float
epochs: int
batch_size: int
gradient_accumulation: int
max_seq_length: int
warmup_ratio: float
weight_decay: float
logging_steps: int
save_steps: int
@dataclass
class FinetuneConfig:
"""Top-level finetune configuration."""
base_model: str
lora: LoraConfig
training: TrainingConfig
@classmethod
def from_toml(cls, config_path: Path) -> FinetuneConfig:
"""Load finetune config from a TOML file."""
raw = tomllib.loads(config_path.read_text())["finetune"]
return cls(
base_model=raw["base_model"],
lora=LoraConfig(**raw["lora"]),
training=TrainingConfig(**raw["training"]),
)
def _messages_to_chatml(messages: list[dict]) -> str:
@@ -95,12 +114,10 @@ def main(
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
"output/qwen-bill-summarizer",
),
base_model: Annotated[str, typer.Option("--base-model", help="Unsloth model ID")] = BASE_MODEL,
epochs: Annotated[int, typer.Option("--epochs", help="Training epochs")] = EPOCHS,
batch_size: Annotated[int, typer.Option("--batch-size", help="Per-device batch size")] = BATCH_SIZE,
learning_rate: Annotated[float, typer.Option("--lr", help="Learning rate")] = LEARNING_RATE,
lora_rank: Annotated[int, typer.Option("--lora-rank", help="LoRA rank")] = LORA_RANK,
max_seq_length: Annotated[int, typer.Option("--max-seq-length", help="Max sequence length")] = MAX_SEQ_LENGTH,
config_path: Annotated[
Path,
typer.Option("--config", help="TOML config file"),
] = Path(__file__).parent / "config.toml",
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
) -> None:
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
@@ -110,21 +127,23 @@ def main(
message = f"Dataset not found: {dataset_path}"
raise typer.BadParameter(message)
logger.info("Loading base model: %s", base_model)
config = FinetuneConfig.from_toml(config_path)
logger.info("Loading base model: %s", config.base_model)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model,
max_seq_length=max_seq_length,
model_name=config.base_model,
max_seq_length=config.training.max_seq_length,
load_in_4bit=True,
dtype=None,
)
logger.info("Applying LoRA (rank=%d, alpha=%d)", lora_rank, LORA_ALPHA)
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT,
target_modules=LORA_TARGETS,
r=config.lora.rank,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
target_modules=config.lora.targets,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
@@ -137,18 +156,18 @@ def main(
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
learning_rate=learning_rate,
warmup_ratio=WARMUP_RATIO,
weight_decay=WEIGHT_DECAY,
num_train_epochs=config.training.epochs,
per_device_train_batch_size=config.training.batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation,
learning_rate=config.training.learning_rate,
warmup_ratio=config.training.warmup_ratio,
weight_decay=config.training.weight_decay,
lr_scheduler_type="cosine",
logging_steps=LOGGING_STEPS,
save_steps=SAVE_STEPS,
logging_steps=config.training.logging_steps,
save_steps=config.training.save_steps,
save_total_limit=3,
eval_strategy="steps",
eval_steps=SAVE_STEPS,
eval_steps=config.training.save_steps,
load_best_model_at_end=True,
bf16=True,
optim="adamw_8bit",
@@ -162,11 +181,16 @@ def main(
train_dataset=train_dataset,
eval_dataset=validation_dataset,
args=training_args,
max_seq_length=max_seq_length,
max_seq_length=config.training.max_seq_length,
packing=True,
)
logger.info("Starting training: %d train, %d val, %d epochs", len(train_dataset), len(validation_dataset), epochs)
logger.info(
"Starting training: %d train, %d val, %d epochs",
len(train_dataset),
len(validation_dataset),
config.training.epochs,
)
trainer.train()
merged_path = str(output_dir / "merged")