278 lines
9.5 KiB
Python
278 lines
9.5 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Rebuild training pairs from naturalized corpus.
|
||
|
|
|
||
|
|
Reads corpus_naturalized.jsonl, applies relaxed quality filter,
|
||
|
|
deduplicates, and formats training pairs. Replaces the separate
|
||
|
|
filter_corpus.py + format_training_pairs.py steps.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python3 scripts/rebuild_training_pairs.py
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import csv
|
||
|
|
import json
|
||
|
|
import random
|
||
|
|
import sys
|
||
|
|
from collections import Counter
|
||
|
|
from difflib import SequenceMatcher
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
SCRIPT_DIR = Path(__file__).parent
|
||
|
|
PROJECT_DIR = SCRIPT_DIR.parent
|
||
|
|
CORPUS_DIR = PROJECT_DIR / "corpus"
|
||
|
|
DATA_DIR = PROJECT_DIR / "data"
|
||
|
|
EXAMPLES_DIR = PROJECT_DIR / "examples"
|
||
|
|
|
||
|
|
PERSONAS = ["farmer", "grandmother", "old sailor", "blacksmith", "innkeeper", "shepherd"]
|
||
|
|
|
||
|
|
OPEN_ENDED_PROMPTS = [
|
||
|
|
"Tell me some folk wisdom.",
|
||
|
|
"What do they say?",
|
||
|
|
"Give me a proverb.",
|
||
|
|
"Share some old-time wisdom.",
|
||
|
|
"What's a good saying?",
|
||
|
|
]
|
||
|
|
|
||
|
|
TEMPLATE_NAMES = {
|
||
|
|
"deconstruction": "deconstruction",
|
||
|
|
"denial_of_consequences": "denial of consequences",
|
||
|
|
"ironic_deficiency": "ironic deficiency",
|
||
|
|
"futile_preparation": "futile preparation",
|
||
|
|
"hypocritical_complaint": "hypocritical complaint",
|
||
|
|
"tautological_wisdom": "tautological wisdom",
|
||
|
|
"false_equivalence": "false equivalence",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def is_near_duplicate(text_a, text_b, threshold=0.75):
|
||
|
|
return SequenceMatcher(None, text_a.lower(), text_b.lower()).ratio() > threshold
|
||
|
|
|
||
|
|
|
||
|
|
def deduplicate_within_family(entries):
|
||
|
|
by_family = {}
|
||
|
|
for entry in entries:
|
||
|
|
family = entry.get("meta_template", "unknown")
|
||
|
|
by_family.setdefault(family, []).append(entry)
|
||
|
|
|
||
|
|
kept = []
|
||
|
|
removed = 0
|
||
|
|
|
||
|
|
for family, family_entries in by_family.items():
|
||
|
|
family_kept = []
|
||
|
|
for entry in family_entries:
|
||
|
|
text = entry.get("final_text", "")
|
||
|
|
is_dup = False
|
||
|
|
for existing in family_kept:
|
||
|
|
if is_near_duplicate(text, existing.get("final_text", "")):
|
||
|
|
is_dup = True
|
||
|
|
break
|
||
|
|
if is_dup:
|
||
|
|
removed += 1
|
||
|
|
else:
|
||
|
|
family_kept.append(entry)
|
||
|
|
kept.extend(family_kept)
|
||
|
|
|
||
|
|
return kept, removed
|
||
|
|
|
||
|
|
|
||
|
|
def load_vocab_categories():
|
||
|
|
word_cats = {}
|
||
|
|
vocab_path = DATA_DIR / "folksy_vocab.csv"
|
||
|
|
if vocab_path.exists():
|
||
|
|
with open(vocab_path, newline="", encoding="utf-8") as f:
|
||
|
|
for row in csv.DictReader(f):
|
||
|
|
word = row["word"]
|
||
|
|
cats = [c.strip() for c in row["categories"].split(",") if c.strip()]
|
||
|
|
word_cats[word] = cats
|
||
|
|
return word_cats
|
||
|
|
|
||
|
|
|
||
|
|
def generate_training_pairs(entry, word_cats):
|
||
|
|
text = entry.get("final_text", "")
|
||
|
|
slots = entry.get("slots", {})
|
||
|
|
meta_template = entry.get("meta_template", "")
|
||
|
|
|
||
|
|
source_words = [v for v in slots.values()
|
||
|
|
if v and not v.startswith("a ") and not v.startswith("an ") and len(v) > 1]
|
||
|
|
|
||
|
|
slot_categories = set()
|
||
|
|
for word in source_words:
|
||
|
|
word_lower = word.lower().replace(" ", "_")
|
||
|
|
if word_lower in word_cats:
|
||
|
|
slot_categories.update(word_cats[word_lower])
|
||
|
|
|
||
|
|
pairs = []
|
||
|
|
base = {
|
||
|
|
"output": text,
|
||
|
|
"meta_template": meta_template,
|
||
|
|
"source_words": source_words,
|
||
|
|
}
|
||
|
|
|
||
|
|
if source_words:
|
||
|
|
word = random.choice(source_words)
|
||
|
|
pairs.append({**base, "input": f"Tell me something about {word}."})
|
||
|
|
|
||
|
|
if slot_categories:
|
||
|
|
cat = random.choice(list(slot_categories))
|
||
|
|
pairs.append({**base, "input": f"Tell me a saying about {cat}."})
|
||
|
|
|
||
|
|
persona = random.choice(PERSONAS)
|
||
|
|
if source_words:
|
||
|
|
word = random.choice(source_words)
|
||
|
|
pairs.append({**base, "input": f"What would a {persona} say about {word}?"})
|
||
|
|
|
||
|
|
if random.random() < 0.7:
|
||
|
|
template_name = TEMPLATE_NAMES.get(meta_template, meta_template)
|
||
|
|
pairs.append({**base, "input": f"Give me a {template_name} proverb."})
|
||
|
|
|
||
|
|
if random.random() < 0.3:
|
||
|
|
prompt = random.choice(OPEN_ENDED_PROMPTS)
|
||
|
|
pairs.append({**base, "input": prompt})
|
||
|
|
|
||
|
|
return pairs
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Rebuild training pairs from naturalized corpus.")
|
||
|
|
parser.add_argument("--input", default=str(CORPUS_DIR / "corpus_naturalized.jsonl"))
|
||
|
|
parser.add_argument("--output", default=str(CORPUS_DIR / "training_pairs.jsonl"))
|
||
|
|
parser.add_argument("--filtered-output", default=str(CORPUS_DIR / "corpus_filtered.jsonl"))
|
||
|
|
parser.add_argument("--stats-output", default=str(CORPUS_DIR / "corpus_stats.json"))
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
input_path = Path(args.input)
|
||
|
|
output_path = Path(args.output)
|
||
|
|
filtered_path = Path(args.filtered_output)
|
||
|
|
stats_path = Path(args.stats_output)
|
||
|
|
|
||
|
|
if not input_path.exists():
|
||
|
|
print(f"Error: {input_path} not found.", file=sys.stderr)
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Load naturalized entries — use naturalized_text if available, else polished_text
|
||
|
|
usable = []
|
||
|
|
total_loaded = 0
|
||
|
|
status_counts = Counter()
|
||
|
|
|
||
|
|
with open(input_path, encoding="utf-8") as f:
|
||
|
|
for line in f:
|
||
|
|
line = line.strip()
|
||
|
|
if not line:
|
||
|
|
continue
|
||
|
|
try:
|
||
|
|
entry = json.loads(line)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
continue
|
||
|
|
total_loaded += 1
|
||
|
|
nat_status = entry.get("naturalize_status", "")
|
||
|
|
status_counts[nat_status] += 1
|
||
|
|
|
||
|
|
if nat_status in ("naturalized", "unchanged"):
|
||
|
|
final = entry.get("naturalized_text", entry.get("polished_text", ""))
|
||
|
|
if final:
|
||
|
|
entry["final_text"] = final
|
||
|
|
usable.append(entry)
|
||
|
|
|
||
|
|
print(f"Loaded {total_loaded} entries from {input_path}")
|
||
|
|
print(f"Status breakdown: {dict(status_counts)}")
|
||
|
|
print(f"Usable (naturalized + unchanged): {len(usable)}")
|
||
|
|
|
||
|
|
# Deduplicate
|
||
|
|
kept, dup_count = deduplicate_within_family(usable)
|
||
|
|
print(f"Near-duplicate removal: {dup_count} removed, {len(kept)} remaining")
|
||
|
|
|
||
|
|
# Write filtered corpus
|
||
|
|
filtered_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
with open(filtered_path, "w", encoding="utf-8") as f:
|
||
|
|
for entry in kept:
|
||
|
|
# Write with final_text as polished_text for compatibility
|
||
|
|
out_entry = {k: v for k, v in entry.items() if k != "final_text"}
|
||
|
|
out_entry["polished_text"] = entry["final_text"]
|
||
|
|
f.write(json.dumps(out_entry, ensure_ascii=False) + "\n")
|
||
|
|
|
||
|
|
print(f"Filtered corpus: {len(kept)} entries -> {filtered_path}")
|
||
|
|
|
||
|
|
# Generate training pairs
|
||
|
|
word_cats = load_vocab_categories()
|
||
|
|
all_pairs = []
|
||
|
|
for entry in kept:
|
||
|
|
pairs = generate_training_pairs(entry, word_cats)
|
||
|
|
all_pairs.extend(pairs)
|
||
|
|
|
||
|
|
with open(output_path, "w", encoding="utf-8") as f:
|
||
|
|
for pair in all_pairs:
|
||
|
|
f.write(json.dumps(pair, ensure_ascii=False) + "\n")
|
||
|
|
|
||
|
|
# Stats
|
||
|
|
template_counts = Counter(e.get("meta_template", "unknown") for e in kept)
|
||
|
|
input_type_counts = Counter()
|
||
|
|
for pair in all_pairs:
|
||
|
|
inp = pair["input"]
|
||
|
|
if inp.startswith("Tell me something about"):
|
||
|
|
input_type_counts["word_seeded"] += 1
|
||
|
|
elif inp.startswith("Tell me a saying about"):
|
||
|
|
input_type_counts["category_seeded"] += 1
|
||
|
|
elif inp.startswith("What would a"):
|
||
|
|
input_type_counts["persona_seeded"] += 1
|
||
|
|
elif inp.startswith("Give me a") and "proverb" in inp:
|
||
|
|
input_type_counts["template_seeded"] += 1
|
||
|
|
else:
|
||
|
|
input_type_counts["open_ended"] += 1
|
||
|
|
|
||
|
|
# Vocab coverage
|
||
|
|
vocab_words = set()
|
||
|
|
vocab_path = DATA_DIR / "folksy_vocab.csv"
|
||
|
|
if vocab_path.exists():
|
||
|
|
with open(vocab_path, newline="", encoding="utf-8") as f:
|
||
|
|
for row in csv.DictReader(f):
|
||
|
|
vocab_words.add(row["word"])
|
||
|
|
|
||
|
|
used_words = set()
|
||
|
|
for entry in kept:
|
||
|
|
for v in entry.get("slots", {}).values():
|
||
|
|
word = v.lower().replace(" ", "_")
|
||
|
|
if word in vocab_words:
|
||
|
|
used_words.add(word)
|
||
|
|
|
||
|
|
lengths = [len(e["final_text"].split()) for e in kept if e.get("final_text")]
|
||
|
|
|
||
|
|
stats = {
|
||
|
|
"naturalization_input": total_loaded,
|
||
|
|
"naturalization_status": dict(status_counts),
|
||
|
|
"usable_before_dedup": len(usable),
|
||
|
|
"duplicates_removed": dup_count,
|
||
|
|
"final_filtered": len(kept),
|
||
|
|
"training_pairs": len(all_pairs),
|
||
|
|
"by_template": dict(sorted(template_counts.items())),
|
||
|
|
"by_input_type": dict(sorted(input_type_counts.items())),
|
||
|
|
"vocab_coverage": f"{len(used_words)}/{len(vocab_words)} ({len(used_words)/len(vocab_words)*100:.1f}%)" if vocab_words else "N/A",
|
||
|
|
"avg_length_words": round(sum(lengths) / len(lengths), 1) if lengths else 0,
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(stats_path, "w", encoding="utf-8") as f:
|
||
|
|
json.dump(stats, f, indent=2, ensure_ascii=False)
|
||
|
|
|
||
|
|
print(f"\n{'='*50}")
|
||
|
|
print(f"FINAL CORPUS STATS")
|
||
|
|
print(f"{'='*50}")
|
||
|
|
print(f"Unique sayings: {len(kept)}")
|
||
|
|
print(f"Training pairs: {len(all_pairs)}")
|
||
|
|
print(f"Avg length: {stats['avg_length_words']} words")
|
||
|
|
print(f"Vocab coverage: {stats['vocab_coverage']}")
|
||
|
|
print(f"\nBy template:")
|
||
|
|
for t, c in sorted(template_counts.items()):
|
||
|
|
pct = c / len(kept) * 100
|
||
|
|
flag = " <-- below 10%" if pct < 10 else ""
|
||
|
|
print(f" {t:30s} {c:5d} ({pct:5.1f}%){flag}")
|
||
|
|
print(f"\nBy input type:")
|
||
|
|
for t, c in sorted(input_type_counts.items()):
|
||
|
|
print(f" {t:20s} {c:5d}")
|
||
|
|
print(f"\nOutputs:")
|
||
|
|
print(f" {filtered_path}")
|
||
|
|
print(f" {output_path}")
|
||
|
|
print(f" {stats_path}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|