folksy_idioms/scripts/train_sft.py
john 02daa7bb97 Add SFT training script and run Qwen3-0.6B-Base fine-tune
Train Qwen3-0.6B-Base (596M params) on 36K folksy proverb pairs
using full SFT with HuggingFace TRL. 3 epochs, 11 min on RTX 4090.

Results: train_loss=0.954, eval_loss=1.032, test_loss=1.031
Model checkpoint at folksy-model/final/ (not committed — 1.2 GB)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-31 22:07:23 -04:00

296 lines
9.6 KiB
Python

#!/usr/bin/env python3
"""SFT fine-tune Qwen3-0.6B-Base on folksy proverb training pairs.
Usage:
python scripts/train_sft.py
Expects corpus/training_pairs.jsonl in the project root.
Outputs model checkpoints and training logs to folksy-model/.
"""
import json
import os
import random
import sys
import time
from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path
# Prevent CUDA fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import SFTConfig, SFTTrainer
# === Configuration ===
MODEL_ID = "Qwen/Qwen3-0.6B-Base"
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_FILE = PROJECT_ROOT / "corpus" / "training_pairs.jsonl"
OUTPUT_DIR = PROJECT_ROOT / "folksy-model"
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
FINAL_MODEL_DIR = OUTPUT_DIR / "final"
LOG_FILE = OUTPUT_DIR / "training_log.jsonl"
# ChatML template without Qwen3 thinking tags — clean input/output format
CHAT_TEMPLATE = (
"{% for message in messages %}"
"<|im_start|>{{ message['role'] }}\n"
"{{ message['content'] }}<|im_end|>\n"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
TRAINING_CONFIG = {
"num_train_epochs": 3,
"per_device_train_batch_size": 32,
"learning_rate": 2e-5,
"lr_scheduler_type": "cosine",
"warmup_ratio": 0.05,
"weight_decay": 0.01,
"max_seq_length": 128,
"eval_steps": 100,
"save_steps": 500,
"logging_steps": 10,
"seed": 42,
}
def log_event(event: str, **kwargs):
"""Append a structured log event."""
entry = {"timestamp": datetime.now().isoformat(), "event": event, **kwargs}
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(LOG_FILE, "a") as f:
f.write(json.dumps(entry) + "\n")
detail = {k: v for k, v in kwargs.items() if k != "timestamp"}
print(f"[{entry['timestamp'][:19]}] {event}", detail if detail else "")
def load_and_split_data(data_file: Path, val_ratio=0.05, test_ratio=0.05):
"""Load JSONL training pairs and create stratified splits by meta_template."""
with open(data_file) as f:
records = [json.loads(line) for line in f]
log_event("data_loaded", total_records=len(records))
# Convert to chat messages format
for r in records:
r["messages"] = [
{"role": "user", "content": r["input"]},
{"role": "assistant", "content": r["output"]},
]
# Stratified split by meta_template
random.seed(42)
groups = defaultdict(list)
for i, r in enumerate(records):
groups[r["meta_template"]].append(i)
train_idx, val_idx, test_idx = [], [], []
for template, indices in sorted(groups.items()):
random.shuffle(indices)
n = len(indices)
n_test = max(1, round(n * test_ratio))
n_val = max(1, round(n * val_ratio))
test_idx.extend(indices[:n_test])
val_idx.extend(indices[n_test : n_test + n_val])
train_idx.extend(indices[n_test + n_val :])
def make_dataset(indices):
return Dataset.from_list(
[
{
"messages": records[i]["messages"],
"meta_template": records[i]["meta_template"],
}
for i in indices
]
)
train_ds = make_dataset(train_idx)
val_ds = make_dataset(val_idx)
test_ds = make_dataset(test_idx)
log_event(
"data_split", train=len(train_ds), val=len(val_ds), test=len(test_ds)
)
# Print distribution
for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]:
dist = Counter(ds["meta_template"])
print(f"\n{name} ({len(ds)} examples):")
for t, c in sorted(dist.items()):
print(f" {t}: {c} ({c / len(ds) * 100:.1f}%)")
return train_ds, val_ds, test_ds
class MetricsLogger(TrainerCallback):
"""Log training metrics to the run log file."""
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
log_event("train_metrics", step=state.global_step, **logs)
def on_epoch_end(self, args, state, control, **kwargs):
epoch_num = int(state.epoch)
train_loss = None
for entry in reversed(state.log_history):
if "loss" in entry:
train_loss = entry["loss"]
break
log_event("epoch_complete", epoch=epoch_num, train_loss=train_loss)
def main():
start_time = time.time()
# Verify GPU
if not torch.cuda.is_available():
print("ERROR: No CUDA GPU available")
sys.exit(1)
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
log_event("session_start", gpu=gpu_name, vram_gb=round(gpu_mem, 1))
# Verify data
if not DATA_FILE.exists():
print(f"ERROR: Training data not found at {DATA_FILE}")
sys.exit(1)
# Create output directories
for d in [OUTPUT_DIR, CHECKPOINT_DIR]:
d.mkdir(parents=True, exist_ok=True)
# Load data
train_ds, val_ds, test_ds = load_and_split_data(DATA_FILE)
# Save val/test splits for later evaluation
val_ds.to_json(OUTPUT_DIR / "val_split.jsonl")
test_ds.to_json(OUTPUT_DIR / "test_split.jsonl")
# Load model and tokenizer
log_event("model_loading", model=MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=torch.bfloat16
)
# Override chat template to remove Qwen3 thinking tags
tokenizer.chat_template = CHAT_TEMPLATE
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
param_count = sum(p.numel() for p in model.parameters())
log_event("model_loaded", parameters=f"{param_count / 1e6:.0f}M")
# Verify tokenization on a sample
sample_messages = train_ds[0]["messages"]
sample_encoded = tokenizer.apply_chat_template(
sample_messages, tokenize=True, return_dict=False
)
sample_text = tokenizer.apply_chat_template(
sample_messages, tokenize=False
)
print(f"\nSample tokenization ({len(sample_encoded)} tokens):")
print(sample_text)
# Configure training
training_args = SFTConfig(
output_dir=str(CHECKPOINT_DIR),
num_train_epochs=TRAINING_CONFIG["num_train_epochs"],
per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
gradient_accumulation_steps=1,
learning_rate=TRAINING_CONFIG["learning_rate"],
lr_scheduler_type=TRAINING_CONFIG["lr_scheduler_type"],
warmup_ratio=TRAINING_CONFIG["warmup_ratio"],
weight_decay=TRAINING_CONFIG["weight_decay"],
bf16=True,
max_length=TRAINING_CONFIG["max_seq_length"],
eval_strategy="steps",
eval_steps=TRAINING_CONFIG["eval_steps"],
save_strategy="steps",
save_steps=TRAINING_CONFIG["save_steps"],
logging_steps=TRAINING_CONFIG["logging_steps"],
report_to="tensorboard",
logging_dir=str(OUTPUT_DIR / "runs"),
seed=TRAINING_CONFIG["seed"],
dataloader_num_workers=2,
optim="adamw_torch_fused",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
processing_class=tokenizer,
callbacks=[MetricsLogger()],
)
steps_per_epoch = len(train_ds) // TRAINING_CONFIG["per_device_train_batch_size"]
total_steps = steps_per_epoch * TRAINING_CONFIG["num_train_epochs"]
log_event(
"training_start",
steps_per_epoch=steps_per_epoch,
total_steps_approx=total_steps,
config=TRAINING_CONFIG,
)
# Train
train_result = trainer.train()
training_time = time.time() - start_time
log_event(
"training_complete",
wall_time_seconds=round(training_time, 1),
wall_time_minutes=round(training_time / 60, 1),
train_loss=train_result.training_loss,
train_runtime=train_result.metrics.get("train_runtime"),
train_samples_per_second=train_result.metrics.get(
"train_samples_per_second"
),
)
# Save final model
FINAL_MODEL_DIR.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(FINAL_MODEL_DIR))
tokenizer.save_pretrained(str(FINAL_MODEL_DIR))
# Save full training log history
with open(FINAL_MODEL_DIR / "trainer_state.json", "w") as f:
json.dump(trainer.state.log_history, f, indent=2)
log_event("final_model_saved", path=str(FINAL_MODEL_DIR))
# Run final eval on test set (swap eval dataset temporarily)
original_eval_ds = trainer.eval_dataset
trainer.eval_dataset = test_ds
test_metrics = trainer.evaluate(metric_key_prefix="test")
trainer.eval_dataset = original_eval_ds
log_event("test_eval", **test_metrics)
# Print summary
print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Model: {MODEL_ID}")
print(f"Training pairs: {len(train_ds)}")
print(f"Val pairs: {len(val_ds)}")
print(f"Test pairs: {len(test_ds)}")
print(f"Final train loss: {train_result.training_loss:.4f}")
print(f"Test loss: {test_metrics.get('test_loss', 'N/A')}")
print(f"Wall time: {training_time / 60:.1f} minutes")
print(f"Checkpoint: {FINAL_MODEL_DIR}")
print(f"Training log: {LOG_FILE}")
print(f"TensorBoard: {OUTPUT_DIR / 'runs'}")
print("=" * 60)
if __name__ == "__main__":
main()