Add SFT training script and run Qwen3-0.6B-Base fine-tune
Train Qwen3-0.6B-Base (596M params) on 36K folksy proverb pairs using full SFT with HuggingFace TRL. 3 epochs, 11 min on RTX 4090. Results: train_loss=0.954, eval_loss=1.032, test_loss=1.031 Model checkpoint at folksy-model/final/ (not committed — 1.2 GB) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
9298c425bc
commit
02daa7bb97
4 changed files with 919 additions and 0 deletions
296
scripts/train_sft.py
Normal file
296
scripts/train_sft.py
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
#!/usr/bin/env python3
|
||||
"""SFT fine-tune Qwen3-0.6B-Base on folksy proverb training pairs.
|
||||
|
||||
Usage:
|
||||
python scripts/train_sft.py
|
||||
|
||||
Expects corpus/training_pairs.jsonl in the project root.
|
||||
Outputs model checkpoints and training logs to folksy-model/.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# Prevent CUDA fragmentation
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
# === Configuration ===
|
||||
MODEL_ID = "Qwen/Qwen3-0.6B-Base"
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
DATA_FILE = PROJECT_ROOT / "corpus" / "training_pairs.jsonl"
|
||||
OUTPUT_DIR = PROJECT_ROOT / "folksy-model"
|
||||
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
|
||||
FINAL_MODEL_DIR = OUTPUT_DIR / "final"
|
||||
LOG_FILE = OUTPUT_DIR / "training_log.jsonl"
|
||||
|
||||
# ChatML template without Qwen3 thinking tags — clean input/output format
|
||||
CHAT_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
"<|im_start|>{{ message['role'] }}\n"
|
||||
"{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
TRAINING_CONFIG = {
|
||||
"num_train_epochs": 3,
|
||||
"per_device_train_batch_size": 32,
|
||||
"learning_rate": 2e-5,
|
||||
"lr_scheduler_type": "cosine",
|
||||
"warmup_ratio": 0.05,
|
||||
"weight_decay": 0.01,
|
||||
"max_seq_length": 128,
|
||||
"eval_steps": 100,
|
||||
"save_steps": 500,
|
||||
"logging_steps": 10,
|
||||
"seed": 42,
|
||||
}
|
||||
|
||||
|
||||
def log_event(event: str, **kwargs):
|
||||
"""Append a structured log event."""
|
||||
entry = {"timestamp": datetime.now().isoformat(), "event": event, **kwargs}
|
||||
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(LOG_FILE, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
detail = {k: v for k, v in kwargs.items() if k != "timestamp"}
|
||||
print(f"[{entry['timestamp'][:19]}] {event}", detail if detail else "")
|
||||
|
||||
|
||||
def load_and_split_data(data_file: Path, val_ratio=0.05, test_ratio=0.05):
|
||||
"""Load JSONL training pairs and create stratified splits by meta_template."""
|
||||
with open(data_file) as f:
|
||||
records = [json.loads(line) for line in f]
|
||||
|
||||
log_event("data_loaded", total_records=len(records))
|
||||
|
||||
# Convert to chat messages format
|
||||
for r in records:
|
||||
r["messages"] = [
|
||||
{"role": "user", "content": r["input"]},
|
||||
{"role": "assistant", "content": r["output"]},
|
||||
]
|
||||
|
||||
# Stratified split by meta_template
|
||||
random.seed(42)
|
||||
groups = defaultdict(list)
|
||||
for i, r in enumerate(records):
|
||||
groups[r["meta_template"]].append(i)
|
||||
|
||||
train_idx, val_idx, test_idx = [], [], []
|
||||
for template, indices in sorted(groups.items()):
|
||||
random.shuffle(indices)
|
||||
n = len(indices)
|
||||
n_test = max(1, round(n * test_ratio))
|
||||
n_val = max(1, round(n * val_ratio))
|
||||
test_idx.extend(indices[:n_test])
|
||||
val_idx.extend(indices[n_test : n_test + n_val])
|
||||
train_idx.extend(indices[n_test + n_val :])
|
||||
|
||||
def make_dataset(indices):
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": records[i]["messages"],
|
||||
"meta_template": records[i]["meta_template"],
|
||||
}
|
||||
for i in indices
|
||||
]
|
||||
)
|
||||
|
||||
train_ds = make_dataset(train_idx)
|
||||
val_ds = make_dataset(val_idx)
|
||||
test_ds = make_dataset(test_idx)
|
||||
|
||||
log_event(
|
||||
"data_split", train=len(train_ds), val=len(val_ds), test=len(test_ds)
|
||||
)
|
||||
|
||||
# Print distribution
|
||||
for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]:
|
||||
dist = Counter(ds["meta_template"])
|
||||
print(f"\n{name} ({len(ds)} examples):")
|
||||
for t, c in sorted(dist.items()):
|
||||
print(f" {t}: {c} ({c / len(ds) * 100:.1f}%)")
|
||||
|
||||
return train_ds, val_ds, test_ds
|
||||
|
||||
|
||||
class MetricsLogger(TrainerCallback):
|
||||
"""Log training metrics to the run log file."""
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs:
|
||||
log_event("train_metrics", step=state.global_step, **logs)
|
||||
|
||||
def on_epoch_end(self, args, state, control, **kwargs):
|
||||
epoch_num = int(state.epoch)
|
||||
train_loss = None
|
||||
for entry in reversed(state.log_history):
|
||||
if "loss" in entry:
|
||||
train_loss = entry["loss"]
|
||||
break
|
||||
log_event("epoch_complete", epoch=epoch_num, train_loss=train_loss)
|
||||
|
||||
|
||||
def main():
|
||||
start_time = time.time()
|
||||
|
||||
# Verify GPU
|
||||
if not torch.cuda.is_available():
|
||||
print("ERROR: No CUDA GPU available")
|
||||
sys.exit(1)
|
||||
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
log_event("session_start", gpu=gpu_name, vram_gb=round(gpu_mem, 1))
|
||||
|
||||
# Verify data
|
||||
if not DATA_FILE.exists():
|
||||
print(f"ERROR: Training data not found at {DATA_FILE}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create output directories
|
||||
for d in [OUTPUT_DIR, CHECKPOINT_DIR]:
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load data
|
||||
train_ds, val_ds, test_ds = load_and_split_data(DATA_FILE)
|
||||
|
||||
# Save val/test splits for later evaluation
|
||||
val_ds.to_json(OUTPUT_DIR / "val_split.jsonl")
|
||||
test_ds.to_json(OUTPUT_DIR / "test_split.jsonl")
|
||||
|
||||
# Load model and tokenizer
|
||||
log_event("model_loading", model=MODEL_ID)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID, dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# Override chat template to remove Qwen3 thinking tags
|
||||
tokenizer.chat_template = CHAT_TEMPLATE
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
log_event("model_loaded", parameters=f"{param_count / 1e6:.0f}M")
|
||||
|
||||
# Verify tokenization on a sample
|
||||
sample_messages = train_ds[0]["messages"]
|
||||
sample_encoded = tokenizer.apply_chat_template(
|
||||
sample_messages, tokenize=True, return_dict=False
|
||||
)
|
||||
sample_text = tokenizer.apply_chat_template(
|
||||
sample_messages, tokenize=False
|
||||
)
|
||||
print(f"\nSample tokenization ({len(sample_encoded)} tokens):")
|
||||
print(sample_text)
|
||||
|
||||
# Configure training
|
||||
training_args = SFTConfig(
|
||||
output_dir=str(CHECKPOINT_DIR),
|
||||
num_train_epochs=TRAINING_CONFIG["num_train_epochs"],
|
||||
per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=TRAINING_CONFIG["learning_rate"],
|
||||
lr_scheduler_type=TRAINING_CONFIG["lr_scheduler_type"],
|
||||
warmup_ratio=TRAINING_CONFIG["warmup_ratio"],
|
||||
weight_decay=TRAINING_CONFIG["weight_decay"],
|
||||
bf16=True,
|
||||
max_length=TRAINING_CONFIG["max_seq_length"],
|
||||
eval_strategy="steps",
|
||||
eval_steps=TRAINING_CONFIG["eval_steps"],
|
||||
save_strategy="steps",
|
||||
save_steps=TRAINING_CONFIG["save_steps"],
|
||||
logging_steps=TRAINING_CONFIG["logging_steps"],
|
||||
report_to="tensorboard",
|
||||
logging_dir=str(OUTPUT_DIR / "runs"),
|
||||
seed=TRAINING_CONFIG["seed"],
|
||||
dataloader_num_workers=2,
|
||||
optim="adamw_torch_fused",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=val_ds,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[MetricsLogger()],
|
||||
)
|
||||
|
||||
steps_per_epoch = len(train_ds) // TRAINING_CONFIG["per_device_train_batch_size"]
|
||||
total_steps = steps_per_epoch * TRAINING_CONFIG["num_train_epochs"]
|
||||
log_event(
|
||||
"training_start",
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
total_steps_approx=total_steps,
|
||||
config=TRAINING_CONFIG,
|
||||
)
|
||||
|
||||
# Train
|
||||
train_result = trainer.train()
|
||||
|
||||
training_time = time.time() - start_time
|
||||
log_event(
|
||||
"training_complete",
|
||||
wall_time_seconds=round(training_time, 1),
|
||||
wall_time_minutes=round(training_time / 60, 1),
|
||||
train_loss=train_result.training_loss,
|
||||
train_runtime=train_result.metrics.get("train_runtime"),
|
||||
train_samples_per_second=train_result.metrics.get(
|
||||
"train_samples_per_second"
|
||||
),
|
||||
)
|
||||
|
||||
# Save final model
|
||||
FINAL_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
trainer.save_model(str(FINAL_MODEL_DIR))
|
||||
tokenizer.save_pretrained(str(FINAL_MODEL_DIR))
|
||||
|
||||
# Save full training log history
|
||||
with open(FINAL_MODEL_DIR / "trainer_state.json", "w") as f:
|
||||
json.dump(trainer.state.log_history, f, indent=2)
|
||||
|
||||
log_event("final_model_saved", path=str(FINAL_MODEL_DIR))
|
||||
|
||||
# Run final eval on test set (swap eval dataset temporarily)
|
||||
original_eval_ds = trainer.eval_dataset
|
||||
trainer.eval_dataset = test_ds
|
||||
test_metrics = trainer.evaluate(metric_key_prefix="test")
|
||||
trainer.eval_dataset = original_eval_ds
|
||||
log_event("test_eval", **test_metrics)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("TRAINING COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Model: {MODEL_ID}")
|
||||
print(f"Training pairs: {len(train_ds)}")
|
||||
print(f"Val pairs: {len(val_ds)}")
|
||||
print(f"Test pairs: {len(test_ds)}")
|
||||
print(f"Final train loss: {train_result.training_loss:.4f}")
|
||||
print(f"Test loss: {test_metrics.get('test_loss', 'N/A')}")
|
||||
print(f"Wall time: {training_time / 60:.1f} minutes")
|
||||
print(f"Checkpoint: {FINAL_MODEL_DIR}")
|
||||
print(f"Training log: {LOG_FILE}")
|
||||
print(f"TensorBoard: {OUTPUT_DIR / 'runs'}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue