folksy_idioms/scripts/polish_corpus.py

400 lines
16 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""LLM polish pipeline for raw folksy sayings.
Reads corpus_raw.jsonl, sends each to GLM4-32B for polish.
Output file is the checkpoint append mode with resume detection.
Robust error handling:
- Context size errors: truncates chain data and retries
- JSON parse errors: retries, then marks as error
- Transient HTTP errors: exponential backoff retry
- Keyboard interrupt: flushes and exits cleanly
- Safe resume: skips entries already in output file
Usage:
python scripts/polish_corpus.py
python scripts/polish_corpus.py --input corpus/corpus_raw.jsonl --output corpus/corpus_polished.jsonl
"""
import argparse
import json
import sys
import time
from pathlib import Path
SCRIPT_DIR = Path(__file__).parent
PROJECT_DIR = SCRIPT_DIR.parent
CORPUS_DIR = PROJECT_DIR / "corpus"
LLM_ENDPOINT = "http://192.168.1.100:8853/v1d/chat/completions"
LLM_MODEL = "THUDM-GLM4-32B"
SYSTEM_PROMPT = """You are an editor specializing in folk sayings and rural proverbs. You will receive a rough draft of a fake folksy saying along with the relationship chain it encodes.
Your job:
1. Fix grammar, articles, and pluralization
2. Make it sound natural like something a weathered farmer would say while leaning on a fence post
3. Preserve the core nouns and the relationship between them do not swap out the key words
4. You MAY add small colorful details (adjectives, folksy verb choices, regional flavor) but keep it concise real proverbs are short
5. You MAY lightly restructure the sentence for better rhythm, but keep the same meaning pattern
6. If the saying is unsalvageable nonsense (the nouns don't relate in any meaningful way, or the combination is unintentionally offensive), respond with exactly: DISCARD
Output ONLY the polished saying on a single line. No quotes, no explanation, no preamble.
Examples of good polish:
Raw: "Don't build the coffee and act surprised when the water show up."
Chain: coffee MadeOf water
Polished: Don't brew the coffee and act surprised when the water's all gone.
Raw: "The chest's children always goes without hold books."
Chain: chest UsedFor hold_books
Polished: The bookshelf-maker's kids always end up reading off the floor.
Raw: "A pineapple is just a nectarine that's got an attitude."
Chain: pineapple IsA fruit, nectarine IsA fruit, pineapple HasProperty prickly
Polished: A pineapple is just a peach that grew itself some armor.
Raw: "You know what they say, a steel with no iron is just a harder than gold iron."
Chain: steel MadeOf iron, steel HasProperty hard
Polished: You know what they say steel without the iron is just a dream of being hard.
Raw: "Funny how the bamboo never has enough grow very quickly for itself."
Chain: bamboo CapableOf grow_quickly
Polished: DISCARD
Raw: "That's just funning the canoe and praying for boiling food."
Chain: canoe UsedFor transport, fire UsedFor boiling_food
Polished: DISCARD"""
class LLMError(Exception):
"""Base class for LLM errors."""
pass
class ContextTooLong(LLMError):
"""Prompt exceeded context window."""
pass
class TransientError(LLMError):
"""Recoverable error (network, server overload, etc.)."""
pass
def llm_chat_completion(messages, max_retries=3):
"""Chat completion with retry logic and error classification.
Returns (response_text, error_type) tuple.
response_text is None on failure; error_type is None on success.
"""
import requests
for attempt in range(max_retries):
try:
resp = requests.post(LLM_ENDPOINT, json={
"model": LLM_MODEL,
"messages": messages,
"temperature": 0.7,
}, timeout=120)
# Check for context length errors (HTTP 400 typically)
if resp.status_code == 400:
body = resp.text.lower()
if any(kw in body for kw in ["context", "token", "length", "too long", "exceed"]):
return None, "context_too_long"
# Other 400 errors — log and retry
print(f" HTTP 400 (attempt {attempt+1}): {resp.text[:200]}", file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return None, "http_400"
if resp.status_code == 503 or resp.status_code == 429:
wait = 2 ** (attempt + 1)
print(f" HTTP {resp.status_code} (attempt {attempt+1}), waiting {wait}s...",
file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(wait)
continue
return None, "server_overload"
resp.raise_for_status()
# Parse JSON response
try:
data = resp.json()
except (json.JSONDecodeError, ValueError) as e:
print(f" JSON parse error (attempt {attempt+1}): {e}", file=sys.stderr)
print(f" Response body: {resp.text[:300]}", file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return None, "json_parse"
# Extract content from response
try:
content = data["choices"][0]["message"]["content"]
if content is None:
print(f" Null content in response (attempt {attempt+1})", file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(1)
continue
return None, "null_content"
return content.strip(), None
except (KeyError, IndexError) as e:
print(f" Unexpected JSON structure (attempt {attempt+1}): {e}", file=sys.stderr)
print(f" Keys: {list(data.keys()) if isinstance(data, dict) else type(data)}",
file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(1)
continue
return None, "json_structure"
except requests.exceptions.Timeout:
wait = 2 ** (attempt + 1)
print(f" Timeout (attempt {attempt+1}), waiting {wait}s...", file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(wait)
continue
return None, "timeout"
except requests.exceptions.ConnectionError as e:
wait = 2 ** (attempt + 2) # longer wait for connection errors
print(f" Connection error (attempt {attempt+1}): {e}", file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(wait)
continue
return None, "connection"
except Exception as e:
print(f" Unexpected error (attempt {attempt+1}): {type(e).__name__}: {e}",
file=sys.stderr)
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return None, "unexpected"
return None, "exhausted_retries"
def format_chain(chain_edges, truncate=False):
"""Format chain_edges list into readable string for LLM context.
If truncate=True, omit weights and surface_text to reduce token count.
"""
if not chain_edges:
return "(no chain data)"
parts = []
for edge in chain_edges:
start = edge.get("start", "?")
rel = edge.get("relation", "?")
end = edge.get("end", "?")
if truncate:
parts.append(f"{start} --{rel}--> {end}")
else:
weight = edge.get("weight", 0)
parts.append(f"{start} --{rel}--> {end} (w:{weight:.1f})")
return ", ".join(parts)
def format_slots(slots):
"""Format slots dict for LLM context."""
return ", ".join(f"{k}={v}" for k, v in slots.items())
def build_messages(entry, truncate_chain=False):
"""Build the messages list for a single entry."""
raw_text = entry.get("raw_text", "")
meta_template = entry.get("meta_template", "")
chain = format_chain(entry.get("chain", []), truncate=truncate_chain)
slots = format_slots(entry.get("slots", {}))
user_prompt = (
f"Meta-template: {meta_template}\n"
f"Relationship chain: {chain}\n"
f"Slot fills: {slots}\n"
f"Raw saying: {raw_text}"
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
def load_already_processed(output_path):
"""Load set of raw_text strings already processed (for resume).
Also returns counts of each status for accurate progress reporting.
"""
processed = set()
counts = {"polished": 0, "discarded": 0, "error": 0}
if output_path.exists():
with open(output_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
processed.add(entry.get("raw_text", ""))
status = entry.get("status", "")
if status in counts:
counts[status] += 1
except json.JSONDecodeError:
continue
return processed, counts
def main():
parser = argparse.ArgumentParser(description="LLM polish pipeline for folksy sayings.")
parser.add_argument("--input", default=str(CORPUS_DIR / "corpus_raw.jsonl"),
help="Input JSONL file")
parser.add_argument("--output", default=str(CORPUS_DIR / "corpus_polished.jsonl"),
help="Output JSONL file (also serves as checkpoint)")
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
if not input_path.exists():
print(f"Error: {input_path} not found.", file=sys.stderr)
sys.exit(1)
# Load raw entries
raw_entries = []
with open(input_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
raw_entries.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Warning: skipping malformed input line: {e}", file=sys.stderr)
print(f"Loaded {len(raw_entries)} raw entries from {input_path}")
# Check what's already been processed
already_processed, prev_counts = load_already_processed(output_path)
remaining = [e for e in raw_entries if e.get("raw_text", "") not in already_processed]
print(f"Already processed: {len(already_processed)} "
f"(polished={prev_counts['polished']}, "
f"discarded={prev_counts['discarded']}, "
f"errors={prev_counts['error']})")
print(f"Remaining: {len(remaining)}")
if not remaining:
print("Nothing to process.")
return
discards = 0
polished = 0
errors = 0
error_types = {}
consecutive_errors = 0
start_time = time.time()
try:
with open(output_path, "a", encoding="utf-8") as out:
for i, entry in enumerate(remaining):
# First attempt with full chain data
messages = build_messages(entry, truncate_chain=False)
response, error_type = llm_chat_completion(messages)
# If context too long, retry with truncated chain
if error_type == "context_too_long":
print(f" #{i+1}: context too long, retrying with truncated chain...",
file=sys.stderr)
messages = build_messages(entry, truncate_chain=True)
response, error_type = llm_chat_completion(messages)
# If still too long, try with just the raw text
if error_type == "context_too_long":
print(f" #{i+1}: still too long, retrying with minimal prompt...",
file=sys.stderr)
raw_text = entry.get("raw_text", "")
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Raw saying: {raw_text}"},
]
response, error_type = llm_chat_completion(messages)
if response is None:
entry["status"] = "error"
entry["error_type"] = error_type or "unknown"
errors += 1
consecutive_errors += 1
error_types[error_type] = error_types.get(error_type, 0) + 1
# If we get 20 consecutive errors, something is seriously wrong
if consecutive_errors >= 20:
print(f"\nFATAL: {consecutive_errors} consecutive errors. "
f"Last error type: {error_type}", file=sys.stderr)
print("Flushing output and stopping. Re-run to resume.", file=sys.stderr)
out.write(json.dumps(entry, ensure_ascii=False) + "\n")
out.flush()
sys.exit(1)
elif response.strip().upper() == "DISCARD":
entry["status"] = "discarded"
discards += 1
consecutive_errors = 0
else:
# Sanity check the response
cleaned = response.strip()
# Sometimes the LLM wraps in quotes
if cleaned.startswith('"') and cleaned.endswith('"'):
cleaned = cleaned[1:-1]
# Sometimes the LLM prefixes with "Polished:" or similar
for prefix in ["Polished:", "polished:", "Output:", "Result:"]:
if cleaned.startswith(prefix):
cleaned = cleaned[len(prefix):].strip()
entry["polished_text"] = cleaned
entry["status"] = "polished"
polished += 1
consecutive_errors = 0
out.write(json.dumps(entry, ensure_ascii=False) + "\n")
# Flush every 10 entries for fine-grained resume safety
if (i + 1) % 10 == 0:
out.flush()
# Progress report every 100 entries
if (i + 1) % 100 == 0:
total_done = len(already_processed) + i + 1
elapsed = time.time() - start_time
rate = (i + 1) / elapsed
eta_sec = (len(remaining) - (i + 1)) / rate if rate > 0 else 0
eta_min = eta_sec / 60
print(f" [{total_done}/{len(raw_entries)}] "
f"polished={polished}, discarded={discards}, errors={errors} "
f"({rate:.1f}/s, ETA {eta_min:.0f}m)")
time.sleep(0.1)
except KeyboardInterrupt:
print(f"\nInterrupted at entry {i+1}/{len(remaining)}. "
f"Progress saved — re-run to resume.", file=sys.stderr)
# Final report
elapsed = time.time() - start_time
total_done = len(already_processed) + polished + discards + errors
print(f"\nSession complete: {polished + discards + errors} entries processed "
f"in {elapsed/60:.1f} minutes.")
print(f" Polished: {polished}")
print(f" Discarded: {discards}")
print(f" Errors: {errors}")
if error_types:
print(f" Error breakdown: {error_types}")
if polished + discards > 0:
print(f" Discard rate: {discards/(polished+discards)*100:.1f}%")
print(f" Total across all sessions: {total_done}/{len(raw_entries)}")
print(f"Output: {output_path}")
if __name__ == "__main__":
main()