ics-simlab-config-gen-claude/services/pipeline.py

120 lines
3.8 KiB
Python

from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any, Optional
from openai import OpenAI
from services.generation import generate_json_with_llm, repair_with_llm
from helpers.helper import log, write_json_file
from services.patches import (
patch_fill_required_keys,
patch_lowercase_names,
patch_sanitize_network_names,
)
from services.validation import validate_basic
def run_pipeline(
*,
client: OpenAI,
model: str,
full_prompt: str,
schema: Optional[dict[str, Any]],
repair_template: str,
user_input: str,
out_path: Path,
retries: int,
max_output_tokens: int,
) -> None:
Path("outputs").mkdir(parents=True, exist_ok=True)
log(f"Calling LLM (model={model}, max_output_tokens={max_output_tokens})...")
t0 = time.time()
raw = generate_json_with_llm(
client=client,
model=model,
full_prompt=full_prompt,
schema=schema,
max_output_tokens=max_output_tokens,
)
dt = time.time() - t0
log(f"LLM returned in {dt:.1f}s. Output chars={len(raw)}")
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
log("Wrote outputs/last_raw_response.txt")
for attempt in range(retries):
log(f"Validate/repair attempt {attempt+1}/{retries}")
# 1) parse
try:
obj = json.loads(raw)
except json.JSONDecodeError as e:
log(f"JSON decode error: {e}. Repairing...")
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=raw,
errors=[f"JSON decode error: {e}"],
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
log("Wrote outputs/last_raw_response.txt")
continue
if not isinstance(obj, dict):
log("Top-level is not a JSON object. Repairing...")
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=raw,
errors=["Top-level JSON must be an object"],
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
log("Wrote outputs/last_raw_response.txt")
continue
# 2) patches BEFORE validation (order matters)
obj, patch_errors_0 = patch_fill_required_keys(obj)
obj, patch_errors_1 = patch_lowercase_names(obj)
obj, patch_errors_2 = patch_sanitize_network_names(obj)
raw = json.dumps(obj, ensure_ascii=False)
# 3) validate
errors = patch_errors_0 + patch_errors_1 + patch_errors_2 + validate_basic(obj)
if not errors:
write_json_file(out_path, obj)
log(f"Saved OK -> {out_path}")
return
log(f"Validation failed with {len(errors)} error(s). Repairing...")
for e in errors[:12]:
log(f" - {e}")
if len(errors) > 12:
log(f" ... (+{len(errors)-12} more)")
# 4) repair
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=json.dumps(obj, ensure_ascii=False),
errors=errors,
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
log("Wrote outputs/last_raw_response.txt")