ics-simlab-config-gen-claude/tools/safe_eval.py

265 lines
7.6 KiB
Python

"""
Safe expression evaluation using AST parsing.
This module provides safe_eval() for evaluating expressions from control plans.
Only a whitelist of safe AST nodes and function calls are allowed.
Security model:
- Parse expression with ast.parse(mode='eval')
- Walk AST and verify all nodes are in the whitelist
- If valid, compile and eval with restricted locals/globals
Allowed:
- Constants: numbers, strings, booleans, None
- Names: variable references (resolved from provided namespace)
- BinOp: +, -, *, /, //, %, **
- UnaryOp: -, +, not
- BoolOp: and, or
- Compare: ==, !=, <, <=, >, >=, in, not in
- Call: only allowlisted functions (min, max, abs, int, float, bool, clamp)
- IfExp: ternary (x if cond else y)
Forbidden:
- Attribute access (obj.attr)
- Subscript (obj[key])
- Lambda
- Comprehensions (list, dict, set, generator)
- Import
- Call to non-allowlisted functions
- Assignment expressions (:=)
"""
from __future__ import annotations
import ast
from typing import Any, Dict, Optional, Set
class UnsafeExpressionError(Exception):
"""Raised when an expression contains unsafe AST nodes."""
pass
# Allowlisted function names that can be called in expressions
SAFE_FUNCTIONS: Set[str] = {
"min",
"max",
"abs",
"int",
"float",
"bool",
"clamp", # Custom clamp function provided in builtins
}
# Allowlisted AST node types
SAFE_NODES: Set[type] = {
ast.Expression, # Top-level for mode='eval'
ast.Constant, # Literals (numbers, strings, booleans, None)
ast.Name, # Variable references
ast.Load, # Load context for names
ast.BinOp, # Binary operations
ast.UnaryOp, # Unary operations
ast.BoolOp, # Boolean operations (and, or)
ast.Compare, # Comparisons
ast.Call, # Function calls (restricted to SAFE_FUNCTIONS)
ast.IfExp, # Ternary: x if cond else y
# Operators
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.FloorDiv,
ast.Mod,
ast.Pow,
ast.USub, # Unary minus
ast.UAdd, # Unary plus
ast.Not, # not
ast.And,
ast.Or,
ast.Eq,
ast.NotEq,
ast.Lt,
ast.LtE,
ast.Gt,
ast.GtE,
ast.In,
ast.NotIn,
}
def _clamp(x: float, lo: float, hi: float) -> float:
"""Clamp x to the range [lo, hi]."""
return lo if x < lo else hi if x > hi else x
def validate_expression(expr: str) -> None:
"""
Validate that an expression is safe to evaluate.
Raises:
UnsafeExpressionError: if expression contains unsafe constructs
SyntaxError: if expression is not valid Python
"""
try:
tree = ast.parse(expr, mode='eval')
except SyntaxError as e:
raise SyntaxError(f"Invalid expression: {e}")
for node in ast.walk(tree):
node_type = type(node)
# Check if node type is allowed
if node_type not in SAFE_NODES:
raise UnsafeExpressionError(
f"Unsafe AST node type: {node_type.__name__} in expression: {expr}"
)
# Special check for Call nodes: only allow safe functions
if isinstance(node, ast.Call):
# Function must be a simple Name (no attribute access)
if not isinstance(node.func, ast.Name):
raise UnsafeExpressionError(
f"Unsafe function call (not a simple name) in expression: {expr}"
)
func_name = node.func.id
if func_name not in SAFE_FUNCTIONS:
raise UnsafeExpressionError(
f"Unsafe function call: {func_name} in expression: {expr}. "
f"Allowed: {', '.join(sorted(SAFE_FUNCTIONS))}"
)
def safe_eval(expr: str, namespace: Dict[str, Any]) -> Any:
"""
Safely evaluate an expression with the given namespace.
Args:
expr: Python expression string
namespace: dict mapping variable names to values
Returns:
The result of evaluating the expression
Raises:
UnsafeExpressionError: if expression contains unsafe constructs
SyntaxError: if expression is not valid Python
NameError: if expression references undefined variables
Exception: for runtime errors (division by zero, etc.)
"""
# Validate expression safety
validate_expression(expr)
# Build safe globals with only our clamp function and builtins
safe_globals: Dict[str, Any] = {
"__builtins__": {
"min": min,
"max": max,
"abs": abs,
"int": int,
"float": float,
"bool": bool,
"True": True,
"False": False,
"None": None,
"clamp": _clamp,
}
}
# Compile and evaluate
code = compile(expr, "<control_plan>", "eval")
return eval(code, safe_globals, namespace)
def safe_eval_condition(expr: str, namespace: Dict[str, Any]) -> bool:
"""
Safely evaluate a boolean condition.
Same as safe_eval but ensures result is converted to bool.
"""
result = safe_eval(expr, namespace)
return bool(result)
def extract_variable_names(expr: str) -> Set[str]:
"""
Extract all variable names referenced in an expression.
This is useful for validation: checking that all referenced
variables exist in the namespace before runtime.
Returns:
Set of variable names referenced in the expression
Raises:
SyntaxError: if expression is not valid Python
"""
try:
tree = ast.parse(expr, mode='eval')
except SyntaxError as e:
raise SyntaxError(f"Invalid expression: {e}")
names: Set[str] = set()
for node in ast.walk(tree):
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
# Skip function names (they're provided in builtins)
if node.id not in SAFE_FUNCTIONS and node.id not in {"True", "False", "None"}:
names.add(node.id)
return names
def generate_python_code(expr: str, namespace_var: str = "pv") -> str:
"""
Generate Python code for an expression, with variables read from a dict.
This is used by the compiler to generate HIL logic code.
Args:
expr: The expression string
namespace_var: Name of the dict variable containing values
Returns:
Python code string that evaluates the expression
Example:
>>> generate_python_code("x + y * 2", "physical_values")
"physical_values.get('x', 0) + physical_values.get('y', 0) * 2"
"""
# First validate the expression
validate_expression(expr)
# Parse and transform
tree = ast.parse(expr, mode='eval')
class NameTransformer(ast.NodeTransformer):
"""Transform Name nodes to dict.get() calls."""
def visit_Name(self, node: ast.Name) -> ast.AST:
# Skip function names and builtins
if node.id in SAFE_FUNCTIONS or node.id in {"True", "False", "None"}:
return node
# Transform: x -> pv.get('x', 0)
if isinstance(node.ctx, ast.Load):
return ast.Call(
func=ast.Attribute(
value=ast.Name(id=namespace_var, ctx=ast.Load()),
attr='get',
ctx=ast.Load()
),
args=[
ast.Constant(value=node.id),
ast.Constant(value=0)
],
keywords=[]
)
return node
# Transform the tree
transformer = NameTransformer()
new_tree = transformer.visit(tree)
ast.fix_missing_locations(new_tree)
# Convert back to code
return ast.unparse(new_tree.body)