#!/usr/bin/env python3 """LLM polish pipeline for raw folksy sayings. Reads corpus_raw.jsonl, sends each to GLM4-32B for polish. Output file is the checkpoint — append mode with resume detection. Robust error handling: - Context size errors: truncates chain data and retries - JSON parse errors: retries, then marks as error - Transient HTTP errors: exponential backoff retry - Keyboard interrupt: flushes and exits cleanly - Safe resume: skips entries already in output file Usage: python scripts/polish_corpus.py python scripts/polish_corpus.py --input corpus/corpus_raw.jsonl --output corpus/corpus_polished.jsonl """ import argparse import json import sys import time from pathlib import Path SCRIPT_DIR = Path(__file__).parent PROJECT_DIR = SCRIPT_DIR.parent CORPUS_DIR = PROJECT_DIR / "corpus" LLM_ENDPOINT = "http://192.168.1.100:8853/v1d/chat/completions" LLM_MODEL = "THUDM-GLM4-32B" SYSTEM_PROMPT = """You are an editor specializing in folk sayings and rural proverbs. You will receive a rough draft of a fake folksy saying along with the relationship chain it encodes. Your job: 1. Fix grammar, articles, and pluralization 2. Make it sound natural — like something a weathered farmer would say while leaning on a fence post 3. Preserve the core nouns and the relationship between them — do not swap out the key words 4. You MAY add small colorful details (adjectives, folksy verb choices, regional flavor) but keep it concise — real proverbs are short 5. You MAY lightly restructure the sentence for better rhythm, but keep the same meaning pattern 6. If the saying is unsalvageable nonsense (the nouns don't relate in any meaningful way, or the combination is unintentionally offensive), respond with exactly: DISCARD Output ONLY the polished saying on a single line. No quotes, no explanation, no preamble. Examples of good polish: Raw: "Don't build the coffee and act surprised when the water show up." Chain: coffee MadeOf water Polished: Don't brew the coffee and act surprised when the water's all gone. Raw: "The chest's children always goes without hold books." Chain: chest UsedFor hold_books Polished: The bookshelf-maker's kids always end up reading off the floor. Raw: "A pineapple is just a nectarine that's got an attitude." Chain: pineapple IsA fruit, nectarine IsA fruit, pineapple HasProperty prickly Polished: A pineapple is just a peach that grew itself some armor. Raw: "You know what they say, a steel with no iron is just a harder than gold iron." Chain: steel MadeOf iron, steel HasProperty hard Polished: You know what they say — steel without the iron is just a dream of being hard. Raw: "Funny how the bamboo never has enough grow very quickly for itself." Chain: bamboo CapableOf grow_quickly Polished: DISCARD Raw: "That's just funning the canoe and praying for boiling food." Chain: canoe UsedFor transport, fire UsedFor boiling_food Polished: DISCARD""" class LLMError(Exception): """Base class for LLM errors.""" pass class ContextTooLong(LLMError): """Prompt exceeded context window.""" pass class TransientError(LLMError): """Recoverable error (network, server overload, etc.).""" pass def llm_chat_completion(messages, max_retries=3): """Chat completion with retry logic and error classification. Returns (response_text, error_type) tuple. response_text is None on failure; error_type is None on success. """ import requests for attempt in range(max_retries): try: resp = requests.post(LLM_ENDPOINT, json={ "model": LLM_MODEL, "messages": messages, "temperature": 0.7, }, timeout=120) # Check for context length errors (HTTP 400 typically) if resp.status_code == 400: body = resp.text.lower() if any(kw in body for kw in ["context", "token", "length", "too long", "exceed"]): return None, "context_too_long" # Other 400 errors — log and retry print(f" HTTP 400 (attempt {attempt+1}): {resp.text[:200]}", file=sys.stderr) if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return None, "http_400" if resp.status_code == 503 or resp.status_code == 429: wait = 2 ** (attempt + 1) print(f" HTTP {resp.status_code} (attempt {attempt+1}), waiting {wait}s...", file=sys.stderr) if attempt < max_retries - 1: time.sleep(wait) continue return None, "server_overload" resp.raise_for_status() # Parse JSON response try: data = resp.json() except (json.JSONDecodeError, ValueError) as e: print(f" JSON parse error (attempt {attempt+1}): {e}", file=sys.stderr) print(f" Response body: {resp.text[:300]}", file=sys.stderr) if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return None, "json_parse" # Extract content from response try: content = data["choices"][0]["message"]["content"] if content is None: print(f" Null content in response (attempt {attempt+1})", file=sys.stderr) if attempt < max_retries - 1: time.sleep(1) continue return None, "null_content" return content.strip(), None except (KeyError, IndexError) as e: print(f" Unexpected JSON structure (attempt {attempt+1}): {e}", file=sys.stderr) print(f" Keys: {list(data.keys()) if isinstance(data, dict) else type(data)}", file=sys.stderr) if attempt < max_retries - 1: time.sleep(1) continue return None, "json_structure" except requests.exceptions.Timeout: wait = 2 ** (attempt + 1) print(f" Timeout (attempt {attempt+1}), waiting {wait}s...", file=sys.stderr) if attempt < max_retries - 1: time.sleep(wait) continue return None, "timeout" except requests.exceptions.ConnectionError as e: wait = 2 ** (attempt + 2) # longer wait for connection errors print(f" Connection error (attempt {attempt+1}): {e}", file=sys.stderr) if attempt < max_retries - 1: time.sleep(wait) continue return None, "connection" except Exception as e: print(f" Unexpected error (attempt {attempt+1}): {type(e).__name__}: {e}", file=sys.stderr) if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return None, "unexpected" return None, "exhausted_retries" def format_chain(chain_edges, truncate=False): """Format chain_edges list into readable string for LLM context. If truncate=True, omit weights and surface_text to reduce token count. """ if not chain_edges: return "(no chain data)" parts = [] for edge in chain_edges: start = edge.get("start", "?") rel = edge.get("relation", "?") end = edge.get("end", "?") if truncate: parts.append(f"{start} --{rel}--> {end}") else: weight = edge.get("weight", 0) parts.append(f"{start} --{rel}--> {end} (w:{weight:.1f})") return ", ".join(parts) def format_slots(slots): """Format slots dict for LLM context.""" return ", ".join(f"{k}={v}" for k, v in slots.items()) def build_messages(entry, truncate_chain=False): """Build the messages list for a single entry.""" raw_text = entry.get("raw_text", "") meta_template = entry.get("meta_template", "") chain = format_chain(entry.get("chain", []), truncate=truncate_chain) slots = format_slots(entry.get("slots", {})) user_prompt = ( f"Meta-template: {meta_template}\n" f"Relationship chain: {chain}\n" f"Slot fills: {slots}\n" f"Raw saying: {raw_text}" ) return [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ] def load_already_processed(output_path): """Load set of raw_text strings already processed (for resume). Also returns counts of each status for accurate progress reporting. """ processed = set() counts = {"polished": 0, "discarded": 0, "error": 0} if output_path.exists(): with open(output_path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: entry = json.loads(line) processed.add(entry.get("raw_text", "")) status = entry.get("status", "") if status in counts: counts[status] += 1 except json.JSONDecodeError: continue return processed, counts def main(): parser = argparse.ArgumentParser(description="LLM polish pipeline for folksy sayings.") parser.add_argument("--input", default=str(CORPUS_DIR / "corpus_raw.jsonl"), help="Input JSONL file") parser.add_argument("--output", default=str(CORPUS_DIR / "corpus_polished.jsonl"), help="Output JSONL file (also serves as checkpoint)") args = parser.parse_args() input_path = Path(args.input) output_path = Path(args.output) if not input_path.exists(): print(f"Error: {input_path} not found.", file=sys.stderr) sys.exit(1) # Load raw entries raw_entries = [] with open(input_path, encoding="utf-8") as f: for line in f: line = line.strip() if line: try: raw_entries.append(json.loads(line)) except json.JSONDecodeError as e: print(f"Warning: skipping malformed input line: {e}", file=sys.stderr) print(f"Loaded {len(raw_entries)} raw entries from {input_path}") # Check what's already been processed already_processed, prev_counts = load_already_processed(output_path) remaining = [e for e in raw_entries if e.get("raw_text", "") not in already_processed] print(f"Already processed: {len(already_processed)} " f"(polished={prev_counts['polished']}, " f"discarded={prev_counts['discarded']}, " f"errors={prev_counts['error']})") print(f"Remaining: {len(remaining)}") if not remaining: print("Nothing to process.") return discards = 0 polished = 0 errors = 0 error_types = {} consecutive_errors = 0 start_time = time.time() try: with open(output_path, "a", encoding="utf-8") as out: for i, entry in enumerate(remaining): # First attempt with full chain data messages = build_messages(entry, truncate_chain=False) response, error_type = llm_chat_completion(messages) # If context too long, retry with truncated chain if error_type == "context_too_long": print(f" #{i+1}: context too long, retrying with truncated chain...", file=sys.stderr) messages = build_messages(entry, truncate_chain=True) response, error_type = llm_chat_completion(messages) # If still too long, try with just the raw text if error_type == "context_too_long": print(f" #{i+1}: still too long, retrying with minimal prompt...", file=sys.stderr) raw_text = entry.get("raw_text", "") messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Raw saying: {raw_text}"}, ] response, error_type = llm_chat_completion(messages) if response is None: entry["status"] = "error" entry["error_type"] = error_type or "unknown" errors += 1 consecutive_errors += 1 error_types[error_type] = error_types.get(error_type, 0) + 1 # If we get 20 consecutive errors, something is seriously wrong if consecutive_errors >= 20: print(f"\nFATAL: {consecutive_errors} consecutive errors. " f"Last error type: {error_type}", file=sys.stderr) print("Flushing output and stopping. Re-run to resume.", file=sys.stderr) out.write(json.dumps(entry, ensure_ascii=False) + "\n") out.flush() sys.exit(1) elif response.strip().upper() == "DISCARD": entry["status"] = "discarded" discards += 1 consecutive_errors = 0 else: # Sanity check the response cleaned = response.strip() # Sometimes the LLM wraps in quotes if cleaned.startswith('"') and cleaned.endswith('"'): cleaned = cleaned[1:-1] # Sometimes the LLM prefixes with "Polished:" or similar for prefix in ["Polished:", "polished:", "Output:", "Result:"]: if cleaned.startswith(prefix): cleaned = cleaned[len(prefix):].strip() entry["polished_text"] = cleaned entry["status"] = "polished" polished += 1 consecutive_errors = 0 out.write(json.dumps(entry, ensure_ascii=False) + "\n") # Flush every 10 entries for fine-grained resume safety if (i + 1) % 10 == 0: out.flush() # Progress report every 100 entries if (i + 1) % 100 == 0: total_done = len(already_processed) + i + 1 elapsed = time.time() - start_time rate = (i + 1) / elapsed eta_sec = (len(remaining) - (i + 1)) / rate if rate > 0 else 0 eta_min = eta_sec / 60 print(f" [{total_done}/{len(raw_entries)}] " f"polished={polished}, discarded={discards}, errors={errors} " f"({rate:.1f}/s, ETA {eta_min:.0f}m)") time.sleep(0.1) except KeyboardInterrupt: print(f"\nInterrupted at entry {i+1}/{len(remaining)}. " f"Progress saved — re-run to resume.", file=sys.stderr) # Final report elapsed = time.time() - start_time total_done = len(already_processed) + polished + discards + errors print(f"\nSession complete: {polished + discards + errors} entries processed " f"in {elapsed/60:.1f} minutes.") print(f" Polished: {polished}") print(f" Discarded: {discards}") print(f" Errors: {errors}") if error_types: print(f" Error breakdown: {error_types}") if polished + discards > 0: print(f" Discard rate: {discards/(polished+discards)*100:.1f}%") print(f" Total across all sessions: {total_done}/{len(raw_entries)}") print(f"Output: {output_path}") if __name__ == "__main__": main()