306 lines
10 KiB
Python
306 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
prompts/input_testuale.txt -> LLM -> build_config -> outputs/configuration.json
|
|
|
|
Pipeline:
|
|
1. LLM genera configuration raw
|
|
2. JSON validation + basic patches
|
|
3. build_config: Pydantic validate -> enrich -> semantic validate
|
|
4. If semantic errors, repair with LLM and loop back
|
|
5. Output: configuration.json (versione completa)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
from dotenv import load_dotenv
|
|
from openai import OpenAI
|
|
|
|
from helpers.helper import load_json_schema, log, read_text_file, write_json_file
|
|
from services.generation import generate_json_with_llm, repair_with_llm
|
|
from services.patches import (
|
|
patch_fill_required_keys,
|
|
patch_lowercase_names,
|
|
patch_sanitize_network_names,
|
|
)
|
|
from services.prompting import build_prompt
|
|
from services.validation import validate_basic
|
|
|
|
|
|
MAX_OUTPUT_TOKENS = 5000
|
|
|
|
|
|
def run_build_config(
|
|
raw_path: Path,
|
|
out_dir: Path,
|
|
skip_semantic: bool = False,
|
|
) -> tuple[bool, list[str]]:
|
|
"""
|
|
Run build_config on a raw configuration file.
|
|
|
|
Returns:
|
|
(success, errors): success=True if build_config passed,
|
|
errors=list of semantic error messages if failed
|
|
"""
|
|
cmd = [
|
|
sys.executable,
|
|
"-m",
|
|
"tools.build_config",
|
|
"--config", str(raw_path),
|
|
"--out-dir", str(out_dir),
|
|
"--overwrite",
|
|
"--json-errors",
|
|
]
|
|
if skip_semantic:
|
|
cmd.append("--skip-semantic")
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
if result.returncode == 0:
|
|
return True, []
|
|
|
|
# Exit code 2 = semantic validation failure with JSON output
|
|
if result.returncode == 2:
|
|
# Parse JSON errors from stdout (find last JSON object)
|
|
try:
|
|
stdout = result.stdout
|
|
# Look for "semantic_errors" marker, then find the enclosing { before it
|
|
marker = stdout.rfind('"semantic_errors"')
|
|
if marker >= 0:
|
|
json_start = stdout.rfind('{', 0, marker)
|
|
if json_start >= 0:
|
|
error_data = json.loads(stdout[json_start:])
|
|
errors = [
|
|
f"SEMANTIC ERROR in {e['entity']}: {e['message']}"
|
|
for e in error_data.get("semantic_errors", [])
|
|
]
|
|
return False, errors
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Other failures (Pydantic, etc.)
|
|
stderr = result.stderr.strip() if result.stderr else ""
|
|
stdout = result.stdout.strip() if result.stdout else ""
|
|
error_msg = stderr or stdout or f"build_config failed with exit code {result.returncode}"
|
|
return False, [error_msg]
|
|
|
|
|
|
def run_pipeline_with_semantic_validation(
|
|
*,
|
|
client: OpenAI,
|
|
model: str,
|
|
full_prompt: str,
|
|
schema: Optional[dict[str, Any]],
|
|
repair_template: str,
|
|
user_input: str,
|
|
raw_path: Path,
|
|
out_path: Path,
|
|
retries: int,
|
|
max_output_tokens: int,
|
|
skip_semantic: bool = False,
|
|
) -> None:
|
|
"""
|
|
Run the full pipeline: LLM generation -> JSON validation -> build_config -> semantic validation.
|
|
|
|
The loop repairs both JSON structure errors AND semantic errors.
|
|
"""
|
|
Path("outputs").mkdir(parents=True, exist_ok=True)
|
|
|
|
log(f"Calling LLM (model={model}, max_output_tokens={max_output_tokens})...")
|
|
t0 = time.time()
|
|
raw = generate_json_with_llm(
|
|
client=client,
|
|
model=model,
|
|
full_prompt=full_prompt,
|
|
schema=schema,
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
dt = time.time() - t0
|
|
log(f"LLM returned in {dt:.1f}s. Output chars={len(raw)}")
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
log("Wrote outputs/last_raw_response.txt")
|
|
|
|
for attempt in range(retries):
|
|
log(f"Validate/repair attempt {attempt+1}/{retries}")
|
|
|
|
# Phase 1: JSON parsing
|
|
try:
|
|
obj = json.loads(raw)
|
|
except json.JSONDecodeError as e:
|
|
log(f"JSON decode error: {e}. Repairing...")
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=raw,
|
|
errors=[f"JSON decode error: {e}"],
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
continue
|
|
|
|
if not isinstance(obj, dict):
|
|
log("Top-level is not a JSON object. Repairing...")
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=raw,
|
|
errors=["Top-level JSON must be an object"],
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
continue
|
|
|
|
# Phase 2: Patches
|
|
obj, patch_errors_0 = patch_fill_required_keys(obj)
|
|
obj, patch_errors_1 = patch_lowercase_names(obj)
|
|
obj, patch_errors_2 = patch_sanitize_network_names(obj)
|
|
raw = json.dumps(obj, ensure_ascii=False)
|
|
|
|
# Phase 3: Basic validation
|
|
basic_errors = patch_errors_0 + patch_errors_1 + patch_errors_2 + validate_basic(obj)
|
|
|
|
if basic_errors:
|
|
log(f"Basic validation failed with {len(basic_errors)} error(s). Repairing...")
|
|
for e in basic_errors[:12]:
|
|
log(f" - {e}")
|
|
if len(basic_errors) > 12:
|
|
log(f" ... (+{len(basic_errors)-12} more)")
|
|
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=json.dumps(obj, ensure_ascii=False),
|
|
errors=basic_errors,
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
continue
|
|
|
|
# Phase 4: Save raw config and run build_config (Pydantic + enrich + semantic)
|
|
write_json_file(raw_path, obj)
|
|
log(f"Saved raw config -> {raw_path}")
|
|
|
|
log("Running build_config (Pydantic + enrich + semantic validation)...")
|
|
success, semantic_errors = run_build_config(
|
|
raw_path=raw_path,
|
|
out_dir=out_path.parent,
|
|
skip_semantic=skip_semantic,
|
|
)
|
|
|
|
if success:
|
|
log(f"SUCCESS: Configuration built and validated -> {out_path}")
|
|
return
|
|
|
|
# Semantic validation failed - repair and retry
|
|
log(f"Semantic validation failed with {len(semantic_errors)} error(s). Repairing...")
|
|
for e in semantic_errors[:12]:
|
|
log(f" - {e}")
|
|
if len(semantic_errors) > 12:
|
|
log(f" ... (+{len(semantic_errors)-12} more)")
|
|
|
|
raw = repair_with_llm(
|
|
client=client,
|
|
model=model,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
current_raw=json.dumps(obj, ensure_ascii=False),
|
|
errors=semantic_errors,
|
|
max_output_tokens=max_output_tokens,
|
|
)
|
|
Path("outputs/last_raw_response.txt").write_text(raw, encoding="utf-8")
|
|
|
|
raise SystemExit(
|
|
f"ERROR: Failed to generate valid configuration after {retries} attempts. "
|
|
f"Check outputs/last_raw_response.txt for the last attempt."
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
load_dotenv()
|
|
|
|
parser = argparse.ArgumentParser(description="Generate configuration.json from file input.")
|
|
parser.add_argument("--prompt-file", default="prompts/prompt_json_generation.txt")
|
|
parser.add_argument("--input-file", default="prompts/input_testuale.txt")
|
|
parser.add_argument("--repair-prompt-file", default="prompts/prompt_repair.txt")
|
|
parser.add_argument("--schema-file", default="models/schemas/ics_simlab_config_schema_v1.json")
|
|
parser.add_argument("--model", default="gpt-5-mini")
|
|
parser.add_argument("--out", default="outputs/configuration.json")
|
|
parser.add_argument("--retries", type=int, default=3)
|
|
parser.add_argument("--skip-enrich", action="store_true",
|
|
help="Skip build_config enrichment (output raw LLM config)")
|
|
parser.add_argument("--skip-semantic", action="store_true",
|
|
help="Skip semantic validation in build_config")
|
|
args = parser.parse_args()
|
|
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
raise SystemExit("OPENAI_API_KEY non è impostata. Esegui: export OPENAI_API_KEY='...'")
|
|
|
|
prompt_template = read_text_file(Path(args.prompt_file))
|
|
user_input = read_text_file(Path(args.input_file))
|
|
repair_template = read_text_file(Path(args.repair_prompt_file))
|
|
full_prompt = build_prompt(prompt_template, user_input)
|
|
|
|
schema_path = Path(args.schema_file)
|
|
schema = load_json_schema(schema_path)
|
|
if schema is None:
|
|
log(f"Structured Outputs DISABLED (schema not found/invalid): {schema_path}")
|
|
else:
|
|
log(f"Structured Outputs ENABLED (schema loaded): {schema_path}")
|
|
|
|
client = OpenAI()
|
|
out_path = Path(args.out)
|
|
raw_path = out_path.parent / "configuration_raw.json"
|
|
|
|
if args.skip_enrich:
|
|
# Use the old pipeline (no build_config)
|
|
from services.pipeline import run_pipeline
|
|
run_pipeline(
|
|
client=client,
|
|
model=args.model,
|
|
full_prompt=full_prompt,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
out_path=out_path,
|
|
retries=args.retries,
|
|
max_output_tokens=MAX_OUTPUT_TOKENS,
|
|
)
|
|
log(f"Output (raw LLM): {out_path}")
|
|
else:
|
|
# Use integrated pipeline with semantic validation in repair loop
|
|
run_pipeline_with_semantic_validation(
|
|
client=client,
|
|
model=args.model,
|
|
full_prompt=full_prompt,
|
|
schema=schema,
|
|
repair_template=repair_template,
|
|
user_input=user_input,
|
|
raw_path=raw_path,
|
|
out_path=out_path,
|
|
retries=args.retries,
|
|
max_output_tokens=MAX_OUTPUT_TOKENS,
|
|
skip_semantic=args.skip_semantic,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|