120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
from openai import OpenAI
|
|
|
|
from services.generation import generate_json_with_llm, repair_with_llm
|
|
from helpers.helper import log, write_json_file
|
|
from services.patches import (
|
|
patch_fill_required_keys,
|
|
patch_lowercase_names,
|
|
patch_sanitize_network_names,
|
|
)
|
|
from services.validation import validate_basic
|
|
|
|
|
|
def run_pipeline(
|
|
*,
|
|
client: OpenAI,
|
|
model: str,
|
|
full_prompt: str,
|
|
schema: Optional[dict[str, Any]],
|
|
repair_template: str,
|
|
user_input: str,
|
|
out_path: Path,
|
|
retries: int,
|
|
max_output_tokens: int,
|
|
) -> None:
|
|
Path("outputs").mkdir(parents=True, exist_ok=True)
|
|
|
|
log(f"Calling LLM (model={model}, max_output_tokens={max_output_tokens})...")
|
|
t0 = time.time()
|
|
raw = generate_json_with_llm(
|
|
client=client,
|
|
model=model,
|
|
full_prompt=full_prompt,
|
|
schema=schema,
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
dt = time.time() - t0
|
|
log(f"LLM returned in {dt:.1f}s. Output chars={len(raw)}")
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
log("Wrote outputs/last_raw_response.txt")
|
|
|
|
for attempt in range(retries):
|
|
log(f"Validate/repair attempt {attempt+1}/{retries}")
|
|
|
|
# 1) parse
|
|
try:
|
|
obj = json.loads(raw)
|
|
except json.JSONDecodeError as e:
|
|
log(f"JSON decode error: {e}. Repairing...")
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=raw,
|
|
errors=[f"JSON decode error: {e}"],
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
log("Wrote outputs/last_raw_response.txt")
|
|
continue
|
|
|
|
if not isinstance(obj, dict):
|
|
log("Top-level is not a JSON object. Repairing...")
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=raw,
|
|
errors=["Top-level JSON must be an object"],
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
log("Wrote outputs/last_raw_response.txt")
|
|
continue
|
|
|
|
# 2) patches BEFORE validation (order matters)
|
|
obj, patch_errors_0 = patch_fill_required_keys(obj)
|
|
obj, patch_errors_1 = patch_lowercase_names(obj)
|
|
obj, patch_errors_2 = patch_sanitize_network_names(obj)
|
|
|
|
raw = json.dumps(obj, ensure_ascii=False)
|
|
|
|
# 3) validate
|
|
errors = patch_errors_0 + patch_errors_1 + patch_errors_2 + validate_basic(obj)
|
|
|
|
if not errors:
|
|
write_json_file(out_path, obj)
|
|
log(f"Saved OK -> {out_path}")
|
|
return
|
|
|
|
log(f"Validation failed with {len(errors)} error(s). Repairing...")
|
|
for e in errors[:12]:
|
|
log(f" - {e}")
|
|
if len(errors) > 12:
|
|
log(f" ... (+{len(errors)-12} more)")
|
|
|
|
# 4) repair
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=json.dumps(obj, ensure_ascii=False),
|
|
errors=errors,
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
log("Wrote outputs/last_raw_response.txt")
|