folksy_idioms/scripts/enhance_graph.py

787 lines
27 KiB
Python

#!/usr/bin/env python3
"""LLM-augmented graph enhancement for the folksy subgraph.
Three phases:
Phase 1: Per-word relationship expansion
Phase 2: Cross-word bridge discovery
Phase 3: Property enrichment for false_equivalence templates
Usage:
python scripts/enhance_graph.py --phase 1 # Run phase 1 only
python scripts/enhance_graph.py --phase 2 # Run phase 2 only
python scripts/enhance_graph.py --phase 3 # Run phase 3 only
python scripts/enhance_graph.py --all # Run all phases
python scripts/enhance_graph.py --phase 1 --dry-run # Print prompts without calling LLM
"""
import argparse
import csv
import os
import random
import re
import sys
import time
from collections import defaultdict
from datetime import datetime
from pathlib import Path
# Paths
SCRIPT_DIR = Path(__file__).parent
PROJECT_DIR = SCRIPT_DIR.parent
DATA_DIR = PROJECT_DIR / "data"
LLM_ENDPOINT = "http://192.168.1.100:8853/v1d/chat/completions"
LLM_MODEL = "THUDM-GLM4-32B"
VALID_RELATIONS = {
"AtLocation", "MadeOf", "PartOf", "UsedFor", "HasA", "HasProperty",
"Causes", "HasPrerequisite", "CapableOf", "ReceivesAction", "Desires",
"CausesDesire", "LocatedNear", "CreatedBy", "MotivatedByGoal", "HasSubevent",
}
AUGMENTED_CSV = DATA_DIR / "folksy_relations_augmented.csv"
CANDIDATE_CSV = DATA_DIR / "candidate_additions.csv"
LOG_CSV = DATA_DIR / "enhancement_log.csv"
# ---------------------------------------------------------------------------
# Infrastructure
# ---------------------------------------------------------------------------
def llm_chat_completion(messages, max_retries=3):
"""Chat completion with retry logic."""
import requests
for attempt in range(max_retries):
try:
resp = requests.post(LLM_ENDPOINT, json={
"model": LLM_MODEL,
"messages": messages,
}, timeout=120)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
wait = (2 ** attempt)
print(f" LLM call failed (attempt {attempt+1}/{max_retries}): {e}", file=sys.stderr)
if attempt < max_retries - 1:
print(f" Retrying in {wait}s...", file=sys.stderr)
time.sleep(wait)
else:
print(f" Giving up on this word.", file=sys.stderr)
return None
def load_vocab():
"""Load folksy vocabulary."""
vocab = {}
with open(DATA_DIR / "folksy_vocab.csv", newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
word = row["word"]
cats = [c.strip() for c in row["categories"].split(",") if c.strip()]
vocab[word] = {
"categories": cats,
"tangibility": float(row.get("tangibility_score", 0)),
"edge_count": int(row.get("conceptnet_edge_count", 0)),
}
return vocab
def load_relations():
"""Load existing relations (ConceptNet + any existing augmented)."""
edges = defaultdict(list) # (start, relation) -> [(end, weight, surface)]
existing_triples = set() # (start, end, relation) for dedup
for path in [DATA_DIR / "folksy_relations.csv", AUGMENTED_CSV]:
if not path.exists():
continue
with open(path, newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
sw = row["start_word"]
ew = row["end_word"]
rel = row["relation"]
if not row['weight']: continue # corruption / skip?
w = float(row["weight"])
surf = row.get("surface_text", "")
edges[(sw, rel)].append((ew, w, surf))
existing_triples.add((sw, ew, rel))
return edges, existing_triples
def load_checkpoint():
"""Load enhancement log to determine what's already been processed."""
processed = set() # (word, phase)
if LOG_CSV.exists():
with open(LOG_CSV, newline="", encoding="utf-8") as f:
for row in csv.DictReader(f):
processed.add((row["source_word"], row["phase"]))
return processed
def append_log(word, phase, edges_generated, edges_accepted, edges_duplicate, edges_oov):
"""Append a row to the enhancement log."""
write_header = not LOG_CSV.exists()
with open(LOG_CSV, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["source_word", "phase", "timestamp",
"edges_generated", "edges_accepted", "edges_duplicate", "edges_oov"])
writer.writerow([word, phase, datetime.now().isoformat(),
edges_generated, edges_accepted, edges_duplicate, edges_oov])
def append_augmented_edges(edges):
"""Append edges to the augmented relations CSV."""
write_header = not AUGMENTED_CSV.exists()
with open(AUGMENTED_CSV, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["start_word", "end_word", "relation", "weight", "surface_text", "source"])
for e in edges:
writer.writerow([e["start_word"], e["end_word"], e["relation"],
e["weight"], e["surface_text"], e["source"]])
def append_candidates(candidates):
"""Append candidate words to the candidate additions CSV."""
write_header = not CANDIDATE_CSV.exists()
with open(CANDIDATE_CSV, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["word", "suggested_by", "relation_context", "frequency"])
for c in candidates:
writer.writerow([c["word"], c["suggested_by"], c["relation_context"], c["frequency"]])
# ---------------------------------------------------------------------------
# Parsing
# ---------------------------------------------------------------------------
def parse_llm_relations(response_text, source_word):
"""Parse structured LLM output into edge dicts.
Handles bullets, numbering, extra whitespace, multi-word targets.
"""
edges = []
if not response_text:
return edges
for line in response_text.strip().split("\n"):
line = line.strip()
if not line:
continue
# Strip leading bullets/numbers: "- ", "1. ", "* ", etc.
line = re.sub(r"^[\d]+[.)]\s*", "", line)
line = re.sub(r"^[-*•]\s*", "", line)
line = line.strip()
if not line or "NONE" in line.upper():
continue
# Match: RELATION_TYPE: target_word(s) | surface text
match = re.match(r"^(\w+):\s*(.+?)\s*\|\s*(.+)$", line)
if not match:
continue
relation, target_raw, surface = match.groups()
relation = relation.strip()
if relation not in VALID_RELATIONS:
continue
# Normalize target: lowercase, replace spaces with underscores for multi-word
target = target_raw.strip().lower()
target = re.sub(r"\s+", "_", target)
# Skip self-loops
if target == source_word:
continue
edges.append({
"start_word": source_word,
"end_word": target,
"relation": relation,
"weight": 0.8,
"surface_text": surface.strip(),
"source": "llm_augmented",
})
return edges
def parse_bridge_response(response_text, word_a, word_b):
"""Parse bridge discovery LLM output."""
edges = []
if not response_text:
return edges
for line in response_text.strip().split("\n"):
line = line.strip()
if not line:
continue
# Strip common prefixes
line = re.sub(r"^[\d]+[.)]\s*", "", line)
line = re.sub(r"^[-*•]\s*", "", line)
line = re.sub(r"^BRIDGE:\s*", "", line, flags=re.IGNORECASE)
line = line.strip()
if not line:
continue
# BRIDGE_WORD | relation_to_first: TYPE | relation_to_second: TYPE | explanation
parts = [p.strip() for p in line.split("|")]
if len(parts) < 3:
continue
bridge_word = parts[0].strip().lower().replace(" ", "_")
# Parse relation_to_first
rel1_match = re.search(r"(?:relation_to_first|first):\s*(\w+)", parts[1], re.IGNORECASE)
rel2_match = re.search(r"(?:relation_to_second|second):\s*(\w+)", parts[2], re.IGNORECASE)
if not rel1_match or not rel2_match:
# Try simpler format: just the relation type
rel1_match = re.match(r"(\w+)", parts[1].split(":")[-1].strip())
rel2_match = re.match(r"(\w+)", parts[2].split(":")[-1].strip())
if not rel1_match or not rel2_match:
continue
rel1 = rel1_match.group(1)
rel2 = rel2_match.group(1)
if rel1 not in VALID_RELATIONS or rel2 not in VALID_RELATIONS:
continue
explanation = parts[3].strip() if len(parts) > 3 else ""
# Create edges: word_a -> bridge and bridge -> word_b
edges.append({
"start_word": word_a,
"end_word": bridge_word,
"relation": rel1,
"weight": 0.8,
"surface_text": explanation,
"source": "llm_bridge",
})
edges.append({
"start_word": bridge_word,
"end_word": word_b,
"relation": rel2,
"weight": 0.8,
"surface_text": explanation,
"source": "llm_bridge",
})
return edges
def parse_property_response(response_text, word):
"""Parse property enrichment LLM output."""
edges = []
if not response_text:
return edges
for line in response_text.strip().split("\n"):
line = line.strip()
if not line:
continue
line = re.sub(r"^[\d]+[.)]\s*", "", line)
line = re.sub(r"^[-*•]\s*", "", line)
line = line.strip()
if not line:
continue
# PROPERTY | explanation
parts = [p.strip() for p in line.split("|")]
if len(parts) < 1:
continue
prop = parts[0].strip().lower().replace(" ", "_")
explanation = parts[1].strip() if len(parts) > 1 else f"{word} is {prop}"
if not prop or prop == word:
continue
edges.append({
"start_word": word,
"end_word": prop,
"relation": "HasProperty",
"weight": 0.8,
"surface_text": explanation,
"source": "llm_property",
})
return edges
# ---------------------------------------------------------------------------
# Phase 1: Per-Word Expansion
# ---------------------------------------------------------------------------
PHASE1_SYSTEM = """You are a commonsense knowledge annotator. You will be given a concrete noun and its known relationships. Your job is to generate ADDITIONAL commonsense relationships that are missing.
Rules:
- Only generate relationships involving concrete, tangible things (animals, foods, tools, plants, buildings, weather, landscape, household objects)
- Every relationship must be something a typical adult would agree is true
- Do not repeat any relationship already listed as "known"
- Target words should be common English words (top 3000 frequency preferred)
- Output ONLY the structured format shown below, one relationship per line
- If you cannot think of good relationships for a given type, output NONE for that type
- Aim for 3-5 relationships per type where possible
Output format (one per line):
RELATION_TYPE: target_word | short natural phrasing
Example output:
AtLocation: barn | you find a horse in a barn
UsedFor: riding | a horse is used for riding
HasA: mane | a horse has a mane
CapableOf: gallop | a horse can gallop
MadeOf: NONE
PartOf: herd | a horse is part of a herd"""
PHASE1_USER = """Word: {word}
Categories: {categories}
Known relationships:
{existing_edges}
Generate additional relationships for these types:
- AtLocation (where is it found?)
- UsedFor (what is it used for?)
- HasA (what does it have / contain?)
- PartOf (what is it part of?)
- CapableOf (what can it do?)
- MadeOf (what is it made of?)
- HasPrerequisite (what do you need before you can have/use it?)
- Causes (what does it cause or lead to?)
- HasProperty (what adjectives describe it? — limit to physical/sensory properties)"""
def format_existing_edges(edges_dict, word):
"""Format existing edges for a word grouped by relation type."""
relation_types = ["AtLocation", "UsedFor", "HasA", "PartOf", "CapableOf",
"MadeOf", "HasPrerequisite", "Causes", "HasProperty"]
lines = []
for rel in relation_types:
targets = edges_dict.get((word, rel), [])
if targets:
formatted = ", ".join(f"{t[0]} (weight {t[1]:.1f})" for t in targets[:10])
lines.append(f"{rel}: {formatted}")
else:
lines.append(f"{rel}: (none in database)")
return "\n".join(lines)
def run_phase1(vocab, edges, existing_triples, checkpoint, dry_run=False):
"""Phase 1: Per-word relationship expansion."""
words = sorted(vocab.keys())
total = len(words)
total_accepted = 0
total_skipped = 0
print(f"Phase 1: Processing {total} words...")
for i, word in enumerate(words):
if (word, "1") in checkpoint:
total_skipped += 1
continue
categories = ", ".join(vocab[word]["categories"])
existing = format_existing_edges(edges, word)
user_prompt = PHASE1_USER.format(
word=word, categories=categories, existing_edges=existing
)
messages = [
{"role": "system", "content": PHASE1_SYSTEM},
{"role": "user", "content": user_prompt},
]
if dry_run:
if i < 3: # Show first 3 prompts
print(f"\n--- Prompt for '{word}' ---")
print(f"System: {PHASE1_SYSTEM[:200]}...")
print(f"User:\n{user_prompt}")
elif i == 3:
print(f"\n... ({total - 3} more words) ...")
continue
response = llm_chat_completion(messages)
parsed = parse_llm_relations(response, word) if response else []
# Classify edges
accepted = []
candidates = []
duplicates = 0
for edge in parsed:
triple = (edge["start_word"], edge["end_word"], edge["relation"])
if triple in existing_triples:
duplicates += 1
continue
existing_triples.add(triple)
if edge["end_word"] in vocab:
accepted.append(edge)
else:
candidates.append({
"word": edge["end_word"],
"suggested_by": word,
"relation_context": f"{edge['relation']}: {edge['surface_text']}",
"frequency": 1,
})
if accepted:
append_augmented_edges(accepted)
# Also update in-memory edges for subsequent words
for e in accepted:
edges[(e["start_word"], e["relation"])].append(
(e["end_word"], e["weight"], e["surface_text"]))
if candidates:
append_candidates(candidates)
total_accepted += len(accepted)
append_log(word, "1", len(parsed), len(accepted), duplicates, len(candidates))
if (i + 1) % 50 == 0:
print(f" [{i+1}/{total}] {total_accepted} edges accepted so far")
time.sleep(0.1)
if dry_run:
print(f"\nDry run complete. Would process {total - total_skipped} words.")
else:
print(f"\nPhase 1 complete: {total_accepted} new edges accepted.")
# ---------------------------------------------------------------------------
# Phase 2: Cross-Word Bridge Discovery
# ---------------------------------------------------------------------------
PHASE2_SYSTEM = """You are a commonsense knowledge annotator. You will be given two concrete nouns. Your job is to identify a BRIDGE word that connects them — something that relates to both.
Rules:
- The bridge word must be a common, concrete noun
- State the relationship type for each connection
- Valid relationship types: AtLocation, UsedFor, HasA, PartOf, CapableOf, MadeOf, HasPrerequisite, Causes, HasProperty, ReceivesAction, Desires, CausesDesire, LocatedNear, CreatedBy
- Output format: BRIDGE_WORD | relation_to_first: TYPE | relation_to_second: TYPE | explanation
Example:
Words: "cow" and "butter"
milk | relation_to_first: CapableOf | relation_to_second: MadeOf | milk connects production to product"""
PHASE2_USER = """Words: "{word_a}" and "{word_b}"
Categories: {word_a} is {categories_a}, {word_b} is {categories_b}
Find 1-3 bridge words that connect them."""
def build_reachability(vocab, edges):
"""Build 2-hop reachability from vocab words to other vocab words."""
vocab_set = set(vocab.keys())
reachable = defaultdict(set) # word -> set of reachable vocab words
for word in vocab:
# Direct (1-hop) neighbors in vocab
for (sw, rel), targets in edges.items():
if sw == word:
for (ew, w, s) in targets:
if ew in vocab_set and ew != word:
reachable[word].add(ew)
# 2-hop from this neighbor
for (sw2, rel2), targets2 in edges.items():
if sw2 == ew:
for (ew2, w2, s2) in targets2:
if ew2 in vocab_set and ew2 != word:
reachable[word].add(ew2)
return reachable
def run_phase2(vocab, edges, existing_triples, checkpoint, dry_run=False):
"""Phase 2: Cross-word bridge discovery."""
print("Phase 2: Building reachability matrix...")
reachable = build_reachability(vocab, edges)
# Find low-connectivity words
vocab_set = set(vocab.keys())
low_connectivity = []
for word in vocab:
reach_count = len(reachable.get(word, set()))
if reach_count < 10:
low_connectivity.append((word, reach_count))
low_connectivity.sort(key=lambda x: x[1])
print(f" {len(low_connectivity)} words with <10 reachable vocab words")
# Build category index
by_category = defaultdict(list)
for word, info in vocab.items():
for cat in info["categories"]:
by_category[cat].append(word)
total_accepted = 0
pairs_processed = 0
total_skipped = 0
for word, reach_count in low_connectivity:
if (word, "2") in checkpoint:
total_skipped += 1
continue
word_cats = vocab[word]["categories"]
word_reachable = reachable.get(word, set())
# Find same-category words that are unreachable
unreachable = []
for cat in word_cats:
for peer in by_category.get(cat, []):
if peer != word and peer not in word_reachable:
unreachable.append(peer)
if not unreachable:
append_log(word, "2", 0, 0, 0, 0)
continue
# Sample 5-10 unreachable peers
sample = random.sample(unreachable, min(10, len(unreachable)))
accepted_for_word = 0
for peer in sample:
pair_key = f"{word}:{peer}"
if (pair_key, "2") in checkpoint:
continue
categories_a = ", ".join(vocab[word]["categories"])
categories_b = ", ".join(vocab[peer]["categories"])
user_prompt = PHASE2_USER.format(
word_a=word, word_b=peer,
categories_a=categories_a, categories_b=categories_b,
)
messages = [
{"role": "system", "content": PHASE2_SYSTEM},
{"role": "user", "content": user_prompt},
]
if dry_run:
if pairs_processed < 3:
print(f"\n--- Bridge prompt: '{word}' <-> '{peer}' ---")
print(f"User:\n{user_prompt}")
elif pairs_processed == 3:
print(f"\n... (more pairs) ...")
pairs_processed += 1
continue
response = llm_chat_completion(messages)
parsed = parse_bridge_response(response, word, peer) if response else []
accepted = []
duplicates = 0
oov = 0
for edge in parsed:
triple = (edge["start_word"], edge["end_word"], edge["relation"])
if triple in existing_triples:
duplicates += 1
continue
existing_triples.add(triple)
# For bridge edges, both endpoints should ideally be in vocab
if edge["start_word"] in vocab_set and edge["end_word"] in vocab_set:
accepted.append(edge)
elif edge["start_word"] in vocab_set or edge["end_word"] in vocab_set:
# At least one end in vocab — still useful
accepted.append(edge)
else:
oov += 1
if accepted:
append_augmented_edges(accepted)
for e in accepted:
edges[(e["start_word"], e["relation"])].append(
(e["end_word"], e["weight"], e["surface_text"]))
accepted_for_word += len(accepted)
pairs_processed += 1
time.sleep(0.1)
total_accepted += accepted_for_word
append_log(word, "2", 0, accepted_for_word, 0, 0)
if (pairs_processed) % 20 == 0:
print(f" {pairs_processed} pairs processed, {total_accepted} edges accepted")
if dry_run:
print(f"\nDry run complete. Would process {pairs_processed} word pairs.")
else:
print(f"\nPhase 2 complete: {total_accepted} bridge edges accepted from {pairs_processed} pairs.")
# ---------------------------------------------------------------------------
# Phase 3: Property Enrichment
# ---------------------------------------------------------------------------
PHASE3_SYSTEM = """You are a commonsense knowledge annotator. Given a concrete noun, list its most distinctive physical or sensory properties — things you could see, touch, hear, smell, or taste. Also list behavioral properties for animals.
Rules:
- Only physical/sensory/behavioral properties, not abstract qualities
- Properties should DISTINGUISH this thing from similar things in its category
- Output one property per line as: PROPERTY | brief explanation
- Aim for 5-8 properties"""
PHASE3_USER = """Word: {word}
Category: {categories}
Other words in same category: {peers}
What properties distinguish {word} from the others listed?"""
def run_phase3(vocab, edges, existing_triples, checkpoint, dry_run=False):
"""Phase 3: Property enrichment for false_equivalence templates."""
by_category = defaultdict(list)
for word, info in vocab.items():
for cat in info["categories"]:
by_category[cat].append(word)
words = sorted(vocab.keys())
total = len(words)
total_accepted = 0
total_skipped = 0
print(f"Phase 3: Property enrichment for {total} words...")
for i, word in enumerate(words):
if (word, "3") in checkpoint:
total_skipped += 1
continue
word_cats = vocab[word]["categories"]
categories = ", ".join(word_cats)
# Gather same-category peers (sample of 10)
peers = set()
for cat in word_cats:
for peer in by_category.get(cat, []):
if peer != word:
peers.add(peer)
peer_sample = random.sample(list(peers), min(10, len(peers))) if peers else []
if not peer_sample:
append_log(word, "3", 0, 0, 0, 0)
continue
user_prompt = PHASE3_USER.format(
word=word, categories=categories,
peers=", ".join(peer_sample),
)
messages = [
{"role": "system", "content": PHASE3_SYSTEM},
{"role": "user", "content": user_prompt},
]
if dry_run:
if i < 3:
print(f"\n--- Property prompt for '{word}' ---")
print(f"User:\n{user_prompt}")
elif i == 3:
print(f"\n... ({total - 3} more words) ...")
continue
response = llm_chat_completion(messages)
parsed = parse_property_response(response, word) if response else []
accepted = []
duplicates = 0
for edge in parsed:
triple = (edge["start_word"], edge["end_word"], edge["relation"])
if triple in existing_triples:
duplicates += 1
continue
existing_triples.add(triple)
accepted.append(edge)
if accepted:
append_augmented_edges(accepted)
for e in accepted:
edges[(e["start_word"], e["relation"])].append(
(e["end_word"], e["weight"], e["surface_text"]))
total_accepted += len(accepted)
append_log(word, "3", len(parsed), len(accepted), duplicates, 0)
if (i + 1) % 50 == 0:
print(f" [{i+1}/{total}] {total_accepted} properties accepted so far")
time.sleep(0.1)
if dry_run:
print(f"\nDry run complete. Would process {total - total_skipped} words.")
else:
print(f"\nPhase 3 complete: {total_accepted} new HasProperty edges accepted.")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="LLM-augmented graph enhancement for folksy subgraph."
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--phase", type=int, choices=[1, 2, 3],
help="Run a specific phase (1, 2, or 3)")
group.add_argument("--all", action="store_true",
help="Run all three phases in sequence")
parser.add_argument("--dry-run", action="store_true",
help="Print prompts without calling LLM")
args = parser.parse_args()
vocab = load_vocab()
edges, existing_triples = load_relations()
checkpoint = load_checkpoint()
print(f"Loaded {len(vocab)} vocab words, {len(existing_triples)} existing edge triples.")
print(f"Checkpoint: {len(checkpoint)} (word, phase) pairs already processed.")
phases = [args.phase] if args.phase else [1, 2, 3]
for phase in phases:
print(f"\n{'='*60}")
print(f"Running Phase {phase}")
print(f"{'='*60}")
if phase == 1:
run_phase1(vocab, edges, existing_triples, checkpoint, args.dry_run)
elif phase == 2:
run_phase2(vocab, edges, existing_triples, checkpoint, args.dry_run)
elif phase == 3:
run_phase3(vocab, edges, existing_triples, checkpoint, args.dry_run)
# Reload checkpoint after each phase for resumability
checkpoint = load_checkpoint()
print("\nDone.")
if __name__ == "__main__":
main()