265 lines
7.6 KiB
Python
265 lines
7.6 KiB
Python
"""
|
|
Safe expression evaluation using AST parsing.
|
|
|
|
This module provides safe_eval() for evaluating expressions from control plans.
|
|
Only a whitelist of safe AST nodes and function calls are allowed.
|
|
|
|
Security model:
|
|
- Parse expression with ast.parse(mode='eval')
|
|
- Walk AST and verify all nodes are in the whitelist
|
|
- If valid, compile and eval with restricted locals/globals
|
|
|
|
Allowed:
|
|
- Constants: numbers, strings, booleans, None
|
|
- Names: variable references (resolved from provided namespace)
|
|
- BinOp: +, -, *, /, //, %, **
|
|
- UnaryOp: -, +, not
|
|
- BoolOp: and, or
|
|
- Compare: ==, !=, <, <=, >, >=, in, not in
|
|
- Call: only allowlisted functions (min, max, abs, int, float, bool, clamp)
|
|
- IfExp: ternary (x if cond else y)
|
|
|
|
Forbidden:
|
|
- Attribute access (obj.attr)
|
|
- Subscript (obj[key])
|
|
- Lambda
|
|
- Comprehensions (list, dict, set, generator)
|
|
- Import
|
|
- Call to non-allowlisted functions
|
|
- Assignment expressions (:=)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
from typing import Any, Dict, Optional, Set
|
|
|
|
|
|
class UnsafeExpressionError(Exception):
|
|
"""Raised when an expression contains unsafe AST nodes."""
|
|
pass
|
|
|
|
|
|
# Allowlisted function names that can be called in expressions
|
|
SAFE_FUNCTIONS: Set[str] = {
|
|
"min",
|
|
"max",
|
|
"abs",
|
|
"int",
|
|
"float",
|
|
"bool",
|
|
"clamp", # Custom clamp function provided in builtins
|
|
}
|
|
|
|
# Allowlisted AST node types
|
|
SAFE_NODES: Set[type] = {
|
|
ast.Expression, # Top-level for mode='eval'
|
|
ast.Constant, # Literals (numbers, strings, booleans, None)
|
|
ast.Name, # Variable references
|
|
ast.Load, # Load context for names
|
|
ast.BinOp, # Binary operations
|
|
ast.UnaryOp, # Unary operations
|
|
ast.BoolOp, # Boolean operations (and, or)
|
|
ast.Compare, # Comparisons
|
|
ast.Call, # Function calls (restricted to SAFE_FUNCTIONS)
|
|
ast.IfExp, # Ternary: x if cond else y
|
|
# Operators
|
|
ast.Add,
|
|
ast.Sub,
|
|
ast.Mult,
|
|
ast.Div,
|
|
ast.FloorDiv,
|
|
ast.Mod,
|
|
ast.Pow,
|
|
ast.USub, # Unary minus
|
|
ast.UAdd, # Unary plus
|
|
ast.Not, # not
|
|
ast.And,
|
|
ast.Or,
|
|
ast.Eq,
|
|
ast.NotEq,
|
|
ast.Lt,
|
|
ast.LtE,
|
|
ast.Gt,
|
|
ast.GtE,
|
|
ast.In,
|
|
ast.NotIn,
|
|
}
|
|
|
|
|
|
def _clamp(x: float, lo: float, hi: float) -> float:
|
|
"""Clamp x to the range [lo, hi]."""
|
|
return lo if x < lo else hi if x > hi else x
|
|
|
|
|
|
def validate_expression(expr: str) -> None:
|
|
"""
|
|
Validate that an expression is safe to evaluate.
|
|
|
|
Raises:
|
|
UnsafeExpressionError: if expression contains unsafe constructs
|
|
SyntaxError: if expression is not valid Python
|
|
"""
|
|
try:
|
|
tree = ast.parse(expr, mode='eval')
|
|
except SyntaxError as e:
|
|
raise SyntaxError(f"Invalid expression: {e}")
|
|
|
|
for node in ast.walk(tree):
|
|
node_type = type(node)
|
|
|
|
# Check if node type is allowed
|
|
if node_type not in SAFE_NODES:
|
|
raise UnsafeExpressionError(
|
|
f"Unsafe AST node type: {node_type.__name__} in expression: {expr}"
|
|
)
|
|
|
|
# Special check for Call nodes: only allow safe functions
|
|
if isinstance(node, ast.Call):
|
|
# Function must be a simple Name (no attribute access)
|
|
if not isinstance(node.func, ast.Name):
|
|
raise UnsafeExpressionError(
|
|
f"Unsafe function call (not a simple name) in expression: {expr}"
|
|
)
|
|
func_name = node.func.id
|
|
if func_name not in SAFE_FUNCTIONS:
|
|
raise UnsafeExpressionError(
|
|
f"Unsafe function call: {func_name} in expression: {expr}. "
|
|
f"Allowed: {', '.join(sorted(SAFE_FUNCTIONS))}"
|
|
)
|
|
|
|
|
|
def safe_eval(expr: str, namespace: Dict[str, Any]) -> Any:
|
|
"""
|
|
Safely evaluate an expression with the given namespace.
|
|
|
|
Args:
|
|
expr: Python expression string
|
|
namespace: dict mapping variable names to values
|
|
|
|
Returns:
|
|
The result of evaluating the expression
|
|
|
|
Raises:
|
|
UnsafeExpressionError: if expression contains unsafe constructs
|
|
SyntaxError: if expression is not valid Python
|
|
NameError: if expression references undefined variables
|
|
Exception: for runtime errors (division by zero, etc.)
|
|
"""
|
|
# Validate expression safety
|
|
validate_expression(expr)
|
|
|
|
# Build safe globals with only our clamp function and builtins
|
|
safe_globals: Dict[str, Any] = {
|
|
"__builtins__": {
|
|
"min": min,
|
|
"max": max,
|
|
"abs": abs,
|
|
"int": int,
|
|
"float": float,
|
|
"bool": bool,
|
|
"True": True,
|
|
"False": False,
|
|
"None": None,
|
|
"clamp": _clamp,
|
|
}
|
|
}
|
|
|
|
# Compile and evaluate
|
|
code = compile(expr, "<control_plan>", "eval")
|
|
return eval(code, safe_globals, namespace)
|
|
|
|
|
|
def safe_eval_condition(expr: str, namespace: Dict[str, Any]) -> bool:
|
|
"""
|
|
Safely evaluate a boolean condition.
|
|
|
|
Same as safe_eval but ensures result is converted to bool.
|
|
"""
|
|
result = safe_eval(expr, namespace)
|
|
return bool(result)
|
|
|
|
|
|
def extract_variable_names(expr: str) -> Set[str]:
|
|
"""
|
|
Extract all variable names referenced in an expression.
|
|
|
|
This is useful for validation: checking that all referenced
|
|
variables exist in the namespace before runtime.
|
|
|
|
Returns:
|
|
Set of variable names referenced in the expression
|
|
|
|
Raises:
|
|
SyntaxError: if expression is not valid Python
|
|
"""
|
|
try:
|
|
tree = ast.parse(expr, mode='eval')
|
|
except SyntaxError as e:
|
|
raise SyntaxError(f"Invalid expression: {e}")
|
|
|
|
names: Set[str] = set()
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
|
|
# Skip function names (they're provided in builtins)
|
|
if node.id not in SAFE_FUNCTIONS and node.id not in {"True", "False", "None"}:
|
|
names.add(node.id)
|
|
|
|
return names
|
|
|
|
|
|
def generate_python_code(expr: str, namespace_var: str = "pv") -> str:
|
|
"""
|
|
Generate Python code for an expression, with variables read from a dict.
|
|
|
|
This is used by the compiler to generate HIL logic code.
|
|
|
|
Args:
|
|
expr: The expression string
|
|
namespace_var: Name of the dict variable containing values
|
|
|
|
Returns:
|
|
Python code string that evaluates the expression
|
|
|
|
Example:
|
|
>>> generate_python_code("x + y * 2", "physical_values")
|
|
"physical_values.get('x', 0) + physical_values.get('y', 0) * 2"
|
|
"""
|
|
# First validate the expression
|
|
validate_expression(expr)
|
|
|
|
# Parse and transform
|
|
tree = ast.parse(expr, mode='eval')
|
|
|
|
class NameTransformer(ast.NodeTransformer):
|
|
"""Transform Name nodes to dict.get() calls."""
|
|
|
|
def visit_Name(self, node: ast.Name) -> ast.AST:
|
|
# Skip function names and builtins
|
|
if node.id in SAFE_FUNCTIONS or node.id in {"True", "False", "None"}:
|
|
return node
|
|
|
|
# Transform: x -> pv.get('x', 0)
|
|
if isinstance(node.ctx, ast.Load):
|
|
return ast.Call(
|
|
func=ast.Attribute(
|
|
value=ast.Name(id=namespace_var, ctx=ast.Load()),
|
|
attr='get',
|
|
ctx=ast.Load()
|
|
),
|
|
args=[
|
|
ast.Constant(value=node.id),
|
|
ast.Constant(value=0)
|
|
],
|
|
keywords=[]
|
|
)
|
|
return node
|
|
|
|
# Transform the tree
|
|
transformer = NameTransformer()
|
|
new_tree = transformer.visit(tree)
|
|
ast.fix_missing_locations(new_tree)
|
|
|
|
# Convert back to code
|
|
return ast.unparse(new_tree.body)
|