787 lines
27 KiB
Python
787 lines
27 KiB
Python
#!/usr/bin/env python3
|
|
"""LLM-augmented graph enhancement for the folksy subgraph.
|
|
|
|
Three phases:
|
|
Phase 1: Per-word relationship expansion
|
|
Phase 2: Cross-word bridge discovery
|
|
Phase 3: Property enrichment for false_equivalence templates
|
|
|
|
Usage:
|
|
python scripts/enhance_graph.py --phase 1 # Run phase 1 only
|
|
python scripts/enhance_graph.py --phase 2 # Run phase 2 only
|
|
python scripts/enhance_graph.py --phase 3 # Run phase 3 only
|
|
python scripts/enhance_graph.py --all # Run all phases
|
|
python scripts/enhance_graph.py --phase 1 --dry-run # Print prompts without calling LLM
|
|
"""
|
|
|
|
import argparse
|
|
import csv
|
|
import os
|
|
import random
|
|
import re
|
|
import sys
|
|
import time
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
# Paths
|
|
SCRIPT_DIR = Path(__file__).parent
|
|
PROJECT_DIR = SCRIPT_DIR.parent
|
|
DATA_DIR = PROJECT_DIR / "data"
|
|
|
|
LLM_ENDPOINT = "http://192.168.1.100:8853/v1d/chat/completions"
|
|
LLM_MODEL = "THUDM-GLM4-32B"
|
|
|
|
VALID_RELATIONS = {
|
|
"AtLocation", "MadeOf", "PartOf", "UsedFor", "HasA", "HasProperty",
|
|
"Causes", "HasPrerequisite", "CapableOf", "ReceivesAction", "Desires",
|
|
"CausesDesire", "LocatedNear", "CreatedBy", "MotivatedByGoal", "HasSubevent",
|
|
}
|
|
|
|
AUGMENTED_CSV = DATA_DIR / "folksy_relations_augmented.csv"
|
|
CANDIDATE_CSV = DATA_DIR / "candidate_additions.csv"
|
|
LOG_CSV = DATA_DIR / "enhancement_log.csv"
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Infrastructure
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def llm_chat_completion(messages, max_retries=3):
|
|
"""Chat completion with retry logic."""
|
|
import requests
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
resp = requests.post(LLM_ENDPOINT, json={
|
|
"model": LLM_MODEL,
|
|
"messages": messages,
|
|
}, timeout=120)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data["choices"][0]["message"]["content"]
|
|
except Exception as e:
|
|
wait = (2 ** attempt)
|
|
print(f" LLM call failed (attempt {attempt+1}/{max_retries}): {e}", file=sys.stderr)
|
|
if attempt < max_retries - 1:
|
|
print(f" Retrying in {wait}s...", file=sys.stderr)
|
|
time.sleep(wait)
|
|
else:
|
|
print(f" Giving up on this word.", file=sys.stderr)
|
|
return None
|
|
|
|
|
|
def load_vocab():
|
|
"""Load folksy vocabulary."""
|
|
vocab = {}
|
|
with open(DATA_DIR / "folksy_vocab.csv", newline="", encoding="utf-8") as f:
|
|
for row in csv.DictReader(f):
|
|
word = row["word"]
|
|
cats = [c.strip() for c in row["categories"].split(",") if c.strip()]
|
|
vocab[word] = {
|
|
"categories": cats,
|
|
"tangibility": float(row.get("tangibility_score", 0)),
|
|
"edge_count": int(row.get("conceptnet_edge_count", 0)),
|
|
}
|
|
return vocab
|
|
|
|
|
|
def load_relations():
|
|
"""Load existing relations (ConceptNet + any existing augmented)."""
|
|
edges = defaultdict(list) # (start, relation) -> [(end, weight, surface)]
|
|
existing_triples = set() # (start, end, relation) for dedup
|
|
|
|
for path in [DATA_DIR / "folksy_relations.csv", AUGMENTED_CSV]:
|
|
if not path.exists():
|
|
continue
|
|
with open(path, newline="", encoding="utf-8") as f:
|
|
for row in csv.DictReader(f):
|
|
sw = row["start_word"]
|
|
ew = row["end_word"]
|
|
rel = row["relation"]
|
|
if not row['weight']: continue # corruption / skip?
|
|
w = float(row["weight"])
|
|
surf = row.get("surface_text", "")
|
|
edges[(sw, rel)].append((ew, w, surf))
|
|
existing_triples.add((sw, ew, rel))
|
|
|
|
return edges, existing_triples
|
|
|
|
|
|
def load_checkpoint():
|
|
"""Load enhancement log to determine what's already been processed."""
|
|
processed = set() # (word, phase)
|
|
if LOG_CSV.exists():
|
|
with open(LOG_CSV, newline="", encoding="utf-8") as f:
|
|
for row in csv.DictReader(f):
|
|
processed.add((row["source_word"], row["phase"]))
|
|
return processed
|
|
|
|
|
|
def append_log(word, phase, edges_generated, edges_accepted, edges_duplicate, edges_oov):
|
|
"""Append a row to the enhancement log."""
|
|
write_header = not LOG_CSV.exists()
|
|
with open(LOG_CSV, "a", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
if write_header:
|
|
writer.writerow(["source_word", "phase", "timestamp",
|
|
"edges_generated", "edges_accepted", "edges_duplicate", "edges_oov"])
|
|
writer.writerow([word, phase, datetime.now().isoformat(),
|
|
edges_generated, edges_accepted, edges_duplicate, edges_oov])
|
|
|
|
|
|
def append_augmented_edges(edges):
|
|
"""Append edges to the augmented relations CSV."""
|
|
write_header = not AUGMENTED_CSV.exists()
|
|
with open(AUGMENTED_CSV, "a", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
if write_header:
|
|
writer.writerow(["start_word", "end_word", "relation", "weight", "surface_text", "source"])
|
|
for e in edges:
|
|
writer.writerow([e["start_word"], e["end_word"], e["relation"],
|
|
e["weight"], e["surface_text"], e["source"]])
|
|
|
|
|
|
def append_candidates(candidates):
|
|
"""Append candidate words to the candidate additions CSV."""
|
|
write_header = not CANDIDATE_CSV.exists()
|
|
with open(CANDIDATE_CSV, "a", newline="", encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
if write_header:
|
|
writer.writerow(["word", "suggested_by", "relation_context", "frequency"])
|
|
for c in candidates:
|
|
writer.writerow([c["word"], c["suggested_by"], c["relation_context"], c["frequency"]])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def parse_llm_relations(response_text, source_word):
|
|
"""Parse structured LLM output into edge dicts.
|
|
|
|
Handles bullets, numbering, extra whitespace, multi-word targets.
|
|
"""
|
|
edges = []
|
|
if not response_text:
|
|
return edges
|
|
|
|
for line in response_text.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
# Strip leading bullets/numbers: "- ", "1. ", "* ", etc.
|
|
line = re.sub(r"^[\d]+[.)]\s*", "", line)
|
|
line = re.sub(r"^[-*•]\s*", "", line)
|
|
line = line.strip()
|
|
|
|
if not line or "NONE" in line.upper():
|
|
continue
|
|
|
|
# Match: RELATION_TYPE: target_word(s) | surface text
|
|
match = re.match(r"^(\w+):\s*(.+?)\s*\|\s*(.+)$", line)
|
|
if not match:
|
|
continue
|
|
|
|
relation, target_raw, surface = match.groups()
|
|
relation = relation.strip()
|
|
|
|
if relation not in VALID_RELATIONS:
|
|
continue
|
|
|
|
# Normalize target: lowercase, replace spaces with underscores for multi-word
|
|
target = target_raw.strip().lower()
|
|
target = re.sub(r"\s+", "_", target)
|
|
|
|
# Skip self-loops
|
|
if target == source_word:
|
|
continue
|
|
|
|
edges.append({
|
|
"start_word": source_word,
|
|
"end_word": target,
|
|
"relation": relation,
|
|
"weight": 0.8,
|
|
"surface_text": surface.strip(),
|
|
"source": "llm_augmented",
|
|
})
|
|
|
|
return edges
|
|
|
|
|
|
def parse_bridge_response(response_text, word_a, word_b):
|
|
"""Parse bridge discovery LLM output."""
|
|
edges = []
|
|
if not response_text:
|
|
return edges
|
|
|
|
for line in response_text.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
# Strip common prefixes
|
|
line = re.sub(r"^[\d]+[.)]\s*", "", line)
|
|
line = re.sub(r"^[-*•]\s*", "", line)
|
|
line = re.sub(r"^BRIDGE:\s*", "", line, flags=re.IGNORECASE)
|
|
line = line.strip()
|
|
|
|
if not line:
|
|
continue
|
|
|
|
# BRIDGE_WORD | relation_to_first: TYPE | relation_to_second: TYPE | explanation
|
|
parts = [p.strip() for p in line.split("|")]
|
|
if len(parts) < 3:
|
|
continue
|
|
|
|
bridge_word = parts[0].strip().lower().replace(" ", "_")
|
|
|
|
# Parse relation_to_first
|
|
rel1_match = re.search(r"(?:relation_to_first|first):\s*(\w+)", parts[1], re.IGNORECASE)
|
|
rel2_match = re.search(r"(?:relation_to_second|second):\s*(\w+)", parts[2], re.IGNORECASE)
|
|
|
|
if not rel1_match or not rel2_match:
|
|
# Try simpler format: just the relation type
|
|
rel1_match = re.match(r"(\w+)", parts[1].split(":")[-1].strip())
|
|
rel2_match = re.match(r"(\w+)", parts[2].split(":")[-1].strip())
|
|
|
|
if not rel1_match or not rel2_match:
|
|
continue
|
|
|
|
rel1 = rel1_match.group(1)
|
|
rel2 = rel2_match.group(1)
|
|
|
|
if rel1 not in VALID_RELATIONS or rel2 not in VALID_RELATIONS:
|
|
continue
|
|
|
|
explanation = parts[3].strip() if len(parts) > 3 else ""
|
|
|
|
# Create edges: word_a -> bridge and bridge -> word_b
|
|
edges.append({
|
|
"start_word": word_a,
|
|
"end_word": bridge_word,
|
|
"relation": rel1,
|
|
"weight": 0.8,
|
|
"surface_text": explanation,
|
|
"source": "llm_bridge",
|
|
})
|
|
edges.append({
|
|
"start_word": bridge_word,
|
|
"end_word": word_b,
|
|
"relation": rel2,
|
|
"weight": 0.8,
|
|
"surface_text": explanation,
|
|
"source": "llm_bridge",
|
|
})
|
|
|
|
return edges
|
|
|
|
|
|
def parse_property_response(response_text, word):
|
|
"""Parse property enrichment LLM output."""
|
|
edges = []
|
|
if not response_text:
|
|
return edges
|
|
|
|
for line in response_text.strip().split("\n"):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
line = re.sub(r"^[\d]+[.)]\s*", "", line)
|
|
line = re.sub(r"^[-*•]\s*", "", line)
|
|
line = line.strip()
|
|
|
|
if not line:
|
|
continue
|
|
|
|
# PROPERTY | explanation
|
|
parts = [p.strip() for p in line.split("|")]
|
|
if len(parts) < 1:
|
|
continue
|
|
|
|
prop = parts[0].strip().lower().replace(" ", "_")
|
|
explanation = parts[1].strip() if len(parts) > 1 else f"{word} is {prop}"
|
|
|
|
if not prop or prop == word:
|
|
continue
|
|
|
|
edges.append({
|
|
"start_word": word,
|
|
"end_word": prop,
|
|
"relation": "HasProperty",
|
|
"weight": 0.8,
|
|
"surface_text": explanation,
|
|
"source": "llm_property",
|
|
})
|
|
|
|
return edges
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 1: Per-Word Expansion
|
|
# ---------------------------------------------------------------------------
|
|
|
|
PHASE1_SYSTEM = """You are a commonsense knowledge annotator. You will be given a concrete noun and its known relationships. Your job is to generate ADDITIONAL commonsense relationships that are missing.
|
|
|
|
Rules:
|
|
- Only generate relationships involving concrete, tangible things (animals, foods, tools, plants, buildings, weather, landscape, household objects)
|
|
- Every relationship must be something a typical adult would agree is true
|
|
- Do not repeat any relationship already listed as "known"
|
|
- Target words should be common English words (top 3000 frequency preferred)
|
|
- Output ONLY the structured format shown below, one relationship per line
|
|
- If you cannot think of good relationships for a given type, output NONE for that type
|
|
- Aim for 3-5 relationships per type where possible
|
|
|
|
Output format (one per line):
|
|
RELATION_TYPE: target_word | short natural phrasing
|
|
|
|
Example output:
|
|
AtLocation: barn | you find a horse in a barn
|
|
UsedFor: riding | a horse is used for riding
|
|
HasA: mane | a horse has a mane
|
|
CapableOf: gallop | a horse can gallop
|
|
MadeOf: NONE
|
|
PartOf: herd | a horse is part of a herd"""
|
|
|
|
|
|
PHASE1_USER = """Word: {word}
|
|
Categories: {categories}
|
|
|
|
Known relationships:
|
|
{existing_edges}
|
|
|
|
Generate additional relationships for these types:
|
|
- AtLocation (where is it found?)
|
|
- UsedFor (what is it used for?)
|
|
- HasA (what does it have / contain?)
|
|
- PartOf (what is it part of?)
|
|
- CapableOf (what can it do?)
|
|
- MadeOf (what is it made of?)
|
|
- HasPrerequisite (what do you need before you can have/use it?)
|
|
- Causes (what does it cause or lead to?)
|
|
- HasProperty (what adjectives describe it? — limit to physical/sensory properties)"""
|
|
|
|
|
|
def format_existing_edges(edges_dict, word):
|
|
"""Format existing edges for a word grouped by relation type."""
|
|
relation_types = ["AtLocation", "UsedFor", "HasA", "PartOf", "CapableOf",
|
|
"MadeOf", "HasPrerequisite", "Causes", "HasProperty"]
|
|
|
|
lines = []
|
|
for rel in relation_types:
|
|
targets = edges_dict.get((word, rel), [])
|
|
if targets:
|
|
formatted = ", ".join(f"{t[0]} (weight {t[1]:.1f})" for t in targets[:10])
|
|
lines.append(f"{rel}: {formatted}")
|
|
else:
|
|
lines.append(f"{rel}: (none in database)")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def run_phase1(vocab, edges, existing_triples, checkpoint, dry_run=False):
|
|
"""Phase 1: Per-word relationship expansion."""
|
|
words = sorted(vocab.keys())
|
|
total = len(words)
|
|
total_accepted = 0
|
|
total_skipped = 0
|
|
|
|
print(f"Phase 1: Processing {total} words...")
|
|
|
|
for i, word in enumerate(words):
|
|
if (word, "1") in checkpoint:
|
|
total_skipped += 1
|
|
continue
|
|
|
|
categories = ", ".join(vocab[word]["categories"])
|
|
existing = format_existing_edges(edges, word)
|
|
|
|
user_prompt = PHASE1_USER.format(
|
|
word=word, categories=categories, existing_edges=existing
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": PHASE1_SYSTEM},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
if dry_run:
|
|
if i < 3: # Show first 3 prompts
|
|
print(f"\n--- Prompt for '{word}' ---")
|
|
print(f"System: {PHASE1_SYSTEM[:200]}...")
|
|
print(f"User:\n{user_prompt}")
|
|
elif i == 3:
|
|
print(f"\n... ({total - 3} more words) ...")
|
|
continue
|
|
|
|
response = llm_chat_completion(messages)
|
|
parsed = parse_llm_relations(response, word) if response else []
|
|
|
|
# Classify edges
|
|
accepted = []
|
|
candidates = []
|
|
duplicates = 0
|
|
|
|
for edge in parsed:
|
|
triple = (edge["start_word"], edge["end_word"], edge["relation"])
|
|
if triple in existing_triples:
|
|
duplicates += 1
|
|
continue
|
|
|
|
existing_triples.add(triple)
|
|
|
|
if edge["end_word"] in vocab:
|
|
accepted.append(edge)
|
|
else:
|
|
candidates.append({
|
|
"word": edge["end_word"],
|
|
"suggested_by": word,
|
|
"relation_context": f"{edge['relation']}: {edge['surface_text']}",
|
|
"frequency": 1,
|
|
})
|
|
|
|
if accepted:
|
|
append_augmented_edges(accepted)
|
|
# Also update in-memory edges for subsequent words
|
|
for e in accepted:
|
|
edges[(e["start_word"], e["relation"])].append(
|
|
(e["end_word"], e["weight"], e["surface_text"]))
|
|
|
|
if candidates:
|
|
append_candidates(candidates)
|
|
|
|
total_accepted += len(accepted)
|
|
|
|
append_log(word, "1", len(parsed), len(accepted), duplicates, len(candidates))
|
|
|
|
if (i + 1) % 50 == 0:
|
|
print(f" [{i+1}/{total}] {total_accepted} edges accepted so far")
|
|
|
|
time.sleep(0.1)
|
|
|
|
if dry_run:
|
|
print(f"\nDry run complete. Would process {total - total_skipped} words.")
|
|
else:
|
|
print(f"\nPhase 1 complete: {total_accepted} new edges accepted.")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 2: Cross-Word Bridge Discovery
|
|
# ---------------------------------------------------------------------------
|
|
|
|
PHASE2_SYSTEM = """You are a commonsense knowledge annotator. You will be given two concrete nouns. Your job is to identify a BRIDGE word that connects them — something that relates to both.
|
|
|
|
Rules:
|
|
- The bridge word must be a common, concrete noun
|
|
- State the relationship type for each connection
|
|
- Valid relationship types: AtLocation, UsedFor, HasA, PartOf, CapableOf, MadeOf, HasPrerequisite, Causes, HasProperty, ReceivesAction, Desires, CausesDesire, LocatedNear, CreatedBy
|
|
- Output format: BRIDGE_WORD | relation_to_first: TYPE | relation_to_second: TYPE | explanation
|
|
|
|
Example:
|
|
Words: "cow" and "butter"
|
|
milk | relation_to_first: CapableOf | relation_to_second: MadeOf | milk connects production to product"""
|
|
|
|
|
|
PHASE2_USER = """Words: "{word_a}" and "{word_b}"
|
|
Categories: {word_a} is {categories_a}, {word_b} is {categories_b}
|
|
Find 1-3 bridge words that connect them."""
|
|
|
|
|
|
def build_reachability(vocab, edges):
|
|
"""Build 2-hop reachability from vocab words to other vocab words."""
|
|
vocab_set = set(vocab.keys())
|
|
reachable = defaultdict(set) # word -> set of reachable vocab words
|
|
|
|
for word in vocab:
|
|
# Direct (1-hop) neighbors in vocab
|
|
for (sw, rel), targets in edges.items():
|
|
if sw == word:
|
|
for (ew, w, s) in targets:
|
|
if ew in vocab_set and ew != word:
|
|
reachable[word].add(ew)
|
|
# 2-hop from this neighbor
|
|
for (sw2, rel2), targets2 in edges.items():
|
|
if sw2 == ew:
|
|
for (ew2, w2, s2) in targets2:
|
|
if ew2 in vocab_set and ew2 != word:
|
|
reachable[word].add(ew2)
|
|
|
|
return reachable
|
|
|
|
|
|
def run_phase2(vocab, edges, existing_triples, checkpoint, dry_run=False):
|
|
"""Phase 2: Cross-word bridge discovery."""
|
|
print("Phase 2: Building reachability matrix...")
|
|
reachable = build_reachability(vocab, edges)
|
|
|
|
# Find low-connectivity words
|
|
vocab_set = set(vocab.keys())
|
|
low_connectivity = []
|
|
for word in vocab:
|
|
reach_count = len(reachable.get(word, set()))
|
|
if reach_count < 10:
|
|
low_connectivity.append((word, reach_count))
|
|
|
|
low_connectivity.sort(key=lambda x: x[1])
|
|
print(f" {len(low_connectivity)} words with <10 reachable vocab words")
|
|
|
|
# Build category index
|
|
by_category = defaultdict(list)
|
|
for word, info in vocab.items():
|
|
for cat in info["categories"]:
|
|
by_category[cat].append(word)
|
|
|
|
total_accepted = 0
|
|
pairs_processed = 0
|
|
total_skipped = 0
|
|
|
|
for word, reach_count in low_connectivity:
|
|
if (word, "2") in checkpoint:
|
|
total_skipped += 1
|
|
continue
|
|
|
|
word_cats = vocab[word]["categories"]
|
|
word_reachable = reachable.get(word, set())
|
|
|
|
# Find same-category words that are unreachable
|
|
unreachable = []
|
|
for cat in word_cats:
|
|
for peer in by_category.get(cat, []):
|
|
if peer != word and peer not in word_reachable:
|
|
unreachable.append(peer)
|
|
|
|
if not unreachable:
|
|
append_log(word, "2", 0, 0, 0, 0)
|
|
continue
|
|
|
|
# Sample 5-10 unreachable peers
|
|
sample = random.sample(unreachable, min(10, len(unreachable)))
|
|
|
|
accepted_for_word = 0
|
|
|
|
for peer in sample:
|
|
pair_key = f"{word}:{peer}"
|
|
if (pair_key, "2") in checkpoint:
|
|
continue
|
|
|
|
categories_a = ", ".join(vocab[word]["categories"])
|
|
categories_b = ", ".join(vocab[peer]["categories"])
|
|
|
|
user_prompt = PHASE2_USER.format(
|
|
word_a=word, word_b=peer,
|
|
categories_a=categories_a, categories_b=categories_b,
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": PHASE2_SYSTEM},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
if dry_run:
|
|
if pairs_processed < 3:
|
|
print(f"\n--- Bridge prompt: '{word}' <-> '{peer}' ---")
|
|
print(f"User:\n{user_prompt}")
|
|
elif pairs_processed == 3:
|
|
print(f"\n... (more pairs) ...")
|
|
pairs_processed += 1
|
|
continue
|
|
|
|
response = llm_chat_completion(messages)
|
|
parsed = parse_bridge_response(response, word, peer) if response else []
|
|
|
|
accepted = []
|
|
duplicates = 0
|
|
oov = 0
|
|
|
|
for edge in parsed:
|
|
triple = (edge["start_word"], edge["end_word"], edge["relation"])
|
|
if triple in existing_triples:
|
|
duplicates += 1
|
|
continue
|
|
existing_triples.add(triple)
|
|
|
|
# For bridge edges, both endpoints should ideally be in vocab
|
|
if edge["start_word"] in vocab_set and edge["end_word"] in vocab_set:
|
|
accepted.append(edge)
|
|
elif edge["start_word"] in vocab_set or edge["end_word"] in vocab_set:
|
|
# At least one end in vocab — still useful
|
|
accepted.append(edge)
|
|
else:
|
|
oov += 1
|
|
|
|
if accepted:
|
|
append_augmented_edges(accepted)
|
|
for e in accepted:
|
|
edges[(e["start_word"], e["relation"])].append(
|
|
(e["end_word"], e["weight"], e["surface_text"]))
|
|
accepted_for_word += len(accepted)
|
|
|
|
pairs_processed += 1
|
|
time.sleep(0.1)
|
|
|
|
total_accepted += accepted_for_word
|
|
append_log(word, "2", 0, accepted_for_word, 0, 0)
|
|
|
|
if (pairs_processed) % 20 == 0:
|
|
print(f" {pairs_processed} pairs processed, {total_accepted} edges accepted")
|
|
|
|
if dry_run:
|
|
print(f"\nDry run complete. Would process {pairs_processed} word pairs.")
|
|
else:
|
|
print(f"\nPhase 2 complete: {total_accepted} bridge edges accepted from {pairs_processed} pairs.")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase 3: Property Enrichment
|
|
# ---------------------------------------------------------------------------
|
|
|
|
PHASE3_SYSTEM = """You are a commonsense knowledge annotator. Given a concrete noun, list its most distinctive physical or sensory properties — things you could see, touch, hear, smell, or taste. Also list behavioral properties for animals.
|
|
|
|
Rules:
|
|
- Only physical/sensory/behavioral properties, not abstract qualities
|
|
- Properties should DISTINGUISH this thing from similar things in its category
|
|
- Output one property per line as: PROPERTY | brief explanation
|
|
- Aim for 5-8 properties"""
|
|
|
|
|
|
PHASE3_USER = """Word: {word}
|
|
Category: {categories}
|
|
Other words in same category: {peers}
|
|
|
|
What properties distinguish {word} from the others listed?"""
|
|
|
|
|
|
def run_phase3(vocab, edges, existing_triples, checkpoint, dry_run=False):
|
|
"""Phase 3: Property enrichment for false_equivalence templates."""
|
|
by_category = defaultdict(list)
|
|
for word, info in vocab.items():
|
|
for cat in info["categories"]:
|
|
by_category[cat].append(word)
|
|
|
|
words = sorted(vocab.keys())
|
|
total = len(words)
|
|
total_accepted = 0
|
|
total_skipped = 0
|
|
|
|
print(f"Phase 3: Property enrichment for {total} words...")
|
|
|
|
for i, word in enumerate(words):
|
|
if (word, "3") in checkpoint:
|
|
total_skipped += 1
|
|
continue
|
|
|
|
word_cats = vocab[word]["categories"]
|
|
categories = ", ".join(word_cats)
|
|
|
|
# Gather same-category peers (sample of 10)
|
|
peers = set()
|
|
for cat in word_cats:
|
|
for peer in by_category.get(cat, []):
|
|
if peer != word:
|
|
peers.add(peer)
|
|
peer_sample = random.sample(list(peers), min(10, len(peers))) if peers else []
|
|
|
|
if not peer_sample:
|
|
append_log(word, "3", 0, 0, 0, 0)
|
|
continue
|
|
|
|
user_prompt = PHASE3_USER.format(
|
|
word=word, categories=categories,
|
|
peers=", ".join(peer_sample),
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": PHASE3_SYSTEM},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
|
|
if dry_run:
|
|
if i < 3:
|
|
print(f"\n--- Property prompt for '{word}' ---")
|
|
print(f"User:\n{user_prompt}")
|
|
elif i == 3:
|
|
print(f"\n... ({total - 3} more words) ...")
|
|
continue
|
|
|
|
response = llm_chat_completion(messages)
|
|
parsed = parse_property_response(response, word) if response else []
|
|
|
|
accepted = []
|
|
duplicates = 0
|
|
|
|
for edge in parsed:
|
|
triple = (edge["start_word"], edge["end_word"], edge["relation"])
|
|
if triple in existing_triples:
|
|
duplicates += 1
|
|
continue
|
|
existing_triples.add(triple)
|
|
accepted.append(edge)
|
|
|
|
if accepted:
|
|
append_augmented_edges(accepted)
|
|
for e in accepted:
|
|
edges[(e["start_word"], e["relation"])].append(
|
|
(e["end_word"], e["weight"], e["surface_text"]))
|
|
|
|
total_accepted += len(accepted)
|
|
append_log(word, "3", len(parsed), len(accepted), duplicates, 0)
|
|
|
|
if (i + 1) % 50 == 0:
|
|
print(f" [{i+1}/{total}] {total_accepted} properties accepted so far")
|
|
|
|
time.sleep(0.1)
|
|
|
|
if dry_run:
|
|
print(f"\nDry run complete. Would process {total - total_skipped} words.")
|
|
else:
|
|
print(f"\nPhase 3 complete: {total_accepted} new HasProperty edges accepted.")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="LLM-augmented graph enhancement for folksy subgraph."
|
|
)
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument("--phase", type=int, choices=[1, 2, 3],
|
|
help="Run a specific phase (1, 2, or 3)")
|
|
group.add_argument("--all", action="store_true",
|
|
help="Run all three phases in sequence")
|
|
parser.add_argument("--dry-run", action="store_true",
|
|
help="Print prompts without calling LLM")
|
|
|
|
args = parser.parse_args()
|
|
|
|
vocab = load_vocab()
|
|
edges, existing_triples = load_relations()
|
|
checkpoint = load_checkpoint()
|
|
|
|
print(f"Loaded {len(vocab)} vocab words, {len(existing_triples)} existing edge triples.")
|
|
print(f"Checkpoint: {len(checkpoint)} (word, phase) pairs already processed.")
|
|
|
|
phases = [args.phase] if args.phase else [1, 2, 3]
|
|
|
|
for phase in phases:
|
|
print(f"\n{'='*60}")
|
|
print(f"Running Phase {phase}")
|
|
print(f"{'='*60}")
|
|
|
|
if phase == 1:
|
|
run_phase1(vocab, edges, existing_triples, checkpoint, args.dry_run)
|
|
elif phase == 2:
|
|
run_phase2(vocab, edges, existing_triples, checkpoint, args.dry_run)
|
|
elif phase == 3:
|
|
run_phase3(vocab, edges, existing_triples, checkpoint, args.dry_run)
|
|
|
|
# Reload checkpoint after each phase for resumability
|
|
checkpoint = load_checkpoint()
|
|
|
|
print("\nDone.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|