Train Qwen3-0.6B-Base (596M params) on 36K folksy proverb pairs using full SFT with HuggingFace TRL. 3 epochs, 11 min on RTX 4090. Results: train_loss=0.954, eval_loss=1.032, test_loss=1.031 Model checkpoint at folksy-model/final/ (not committed — 1.2 GB) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
296 lines
9.6 KiB
Python
296 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
|
"""SFT fine-tune Qwen3-0.6B-Base on folksy proverb training pairs.
|
|
|
|
Usage:
|
|
python scripts/train_sft.py
|
|
|
|
Expects corpus/training_pairs.jsonl in the project root.
|
|
Outputs model checkpoints and training logs to folksy-model/.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
from collections import Counter, defaultdict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# Prevent CUDA fragmentation
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
|
|
from trl import SFTConfig, SFTTrainer
|
|
|
|
# === Configuration ===
|
|
MODEL_ID = "Qwen/Qwen3-0.6B-Base"
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
|
DATA_FILE = PROJECT_ROOT / "corpus" / "training_pairs.jsonl"
|
|
OUTPUT_DIR = PROJECT_ROOT / "folksy-model"
|
|
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
|
|
FINAL_MODEL_DIR = OUTPUT_DIR / "final"
|
|
LOG_FILE = OUTPUT_DIR / "training_log.jsonl"
|
|
|
|
# ChatML template without Qwen3 thinking tags — clean input/output format
|
|
CHAT_TEMPLATE = (
|
|
"{% for message in messages %}"
|
|
"<|im_start|>{{ message['role'] }}\n"
|
|
"{{ message['content'] }}<|im_end|>\n"
|
|
"{% endfor %}"
|
|
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
|
)
|
|
|
|
TRAINING_CONFIG = {
|
|
"num_train_epochs": 3,
|
|
"per_device_train_batch_size": 32,
|
|
"learning_rate": 2e-5,
|
|
"lr_scheduler_type": "cosine",
|
|
"warmup_ratio": 0.05,
|
|
"weight_decay": 0.01,
|
|
"max_seq_length": 128,
|
|
"eval_steps": 100,
|
|
"save_steps": 500,
|
|
"logging_steps": 10,
|
|
"seed": 42,
|
|
}
|
|
|
|
|
|
def log_event(event: str, **kwargs):
|
|
"""Append a structured log event."""
|
|
entry = {"timestamp": datetime.now().isoformat(), "event": event, **kwargs}
|
|
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(LOG_FILE, "a") as f:
|
|
f.write(json.dumps(entry) + "\n")
|
|
detail = {k: v for k, v in kwargs.items() if k != "timestamp"}
|
|
print(f"[{entry['timestamp'][:19]}] {event}", detail if detail else "")
|
|
|
|
|
|
def load_and_split_data(data_file: Path, val_ratio=0.05, test_ratio=0.05):
|
|
"""Load JSONL training pairs and create stratified splits by meta_template."""
|
|
with open(data_file) as f:
|
|
records = [json.loads(line) for line in f]
|
|
|
|
log_event("data_loaded", total_records=len(records))
|
|
|
|
# Convert to chat messages format
|
|
for r in records:
|
|
r["messages"] = [
|
|
{"role": "user", "content": r["input"]},
|
|
{"role": "assistant", "content": r["output"]},
|
|
]
|
|
|
|
# Stratified split by meta_template
|
|
random.seed(42)
|
|
groups = defaultdict(list)
|
|
for i, r in enumerate(records):
|
|
groups[r["meta_template"]].append(i)
|
|
|
|
train_idx, val_idx, test_idx = [], [], []
|
|
for template, indices in sorted(groups.items()):
|
|
random.shuffle(indices)
|
|
n = len(indices)
|
|
n_test = max(1, round(n * test_ratio))
|
|
n_val = max(1, round(n * val_ratio))
|
|
test_idx.extend(indices[:n_test])
|
|
val_idx.extend(indices[n_test : n_test + n_val])
|
|
train_idx.extend(indices[n_test + n_val :])
|
|
|
|
def make_dataset(indices):
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": records[i]["messages"],
|
|
"meta_template": records[i]["meta_template"],
|
|
}
|
|
for i in indices
|
|
]
|
|
)
|
|
|
|
train_ds = make_dataset(train_idx)
|
|
val_ds = make_dataset(val_idx)
|
|
test_ds = make_dataset(test_idx)
|
|
|
|
log_event(
|
|
"data_split", train=len(train_ds), val=len(val_ds), test=len(test_ds)
|
|
)
|
|
|
|
# Print distribution
|
|
for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]:
|
|
dist = Counter(ds["meta_template"])
|
|
print(f"\n{name} ({len(ds)} examples):")
|
|
for t, c in sorted(dist.items()):
|
|
print(f" {t}: {c} ({c / len(ds) * 100:.1f}%)")
|
|
|
|
return train_ds, val_ds, test_ds
|
|
|
|
|
|
class MetricsLogger(TrainerCallback):
|
|
"""Log training metrics to the run log file."""
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
if logs:
|
|
log_event("train_metrics", step=state.global_step, **logs)
|
|
|
|
def on_epoch_end(self, args, state, control, **kwargs):
|
|
epoch_num = int(state.epoch)
|
|
train_loss = None
|
|
for entry in reversed(state.log_history):
|
|
if "loss" in entry:
|
|
train_loss = entry["loss"]
|
|
break
|
|
log_event("epoch_complete", epoch=epoch_num, train_loss=train_loss)
|
|
|
|
|
|
def main():
|
|
start_time = time.time()
|
|
|
|
# Verify GPU
|
|
if not torch.cuda.is_available():
|
|
print("ERROR: No CUDA GPU available")
|
|
sys.exit(1)
|
|
|
|
gpu_name = torch.cuda.get_device_name(0)
|
|
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
log_event("session_start", gpu=gpu_name, vram_gb=round(gpu_mem, 1))
|
|
|
|
# Verify data
|
|
if not DATA_FILE.exists():
|
|
print(f"ERROR: Training data not found at {DATA_FILE}")
|
|
sys.exit(1)
|
|
|
|
# Create output directories
|
|
for d in [OUTPUT_DIR, CHECKPOINT_DIR]:
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load data
|
|
train_ds, val_ds, test_ds = load_and_split_data(DATA_FILE)
|
|
|
|
# Save val/test splits for later evaluation
|
|
val_ds.to_json(OUTPUT_DIR / "val_split.jsonl")
|
|
test_ds.to_json(OUTPUT_DIR / "test_split.jsonl")
|
|
|
|
# Load model and tokenizer
|
|
log_event("model_loading", model=MODEL_ID)
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_ID, dtype=torch.bfloat16
|
|
)
|
|
|
|
# Override chat template to remove Qwen3 thinking tags
|
|
tokenizer.chat_template = CHAT_TEMPLATE
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
param_count = sum(p.numel() for p in model.parameters())
|
|
log_event("model_loaded", parameters=f"{param_count / 1e6:.0f}M")
|
|
|
|
# Verify tokenization on a sample
|
|
sample_messages = train_ds[0]["messages"]
|
|
sample_encoded = tokenizer.apply_chat_template(
|
|
sample_messages, tokenize=True, return_dict=False
|
|
)
|
|
sample_text = tokenizer.apply_chat_template(
|
|
sample_messages, tokenize=False
|
|
)
|
|
print(f"\nSample tokenization ({len(sample_encoded)} tokens):")
|
|
print(sample_text)
|
|
|
|
# Configure training
|
|
training_args = SFTConfig(
|
|
output_dir=str(CHECKPOINT_DIR),
|
|
num_train_epochs=TRAINING_CONFIG["num_train_epochs"],
|
|
per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
|
|
gradient_accumulation_steps=1,
|
|
learning_rate=TRAINING_CONFIG["learning_rate"],
|
|
lr_scheduler_type=TRAINING_CONFIG["lr_scheduler_type"],
|
|
warmup_ratio=TRAINING_CONFIG["warmup_ratio"],
|
|
weight_decay=TRAINING_CONFIG["weight_decay"],
|
|
bf16=True,
|
|
max_length=TRAINING_CONFIG["max_seq_length"],
|
|
eval_strategy="steps",
|
|
eval_steps=TRAINING_CONFIG["eval_steps"],
|
|
save_strategy="steps",
|
|
save_steps=TRAINING_CONFIG["save_steps"],
|
|
logging_steps=TRAINING_CONFIG["logging_steps"],
|
|
report_to="tensorboard",
|
|
logging_dir=str(OUTPUT_DIR / "runs"),
|
|
seed=TRAINING_CONFIG["seed"],
|
|
dataloader_num_workers=2,
|
|
optim="adamw_torch_fused",
|
|
load_best_model_at_end=True,
|
|
metric_for_best_model="eval_loss",
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_ds,
|
|
eval_dataset=val_ds,
|
|
processing_class=tokenizer,
|
|
callbacks=[MetricsLogger()],
|
|
)
|
|
|
|
steps_per_epoch = len(train_ds) // TRAINING_CONFIG["per_device_train_batch_size"]
|
|
total_steps = steps_per_epoch * TRAINING_CONFIG["num_train_epochs"]
|
|
log_event(
|
|
"training_start",
|
|
steps_per_epoch=steps_per_epoch,
|
|
total_steps_approx=total_steps,
|
|
config=TRAINING_CONFIG,
|
|
)
|
|
|
|
# Train
|
|
train_result = trainer.train()
|
|
|
|
training_time = time.time() - start_time
|
|
log_event(
|
|
"training_complete",
|
|
wall_time_seconds=round(training_time, 1),
|
|
wall_time_minutes=round(training_time / 60, 1),
|
|
train_loss=train_result.training_loss,
|
|
train_runtime=train_result.metrics.get("train_runtime"),
|
|
train_samples_per_second=train_result.metrics.get(
|
|
"train_samples_per_second"
|
|
),
|
|
)
|
|
|
|
# Save final model
|
|
FINAL_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
|
trainer.save_model(str(FINAL_MODEL_DIR))
|
|
tokenizer.save_pretrained(str(FINAL_MODEL_DIR))
|
|
|
|
# Save full training log history
|
|
with open(FINAL_MODEL_DIR / "trainer_state.json", "w") as f:
|
|
json.dump(trainer.state.log_history, f, indent=2)
|
|
|
|
log_event("final_model_saved", path=str(FINAL_MODEL_DIR))
|
|
|
|
# Run final eval on test set (swap eval dataset temporarily)
|
|
original_eval_ds = trainer.eval_dataset
|
|
trainer.eval_dataset = test_ds
|
|
test_metrics = trainer.evaluate(metric_key_prefix="test")
|
|
trainer.eval_dataset = original_eval_ds
|
|
log_event("test_eval", **test_metrics)
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 60)
|
|
print("TRAINING COMPLETE")
|
|
print("=" * 60)
|
|
print(f"Model: {MODEL_ID}")
|
|
print(f"Training pairs: {len(train_ds)}")
|
|
print(f"Val pairs: {len(val_ds)}")
|
|
print(f"Test pairs: {len(test_ds)}")
|
|
print(f"Final train loss: {train_result.training_loss:.4f}")
|
|
print(f"Test loss: {test_metrics.get('test_loss', 'N/A')}")
|
|
print(f"Wall time: {training_time / 60:.1f} minutes")
|
|
print(f"Checkpoint: {FINAL_MODEL_DIR}")
|
|
print(f"Training log: {LOG_FILE}")
|
|
print(f"TensorBoard: {OUTPUT_DIR / 'runs'}")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|