folksy_idioms/scripts/rebuild_training_pairs.py

278 lines
9.5 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Rebuild training pairs from naturalized corpus.
Reads corpus_naturalized.jsonl, applies relaxed quality filter,
deduplicates, and formats training pairs. Replaces the separate
filter_corpus.py + format_training_pairs.py steps.
Usage:
python3 scripts/rebuild_training_pairs.py
"""
import argparse
import csv
import json
import random
import sys
from collections import Counter
from difflib import SequenceMatcher
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent
PROJECT_DIR = SCRIPT_DIR.parent
CORPUS_DIR = PROJECT_DIR / "corpus"
DATA_DIR = PROJECT_DIR / "data"
EXAMPLES_DIR = PROJECT_DIR / "examples"
PERSONAS = ["farmer", "grandmother", "old sailor", "blacksmith", "innkeeper", "shepherd"]
OPEN_ENDED_PROMPTS = [
"Tell me some folk wisdom.",
"What do they say?",
"Give me a proverb.",
"Share some old-time wisdom.",
"What's a good saying?",
]
TEMPLATE_NAMES = {
"deconstruction": "deconstruction",
"denial_of_consequences": "denial of consequences",
"ironic_deficiency": "ironic deficiency",
"futile_preparation": "futile preparation",
"hypocritical_complaint": "hypocritical complaint",
"tautological_wisdom": "tautological wisdom",
"false_equivalence": "false equivalence",
}
def is_near_duplicate(text_a, text_b, threshold=0.75):
return SequenceMatcher(None, text_a.lower(), text_b.lower()).ratio() > threshold
def deduplicate_within_family(entries):
by_family = {}
for entry in entries:
family = entry.get("meta_template", "unknown")
by_family.setdefault(family, []).append(entry)
kept = []
removed = 0
for family, family_entries in by_family.items():
family_kept = []
for entry in family_entries:
text = entry.get("final_text", "")
is_dup = False
for existing in family_kept:
if is_near_duplicate(text, existing.get("final_text", "")):
is_dup = True
break
if is_dup:
removed += 1
else:
family_kept.append(entry)
kept.extend(family_kept)
return kept, removed
def load_vocab_categories():
word_cats = {}
vocab_path = DATA_DIR / "folksy_vocab.csv"
if vocab_path.exists():
with open(vocab_path, newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
word = row["word"]
cats = [c.strip() for c in row["categories"].split(",") if c.strip()]
word_cats[word] = cats
return word_cats
def generate_training_pairs(entry, word_cats):
text = entry.get("final_text", "")
slots = entry.get("slots", {})
meta_template = entry.get("meta_template", "")
source_words = [v for v in slots.values()
if v and not v.startswith("a ") and not v.startswith("an ") and len(v) > 1]
slot_categories = set()
for word in source_words:
word_lower = word.lower().replace(" ", "_")
if word_lower in word_cats:
slot_categories.update(word_cats[word_lower])
pairs = []
base = {
"output": text,
"meta_template": meta_template,
"source_words": source_words,
}
if source_words:
word = random.choice(source_words)
pairs.append({**base, "input": f"Tell me something about {word}."})
if slot_categories:
cat = random.choice(list(slot_categories))
pairs.append({**base, "input": f"Tell me a saying about {cat}."})
persona = random.choice(PERSONAS)
if source_words:
word = random.choice(source_words)
pairs.append({**base, "input": f"What would a {persona} say about {word}?"})
if random.random() < 0.7:
template_name = TEMPLATE_NAMES.get(meta_template, meta_template)
pairs.append({**base, "input": f"Give me a {template_name} proverb."})
if random.random() < 0.3:
prompt = random.choice(OPEN_ENDED_PROMPTS)
pairs.append({**base, "input": prompt})
return pairs
def main():
parser = argparse.ArgumentParser(description="Rebuild training pairs from naturalized corpus.")
parser.add_argument("--input", default=str(CORPUS_DIR / "corpus_naturalized.jsonl"))
parser.add_argument("--output", default=str(CORPUS_DIR / "training_pairs.jsonl"))
parser.add_argument("--filtered-output", default=str(CORPUS_DIR / "corpus_filtered.jsonl"))
parser.add_argument("--stats-output", default=str(CORPUS_DIR / "corpus_stats.json"))
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
filtered_path = Path(args.filtered_output)
stats_path = Path(args.stats_output)
if not input_path.exists():
print(f"Error: {input_path} not found.", file=sys.stderr)
sys.exit(1)
# Load naturalized entries — use naturalized_text if available, else polished_text
usable = []
total_loaded = 0
status_counts = Counter()
with open(input_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
except json.JSONDecodeError:
continue
total_loaded += 1
nat_status = entry.get("naturalize_status", "")
status_counts[nat_status] += 1
if nat_status in ("naturalized", "unchanged"):
final = entry.get("naturalized_text", entry.get("polished_text", ""))
if final:
entry["final_text"] = final
usable.append(entry)
print(f"Loaded {total_loaded} entries from {input_path}")
print(f"Status breakdown: {dict(status_counts)}")
print(f"Usable (naturalized + unchanged): {len(usable)}")
# Deduplicate
kept, dup_count = deduplicate_within_family(usable)
print(f"Near-duplicate removal: {dup_count} removed, {len(kept)} remaining")
# Write filtered corpus
filtered_path.parent.mkdir(parents=True, exist_ok=True)
with open(filtered_path, "w", encoding="utf-8") as f:
for entry in kept:
# Write with final_text as polished_text for compatibility
out_entry = {k: v for k, v in entry.items() if k != "final_text"}
out_entry["polished_text"] = entry["final_text"]
f.write(json.dumps(out_entry, ensure_ascii=False) + "\n")
print(f"Filtered corpus: {len(kept)} entries -> {filtered_path}")
# Generate training pairs
word_cats = load_vocab_categories()
all_pairs = []
for entry in kept:
pairs = generate_training_pairs(entry, word_cats)
all_pairs.extend(pairs)
with open(output_path, "w", encoding="utf-8") as f:
for pair in all_pairs:
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
# Stats
template_counts = Counter(e.get("meta_template", "unknown") for e in kept)
input_type_counts = Counter()
for pair in all_pairs:
inp = pair["input"]
if inp.startswith("Tell me something about"):
input_type_counts["word_seeded"] += 1
elif inp.startswith("Tell me a saying about"):
input_type_counts["category_seeded"] += 1
elif inp.startswith("What would a"):
input_type_counts["persona_seeded"] += 1
elif inp.startswith("Give me a") and "proverb" in inp:
input_type_counts["template_seeded"] += 1
else:
input_type_counts["open_ended"] += 1
# Vocab coverage
vocab_words = set()
vocab_path = DATA_DIR / "folksy_vocab.csv"
if vocab_path.exists():
with open(vocab_path, newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
vocab_words.add(row["word"])
used_words = set()
for entry in kept:
for v in entry.get("slots", {}).values():
word = v.lower().replace(" ", "_")
if word in vocab_words:
used_words.add(word)
lengths = [len(e["final_text"].split()) for e in kept if e.get("final_text")]
stats = {
"naturalization_input": total_loaded,
"naturalization_status": dict(status_counts),
"usable_before_dedup": len(usable),
"duplicates_removed": dup_count,
"final_filtered": len(kept),
"training_pairs": len(all_pairs),
"by_template": dict(sorted(template_counts.items())),
"by_input_type": dict(sorted(input_type_counts.items())),
"vocab_coverage": f"{len(used_words)}/{len(vocab_words)} ({len(used_words)/len(vocab_words)*100:.1f}%)" if vocab_words else "N/A",
"avg_length_words": round(sum(lengths) / len(lengths), 1) if lengths else 0,
}
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
print(f"\n{'='*50}")
print(f"FINAL CORPUS STATS")
print(f"{'='*50}")
print(f"Unique sayings: {len(kept)}")
print(f"Training pairs: {len(all_pairs)}")
print(f"Avg length: {stats['avg_length_words']} words")
print(f"Vocab coverage: {stats['vocab_coverage']}")
print(f"\nBy template:")
for t, c in sorted(template_counts.items()):
pct = c / len(kept) * 100
flag = " <-- below 10%" if pct < 10 else ""
print(f" {t:30s} {c:5d} ({pct:5.1f}%){flag}")
print(f"\nBy input type:")
for t, c in sorted(input_type_counts.items()):
print(f" {t:20s} {c:5d}")
print(f"\nOutputs:")
print(f" {filtered_path}")
print(f" {output_path}")
print(f" {stats_path}")
if __name__ == "__main__":
main()