mirror of
https://github.com/RichieCahill/dotfiles.git
synced 2026-04-18 05:18:20 -04:00
setup FinetuneConfig
This commit is contained in:
@@ -14,9 +14,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
import tomllib
|
||||||
import typer
|
import typer
|
||||||
from unsloth import FastLanguageModel
|
from unsloth import FastLanguageModel
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -25,32 +27,49 @@ from trl import SFTTrainer
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BASE_MODEL = "unsloth/Qwen3-4B-Base-unsloth-bnb-4bit"
|
|
||||||
|
|
||||||
# LoRA hyperparameters
|
@dataclass
|
||||||
LORA_RANK = 32
|
class LoraConfig:
|
||||||
LORA_ALPHA = 32
|
"""LoRA adapter hyperparameters."""
|
||||||
LORA_DROPOUT = 0.0
|
|
||||||
LORA_TARGETS = [
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Training hyperparameters tuned for ~2k examples on a 3090
|
rank: int
|
||||||
LEARNING_RATE = 2e-4
|
alpha: int
|
||||||
EPOCHS = 3
|
dropout: float
|
||||||
BATCH_SIZE = 2
|
targets: list[str]
|
||||||
GRADIENT_ACCUMULATION = 8 # effective batch = 16
|
|
||||||
MAX_SEQ_LENGTH = 4096
|
|
||||||
WARMUP_RATIO = 0.05
|
@dataclass
|
||||||
WEIGHT_DECAY = 0.01
|
class TrainingConfig:
|
||||||
LOGGING_STEPS = 10
|
"""Training loop hyperparameters."""
|
||||||
SAVE_STEPS = 100
|
|
||||||
|
learning_rate: float
|
||||||
|
epochs: int
|
||||||
|
batch_size: int
|
||||||
|
gradient_accumulation: int
|
||||||
|
max_seq_length: int
|
||||||
|
warmup_ratio: float
|
||||||
|
weight_decay: float
|
||||||
|
logging_steps: int
|
||||||
|
save_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuneConfig:
|
||||||
|
"""Top-level finetune configuration."""
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
lora: LoraConfig
|
||||||
|
training: TrainingConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_toml(cls, config_path: Path) -> FinetuneConfig:
|
||||||
|
"""Load finetune config from a TOML file."""
|
||||||
|
raw = tomllib.loads(config_path.read_text())["finetune"]
|
||||||
|
return cls(
|
||||||
|
base_model=raw["base_model"],
|
||||||
|
lora=LoraConfig(**raw["lora"]),
|
||||||
|
training=TrainingConfig(**raw["training"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _messages_to_chatml(messages: list[dict]) -> str:
|
def _messages_to_chatml(messages: list[dict]) -> str:
|
||||||
@@ -95,12 +114,10 @@ def main(
|
|||||||
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
|
output_dir: Annotated[Path, typer.Option("--output-dir", help="Where to save the merged model")] = Path(
|
||||||
"output/qwen-bill-summarizer",
|
"output/qwen-bill-summarizer",
|
||||||
),
|
),
|
||||||
base_model: Annotated[str, typer.Option("--base-model", help="Unsloth model ID")] = BASE_MODEL,
|
config_path: Annotated[
|
||||||
epochs: Annotated[int, typer.Option("--epochs", help="Training epochs")] = EPOCHS,
|
Path,
|
||||||
batch_size: Annotated[int, typer.Option("--batch-size", help="Per-device batch size")] = BATCH_SIZE,
|
typer.Option("--config", help="TOML config file"),
|
||||||
learning_rate: Annotated[float, typer.Option("--lr", help="Learning rate")] = LEARNING_RATE,
|
] = Path(__file__).parent / "config.toml",
|
||||||
lora_rank: Annotated[int, typer.Option("--lora-rank", help="LoRA rank")] = LORA_RANK,
|
|
||||||
max_seq_length: Annotated[int, typer.Option("--max-seq-length", help="Max sequence length")] = MAX_SEQ_LENGTH,
|
|
||||||
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
|
save_gguf: Annotated[bool, typer.Option("--save-gguf/--no-save-gguf", help="Also save GGUF")] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
|
"""Fine-tune Qwen 3.5 4B on bill summarization with Unsloth + QLoRA."""
|
||||||
@@ -110,21 +127,23 @@ def main(
|
|||||||
message = f"Dataset not found: {dataset_path}"
|
message = f"Dataset not found: {dataset_path}"
|
||||||
raise typer.BadParameter(message)
|
raise typer.BadParameter(message)
|
||||||
|
|
||||||
logger.info("Loading base model: %s", base_model)
|
config = FinetuneConfig.from_toml(config_path)
|
||||||
|
|
||||||
|
logger.info("Loading base model: %s", config.base_model)
|
||||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
model_name=base_model,
|
model_name=config.base_model,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=config.training.max_seq_length,
|
||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Applying LoRA (rank=%d, alpha=%d)", lora_rank, LORA_ALPHA)
|
logger.info("Applying LoRA (rank=%d, alpha=%d)", config.lora.rank, config.lora.alpha)
|
||||||
model = FastLanguageModel.get_peft_model(
|
model = FastLanguageModel.get_peft_model(
|
||||||
model,
|
model,
|
||||||
r=lora_rank,
|
r=config.lora.rank,
|
||||||
lora_alpha=LORA_ALPHA,
|
lora_alpha=config.lora.alpha,
|
||||||
lora_dropout=LORA_DROPOUT,
|
lora_dropout=config.lora.dropout,
|
||||||
target_modules=LORA_TARGETS,
|
target_modules=config.lora.targets,
|
||||||
bias="none",
|
bias="none",
|
||||||
use_gradient_checkpointing="unsloth",
|
use_gradient_checkpointing="unsloth",
|
||||||
random_state=42,
|
random_state=42,
|
||||||
@@ -137,18 +156,18 @@ def main(
|
|||||||
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
|
logger.info("Split: %d train, %d validation", len(train_dataset), len(validation_dataset))
|
||||||
training_args = TrainingArguments(
|
training_args = TrainingArguments(
|
||||||
output_dir=str(output_dir / "checkpoints"),
|
output_dir=str(output_dir / "checkpoints"),
|
||||||
num_train_epochs=epochs,
|
num_train_epochs=config.training.epochs,
|
||||||
per_device_train_batch_size=batch_size,
|
per_device_train_batch_size=config.training.batch_size,
|
||||||
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
|
gradient_accumulation_steps=config.training.gradient_accumulation,
|
||||||
learning_rate=learning_rate,
|
learning_rate=config.training.learning_rate,
|
||||||
warmup_ratio=WARMUP_RATIO,
|
warmup_ratio=config.training.warmup_ratio,
|
||||||
weight_decay=WEIGHT_DECAY,
|
weight_decay=config.training.weight_decay,
|
||||||
lr_scheduler_type="cosine",
|
lr_scheduler_type="cosine",
|
||||||
logging_steps=LOGGING_STEPS,
|
logging_steps=config.training.logging_steps,
|
||||||
save_steps=SAVE_STEPS,
|
save_steps=config.training.save_steps,
|
||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
eval_strategy="steps",
|
eval_strategy="steps",
|
||||||
eval_steps=SAVE_STEPS,
|
eval_steps=config.training.save_steps,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
bf16=True,
|
bf16=True,
|
||||||
optim="adamw_8bit",
|
optim="adamw_8bit",
|
||||||
@@ -162,11 +181,16 @@ def main(
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=validation_dataset,
|
eval_dataset=validation_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
max_seq_length=max_seq_length,
|
max_seq_length=config.training.max_seq_length,
|
||||||
packing=True,
|
packing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Starting training: %d train, %d val, %d epochs", len(train_dataset), len(validation_dataset), epochs)
|
logger.info(
|
||||||
|
"Starting training: %d train, %d val, %d epochs",
|
||||||
|
len(train_dataset),
|
||||||
|
len(validation_dataset),
|
||||||
|
config.training.epochs,
|
||||||
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
merged_path = str(output_dir / "merged")
|
merged_path = str(output_dir / "merged")
|
||||||
|
|||||||
Reference in New Issue
Block a user