mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-17 13:08:19 -04:00
setup FinetuneConfig
This commit is contained in:
@@ -14,9 +14,11 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import tomllib
|
||||
import typer
|
||||
from unsloth import FastLanguageModel
|
||||
from datasets import Dataset
|
||||
@@ -25,32 +27,49 @@ from trl import SFTTrainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BASE_MODEL = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit"
|
||||
|
||||
# LoRA hyperparameters
|
||||
LORA_RANK = 32
|
||||
LORA_ALPHA = 32
|
||||
LORA_DROPOUT = 0.0
|
||||
LORA_TARGETS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
@dataclass
|
||||
class LoraConfig:
|
||||
"""LoRA adapter hyperparameters."""
|
||||
|
||||
# Training hyperparameters tuned for ~2k examples on a 3090
|
||||
LEARNING_RATE = 2e-4
|
||||
EPOCHS = 3
|
||||
BATCH_SIZE = 2
|
||||
GRADIENT_ACCUMULATION = 8 # effective batch = 16
|
||||
MAX_SEQ_LENGTH = 4096
|
||||
WARMUP_RATIO = 0.05
|
||||
WEIGHT_DECAY = 0.01
|
||||
LOGGING_STEPS = 10
|
||||
SAVE_STEPS = 100
|
||||
rank: int
|
||||
alpha: int
|
||||
dropout: float
|
||||
targets: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Training loop hyperparameters."""
|
||||
|
||||
learning_rate: float
|
||||
epochs: int
|
||||
batch_size: int
|
||||
gradient_accumulation: int
|
||||
max_seq_length: int
|
||||
warmup_ratio: float
|
||||
weight_decay: float
|
||||
logging_steps: int
|
||||
save_steps: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuneConfig:
|
||||
"""Top-level finetune configuration."""
|
||||
|
||||
base_model: str
|
||||
lora: LoraConfig
|
||||
training: TrainingConfig
|
||||
|
||||
@classmethod
|
||||
def from_toml(cls, config_path: Path) -> FinetuneConfig:
|
||||
"""Load finetune config from a TOML file."""
|
||||
raw = tomllib.loads(config_path.read_text())["finetune"]
|
||||
return cls(
|
||||
base_model=raw["base_model"],
|
||||
lora=LoraConfig(**raw["lora"]),
|
||||
training=TrainingConfig(**raw["training"]),
|
||||
)
|
||||
|
||||
|
||||
def _messages_to_chatml(messages: list[dict]) -> str:
|
||||
@@ -95,12 +114,10 @@ def main(
|
||||
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
|
||||
"output/qwen-bill-summarizer",
|
||||
),
|
||||
base_model: Annotated[str, typer.Option("--base-model", help="Unsloth model ID")] = BASE_MODEL,
|
||||
epochs: Annotated[int, typer.Option("--epochs", help="Training epochs")] = EPOCHS,
|
||||
batch_size: Annotated[int, typer.Option("--batch-size", help="Per-device batch size")] = BATCH_SIZE,
|
||||
learning_rate: Annotated[float, typer.Option("--lr", help="Learning rate")] = LEARNING_RATE,
|
||||
lora_rank: Annotated[int, typer.Option("--lora-rank", help="LoRA rank")] = LORA_RANK,
|
||||
max_seq_length: Annotated[int, typer.Option("--max-seq-length", help="Max sequence length")] = MAX_SEQ_LENGTH,
|
||||
config_path: Annotated[
|
||||
Path,
|
||||
typer.Option("--config", help="TOML config file"),
|
||||
] = Path(__file__).parent / "config.toml",
|
||||
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
|
||||
) -> None:
|
||||
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
|
||||
@@ -110,21 +127,23 @@ def main(
|
||||
message = f"Dataset not found: {dataset_path}"
|
||||
raise typer.BadParameter(message)
|
||||
|
||||
logger.info("Loading base model: %s", base_model)
|
||||
config = FinetuneConfig.from_toml(config_path)
|
||||
|
||||
logger.info("Loading base model: %s", config.base_model)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=base_model,
|
||||
max_seq_length=max_seq_length,
|
||||
model_name=config.base_model,
|
||||
max_seq_length=config.training.max_seq_length,
|
||||
load_in_4bit=True,
|
||||
dtype=None,
|
||||
)
|
||||
|
||||
logger.info("Applying LoRA (rank=%d, alpha=%d)", lora_rank, LORA_ALPHA)
|
||||
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=lora_rank,
|
||||
lora_alpha=LORA_ALPHA,
|
||||
lora_dropout=LORA_DROPOUT,
|
||||
target_modules=LORA_TARGETS,
|
||||
r=config.lora.rank,
|
||||
lora_alpha=config.lora.alpha,
|
||||
lora_dropout=config.lora.dropout,
|
||||
target_modules=config.lora.targets,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=42,
|
||||
@@ -137,18 +156,18 @@ def main(
|
||||
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir / "checkpoints"),
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
|
||||
learning_rate=learning_rate,
|
||||
warmup_ratio=WARMUP_RATIO,
|
||||
weight_decay=WEIGHT_DECAY,
|
||||
num_train_epochs=config.training.epochs,
|
||||
per_device_train_batch_size=config.training.batch_size,
|
||||
gradient_accumulation_steps=config.training.gradient_accumulation,
|
||||
learning_rate=config.training.learning_rate,
|
||||
warmup_ratio=config.training.warmup_ratio,
|
||||
weight_decay=config.training.weight_decay,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=LOGGING_STEPS,
|
||||
save_steps=SAVE_STEPS,
|
||||
logging_steps=config.training.logging_steps,
|
||||
save_steps=config.training.save_steps,
|
||||
save_total_limit=3,
|
||||
eval_strategy="steps",
|
||||
eval_steps=SAVE_STEPS,
|
||||
eval_steps=config.training.save_steps,
|
||||
load_best_model_at_end=True,
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
@@ -162,11 +181,16 @@ def main(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
args=training_args,
|
||||
max_seq_length=max_seq_length,
|
||||
max_seq_length=config.training.max_seq_length,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
logger.info("Starting training: %d train, %d val, %d epochs", len(train_dataset), len(validation_dataset), epochs)
|
||||
logger.info(
|
||||
"Starting training: %d train, %d val, %d epochs",
|
||||
len(train_dataset),
|
||||
len(validation_dataset),
|
||||
config.training.epochs,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
merged_path = str(output_dir / "merged")
|
||||
|
||||
Reference in New Issue
Block a user