#!/usr/bin/env python3 """SFT fine-tune Qwen3-0.6B-Base on folksy proverb training pairs. Usage: python scripts/train_sft.py Expects corpus/training_pairs.jsonl in the project root. Outputs model checkpoints and training logs to folksy-model/. """ import json import os import random import sys import time from collections import Counter, defaultdict from datetime import datetime from pathlib import Path # Prevent CUDA fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch from datasets import Dataset from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback from trl import SFTConfig, SFTTrainer # === Configuration === MODEL_ID = "Qwen/Qwen3-0.6B-Base" PROJECT_ROOT = Path(__file__).resolve().parent.parent DATA_FILE = PROJECT_ROOT / "corpus" / "training_pairs.jsonl" OUTPUT_DIR = PROJECT_ROOT / "folksy-model" CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints" FINAL_MODEL_DIR = OUTPUT_DIR / "final" LOG_FILE = OUTPUT_DIR / "training_log.jsonl" # ChatML template without Qwen3 thinking tags — clean input/output format CHAT_TEMPLATE = ( "{% for message in messages %}" "<|im_start|>{{ message['role'] }}\n" "{{ message['content'] }}<|im_end|>\n" "{% endfor %}" "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" ) TRAINING_CONFIG = { "num_train_epochs": 3, "per_device_train_batch_size": 32, "learning_rate": 2e-5, "lr_scheduler_type": "cosine", "warmup_ratio": 0.05, "weight_decay": 0.01, "max_seq_length": 128, "eval_steps": 100, "save_steps": 500, "logging_steps": 10, "seed": 42, } def log_event(event: str, **kwargs): """Append a structured log event.""" entry = {"timestamp": datetime.now().isoformat(), "event": event, **kwargs} LOG_FILE.parent.mkdir(parents=True, exist_ok=True) with open(LOG_FILE, "a") as f: f.write(json.dumps(entry) + "\n") detail = {k: v for k, v in kwargs.items() if k != "timestamp"} print(f"[{entry['timestamp'][:19]}] {event}", detail if detail else "") def load_and_split_data(data_file: Path, val_ratio=0.05, test_ratio=0.05): """Load JSONL training pairs and create stratified splits by meta_template.""" with open(data_file) as f: records = [json.loads(line) for line in f] log_event("data_loaded", total_records=len(records)) # Convert to chat messages format for r in records: r["messages"] = [ {"role": "user", "content": r["input"]}, {"role": "assistant", "content": r["output"]}, ] # Stratified split by meta_template random.seed(42) groups = defaultdict(list) for i, r in enumerate(records): groups[r["meta_template"]].append(i) train_idx, val_idx, test_idx = [], [], [] for template, indices in sorted(groups.items()): random.shuffle(indices) n = len(indices) n_test = max(1, round(n * test_ratio)) n_val = max(1, round(n * val_ratio)) test_idx.extend(indices[:n_test]) val_idx.extend(indices[n_test : n_test + n_val]) train_idx.extend(indices[n_test + n_val :]) def make_dataset(indices): return Dataset.from_list( [ { "messages": records[i]["messages"], "meta_template": records[i]["meta_template"], } for i in indices ] ) train_ds = make_dataset(train_idx) val_ds = make_dataset(val_idx) test_ds = make_dataset(test_idx) log_event( "data_split", train=len(train_ds), val=len(val_ds), test=len(test_ds) ) # Print distribution for name, ds in [("train", train_ds), ("val", val_ds), ("test", test_ds)]: dist = Counter(ds["meta_template"]) print(f"\n{name} ({len(ds)} examples):") for t, c in sorted(dist.items()): print(f" {t}: {c} ({c / len(ds) * 100:.1f}%)") return train_ds, val_ds, test_ds class MetricsLogger(TrainerCallback): """Log training metrics to the run log file.""" def on_log(self, args, state, control, logs=None, **kwargs): if logs: log_event("train_metrics", step=state.global_step, **logs) def on_epoch_end(self, args, state, control, **kwargs): epoch_num = int(state.epoch) train_loss = None for entry in reversed(state.log_history): if "loss" in entry: train_loss = entry["loss"] break log_event("epoch_complete", epoch=epoch_num, train_loss=train_loss) def main(): start_time = time.time() # Verify GPU if not torch.cuda.is_available(): print("ERROR: No CUDA GPU available") sys.exit(1) gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 log_event("session_start", gpu=gpu_name, vram_gb=round(gpu_mem, 1)) # Verify data if not DATA_FILE.exists(): print(f"ERROR: Training data not found at {DATA_FILE}") sys.exit(1) # Create output directories for d in [OUTPUT_DIR, CHECKPOINT_DIR]: d.mkdir(parents=True, exist_ok=True) # Load data train_ds, val_ds, test_ds = load_and_split_data(DATA_FILE) # Save val/test splits for later evaluation val_ds.to_json(OUTPUT_DIR / "val_split.jsonl") test_ds.to_json(OUTPUT_DIR / "test_split.jsonl") # Load model and tokenizer log_event("model_loading", model=MODEL_ID) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16 ) # Override chat template to remove Qwen3 thinking tags tokenizer.chat_template = CHAT_TEMPLATE if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token param_count = sum(p.numel() for p in model.parameters()) log_event("model_loaded", parameters=f"{param_count / 1e6:.0f}M") # Verify tokenization on a sample sample_messages = train_ds[0]["messages"] sample_encoded = tokenizer.apply_chat_template( sample_messages, tokenize=True, return_dict=False ) sample_text = tokenizer.apply_chat_template( sample_messages, tokenize=False ) print(f"\nSample tokenization ({len(sample_encoded)} tokens):") print(sample_text) # Configure training training_args = SFTConfig( output_dir=str(CHECKPOINT_DIR), num_train_epochs=TRAINING_CONFIG["num_train_epochs"], per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"], gradient_accumulation_steps=1, learning_rate=TRAINING_CONFIG["learning_rate"], lr_scheduler_type=TRAINING_CONFIG["lr_scheduler_type"], warmup_ratio=TRAINING_CONFIG["warmup_ratio"], weight_decay=TRAINING_CONFIG["weight_decay"], bf16=True, max_length=TRAINING_CONFIG["max_seq_length"], eval_strategy="steps", eval_steps=TRAINING_CONFIG["eval_steps"], save_strategy="steps", save_steps=TRAINING_CONFIG["save_steps"], logging_steps=TRAINING_CONFIG["logging_steps"], report_to="tensorboard", logging_dir=str(OUTPUT_DIR / "runs"), seed=TRAINING_CONFIG["seed"], dataloader_num_workers=2, optim="adamw_torch_fused", load_best_model_at_end=True, metric_for_best_model="eval_loss", ) trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, processing_class=tokenizer, callbacks=[MetricsLogger()], ) steps_per_epoch = len(train_ds) // TRAINING_CONFIG["per_device_train_batch_size"] total_steps = steps_per_epoch * TRAINING_CONFIG["num_train_epochs"] log_event( "training_start", steps_per_epoch=steps_per_epoch, total_steps_approx=total_steps, config=TRAINING_CONFIG, ) # Train train_result = trainer.train() training_time = time.time() - start_time log_event( "training_complete", wall_time_seconds=round(training_time, 1), wall_time_minutes=round(training_time / 60, 1), train_loss=train_result.training_loss, train_runtime=train_result.metrics.get("train_runtime"), train_samples_per_second=train_result.metrics.get( "train_samples_per_second" ), ) # Save final model FINAL_MODEL_DIR.mkdir(parents=True, exist_ok=True) trainer.save_model(str(FINAL_MODEL_DIR)) tokenizer.save_pretrained(str(FINAL_MODEL_DIR)) # Save full training log history with open(FINAL_MODEL_DIR / "trainer_state.json", "w") as f: json.dump(trainer.state.log_history, f, indent=2) log_event("final_model_saved", path=str(FINAL_MODEL_DIR)) # Run final eval on test set (swap eval dataset temporarily) original_eval_ds = trainer.eval_dataset trainer.eval_dataset = test_ds test_metrics = trainer.evaluate(metric_key_prefix="test") trainer.eval_dataset = original_eval_ds log_event("test_eval", **test_metrics) # Print summary print("\n" + "=" * 60) print("TRAINING COMPLETE") print("=" * 60) print(f"Model: {MODEL_ID}") print(f"Training pairs: {len(train_ds)}") print(f"Val pairs: {len(val_ds)}") print(f"Test pairs: {len(test_ds)}") print(f"Final train loss: {train_result.training_loss:.4f}") print(f"Test loss: {test_metrics.get('test_loss', 'N/A')}") print(f"Wall time: {training_time / 60:.1f} minutes") print(f"Checkpoint: {FINAL_MODEL_DIR}") print(f"Training log: {LOG_FILE}") print(f"TensorBoard: {OUTPUT_DIR / 'runs'}") print("=" * 60) if __name__ == "__main__": main()