ics-simlab-config-gen-claude/main.py

306 lines
10 KiB
Python

#!/usr/bin/env python3
"""
prompts/input_testuale.txt -> LLM -> build_config -> outputs/configuration.json
Pipeline:
1. LLM genera configuration raw
2. JSON validation + basic patches
3. build_config: Pydantic validate -> enrich -> semantic validate
4. If semantic errors, repair with LLM and loop back
5. Output: configuration.json (versione completa)
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Optional
from dotenv import load_dotenv
from openai import OpenAI
from helpers.helper import load_json_schema, log, read_text_file, write_json_file
from services.generation import generate_json_with_llm, repair_with_llm
from services.patches import (
patch_fill_required_keys,
patch_lowercase_names,
patch_sanitize_network_names,
)
from services.prompting import build_prompt
from services.validation import validate_basic
MAX_OUTPUT_TOKENS = 5000
def run_build_config(
raw_path: Path,
out_dir: Path,
skip_semantic: bool = False,
) -> tuple[bool, list[str]]:
"""
Run build_config on a raw configuration file.
Returns:
(success, errors): success=True if build_config passed,
errors=list of semantic error messages if failed
"""
cmd = [
sys.executable,
"-m",
"tools.build_config",
"--config", str(raw_path),
"--out-dir", str(out_dir),
"--overwrite",
"--json-errors",
]
if skip_semantic:
cmd.append("--skip-semantic")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
return True, []
# Exit code 2 = semantic validation failure with JSON output
if result.returncode == 2:
# Parse JSON errors from stdout (find last JSON object)
try:
stdout = result.stdout
# Look for "semantic_errors" marker, then find the enclosing { before it
marker = stdout.rfind('"semantic_errors"')
if marker >= 0:
json_start = stdout.rfind('{', 0, marker)
if json_start >= 0:
error_data = json.loads(stdout[json_start:])
errors = [
f"SEMANTIC ERROR in {e['entity']}: {e['message']}"
for e in error_data.get("semantic_errors", [])
]
return False, errors
except json.JSONDecodeError:
pass
# Other failures (Pydantic, etc.)
stderr = result.stderr.strip() if result.stderr else ""
stdout = result.stdout.strip() if result.stdout else ""
error_msg = stderr or stdout or f"build_config failed with exit code {result.returncode}"
return False, [error_msg]
def run_pipeline_with_semantic_validation(
*,
client: OpenAI,
model: str,
full_prompt: str,
schema: Optional[dict[str, Any]],
repair_template: str,
user_input: str,
raw_path: Path,
out_path: Path,
retries: int,
max_output_tokens: int,
skip_semantic: bool = False,
) -> None:
"""
Run the full pipeline: LLM generation -> JSON validation -> build_config -> semantic validation.
The loop repairs both JSON structure errors AND semantic errors.
"""
Path("outputs").mkdir(parents=True, exist_ok=True)
log(f"Calling LLM (model={model}, max_output_tokens={max_output_tokens})...")
t0 = time.time()
raw = generate_json_with_llm(
client=client,
model=model,
full_prompt=full_prompt,
schema=schema,
max_output_tokens=max_output_tokens,
)
dt = time.time() - t0
log(f"LLM returned in {dt:.1f}s. Output chars={len(raw)}")
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
log("Wrote outputs/last_raw_response.txt")
for attempt in range(retries):
log(f"Validate/repair attempt {attempt+1}/{retries}")
# Phase 1: JSON parsing
try:
obj = json.loads(raw)
except json.JSONDecodeError as e:
log(f"JSON decode error: {e}. Repairing...")
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=raw,
errors=[f"JSON decode error: {e}"],
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
continue
if not isinstance(obj, dict):
log("Top-level is not a JSON object. Repairing...")
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=raw,
errors=["Top-level JSON must be an object"],
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
continue
# Phase 2: Patches
obj, patch_errors_0 = patch_fill_required_keys(obj)
obj, patch_errors_1 = patch_lowercase_names(obj)
obj, patch_errors_2 = patch_sanitize_network_names(obj)
raw = json.dumps(obj, ensure_ascii=False)
# Phase 3: Basic validation
basic_errors = patch_errors_0 + patch_errors_1 + patch_errors_2 + validate_basic(obj)
if basic_errors:
log(f"Basic validation failed with {len(basic_errors)} error(s). Repairing...")
for e in basic_errors[:12]:
log(f" - {e}")
if len(basic_errors) > 12:
log(f" ... (+{len(basic_errors)-12} more)")
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=json.dumps(obj, ensure_ascii=False),
errors=basic_errors,
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
continue
# Phase 4: Save raw config and run build_config (Pydantic + enrich + semantic)
write_json_file(raw_path, obj)
log(f"Saved raw config -> {raw_path}")
log("Running build_config (Pydantic + enrich + semantic validation)...")
success, semantic_errors = run_build_config(
raw_path=raw_path,
out_dir=out_path.parent,
skip_semantic=skip_semantic,
)
if success:
log(f"SUCCESS: Configuration built and validated -> {out_path}")
return
# Semantic validation failed - repair and retry
log(f"Semantic validation failed with {len(semantic_errors)} error(s). Repairing...")
for e in semantic_errors[:12]:
log(f" - {e}")
if len(semantic_errors) > 12:
log(f" ... (+{len(semantic_errors)-12} more)")
raw = repair_with_llm(
client=client,
model=model,
schema=schema,
repair_template=repair_template,
user_input=user_input,
current_raw=json.dumps(obj, ensure_ascii=False),
errors=semantic_errors,
max_output_tokens=max_output_tokens,
)
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
raise SystemExit(
f"ERROR: Failed to generate valid configuration after {retries} attempts. "
f"Check outputs/last_raw_response.txt for the last attempt."
)
def main() -> None:
load_dotenv()
parser = argparse.ArgumentParser(description="Generate configuration.json from file input.")
parser.add_argument("--prompt-file", default="prompts/prompt_json_generation.txt")
parser.add_argument("--input-file", default="prompts/input_testuale.txt")
parser.add_argument("--repair-prompt-file", default="prompts/prompt_repair.txt")
parser.add_argument("--schema-file", default="models/schemas/ics_simlab_config_schema_v1.json")
parser.add_argument("--model", default="gpt-5-mini")
parser.add_argument("--out", default="outputs/configuration.json")
parser.add_argument("--retries", type=int, default=3)
parser.add_argument("--skip-enrich", action="store_true",
help="Skip build_config enrichment (output raw LLM config)")
parser.add_argument("--skip-semantic", action="store_true",
help="Skip semantic validation in build_config")
args = parser.parse_args()
if not os.getenv("OPENAI_API_KEY"):
raise SystemExit("OPENAI_API_KEY non è impostata. Esegui: export OPENAI_API_KEY='...'")
prompt_template = read_text_file(Path(args.prompt_file))
user_input = read_text_file(Path(args.input_file))
repair_template = read_text_file(Path(args.repair_prompt_file))
full_prompt = build_prompt(prompt_template, user_input)
schema_path = Path(args.schema_file)
schema = load_json_schema(schema_path)
if schema is None:
log(f"Structured Outputs DISABLED (schema not found/invalid): {schema_path}")
else:
log(f"Structured Outputs ENABLED (schema loaded): {schema_path}")
client = OpenAI()
out_path = Path(args.out)
raw_path = out_path.parent / "configuration_raw.json"
if args.skip_enrich:
# Use the old pipeline (no build_config)
from services.pipeline import run_pipeline
run_pipeline(
client=client,
model=args.model,
full_prompt=full_prompt,
schema=schema,
repair_template=repair_template,
user_input=user_input,
out_path=out_path,
retries=args.retries,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
log(f"Output (raw LLM): {out_path}")
else:
# Use integrated pipeline with semantic validation in repair loop
run_pipeline_with_semantic_validation(
client=client,
model=args.model,
full_prompt=full_prompt,
schema=schema,
repair_template=repair_template,
user_input=user_input,
raw_path=raw_path,
out_path=out_path,
retries=args.retries,
max_output_tokens=MAX_OUTPUT_TOKENS,
skip_semantic=args.skip_semantic,
)
if __name__ == "__main__":
main()