""" Safe expression evaluation using AST parsing. This module provides safe_eval() for evaluating expressions from control plans. Only a whitelist of safe AST nodes and function calls are allowed. Security model: - Parse expression with ast.parse(mode='eval') - Walk AST and verify all nodes are in the whitelist - If valid, compile and eval with restricted locals/globals Allowed: - Constants: numbers, strings, booleans, None - Names: variable references (resolved from provided namespace) - BinOp: +, -, *, /, //, %, ** - UnaryOp: -, +, not - BoolOp: and, or - Compare: ==, !=, <, <=, >, >=, in, not in - Call: only allowlisted functions (min, max, abs, int, float, bool, clamp) - IfExp: ternary (x if cond else y) Forbidden: - Attribute access (obj.attr) - Subscript (obj[key]) - Lambda - Comprehensions (list, dict, set, generator) - Import - Call to non-allowlisted functions - Assignment expressions (:=) """ from __future__ import annotations import ast from typing import Any, Dict, Optional, Set class UnsafeExpressionError(Exception): """Raised when an expression contains unsafe AST nodes.""" pass # Allowlisted function names that can be called in expressions SAFE_FUNCTIONS: Set[str] = { "min", "max", "abs", "int", "float", "bool", "clamp", # Custom clamp function provided in builtins } # Allowlisted AST node types SAFE_NODES: Set[type] = { ast.Expression, # Top-level for mode='eval' ast.Constant, # Literals (numbers, strings, booleans, None) ast.Name, # Variable references ast.Load, # Load context for names ast.BinOp, # Binary operations ast.UnaryOp, # Unary operations ast.BoolOp, # Boolean operations (and, or) ast.Compare, # Comparisons ast.Call, # Function calls (restricted to SAFE_FUNCTIONS) ast.IfExp, # Ternary: x if cond else y # Operators ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.USub, # Unary minus ast.UAdd, # Unary plus ast.Not, # not ast.And, ast.Or, ast.Eq, ast.NotEq, ast.Lt, ast.LtE, ast.Gt, ast.GtE, ast.In, ast.NotIn, } def _clamp(x: float, lo: float, hi: float) -> float: """Clamp x to the range [lo, hi].""" return lo if x < lo else hi if x > hi else x def validate_expression(expr: str) -> None: """ Validate that an expression is safe to evaluate. Raises: UnsafeExpressionError: if expression contains unsafe constructs SyntaxError: if expression is not valid Python """ try: tree = ast.parse(expr, mode='eval') except SyntaxError as e: raise SyntaxError(f"Invalid expression: {e}") for node in ast.walk(tree): node_type = type(node) # Check if node type is allowed if node_type not in SAFE_NODES: raise UnsafeExpressionError( f"Unsafe AST node type: {node_type.__name__} in expression: {expr}" ) # Special check for Call nodes: only allow safe functions if isinstance(node, ast.Call): # Function must be a simple Name (no attribute access) if not isinstance(node.func, ast.Name): raise UnsafeExpressionError( f"Unsafe function call (not a simple name) in expression: {expr}" ) func_name = node.func.id if func_name not in SAFE_FUNCTIONS: raise UnsafeExpressionError( f"Unsafe function call: {func_name} in expression: {expr}. " f"Allowed: {', '.join(sorted(SAFE_FUNCTIONS))}" ) def safe_eval(expr: str, namespace: Dict[str, Any]) -> Any: """ Safely evaluate an expression with the given namespace. Args: expr: Python expression string namespace: dict mapping variable names to values Returns: The result of evaluating the expression Raises: UnsafeExpressionError: if expression contains unsafe constructs SyntaxError: if expression is not valid Python NameError: if expression references undefined variables Exception: for runtime errors (division by zero, etc.) """ # Validate expression safety validate_expression(expr) # Build safe globals with only our clamp function and builtins safe_globals: Dict[str, Any] = { "__builtins__": { "min": min, "max": max, "abs": abs, "int": int, "float": float, "bool": bool, "True": True, "False": False, "None": None, "clamp": _clamp, } } # Compile and evaluate code = compile(expr, "", "eval") return eval(code, safe_globals, namespace) def safe_eval_condition(expr: str, namespace: Dict[str, Any]) -> bool: """ Safely evaluate a boolean condition. Same as safe_eval but ensures result is converted to bool. """ result = safe_eval(expr, namespace) return bool(result) def extract_variable_names(expr: str) -> Set[str]: """ Extract all variable names referenced in an expression. This is useful for validation: checking that all referenced variables exist in the namespace before runtime. Returns: Set of variable names referenced in the expression Raises: SyntaxError: if expression is not valid Python """ try: tree = ast.parse(expr, mode='eval') except SyntaxError as e: raise SyntaxError(f"Invalid expression: {e}") names: Set[str] = set() for node in ast.walk(tree): if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): # Skip function names (they're provided in builtins) if node.id not in SAFE_FUNCTIONS and node.id not in {"True", "False", "None"}: names.add(node.id) return names def generate_python_code(expr: str, namespace_var: str = "pv") -> str: """ Generate Python code for an expression, with variables read from a dict. This is used by the compiler to generate HIL logic code. Args: expr: The expression string namespace_var: Name of the dict variable containing values Returns: Python code string that evaluates the expression Example: >>> generate_python_code("x + y * 2", "physical_values") "physical_values.get('x', 0) + physical_values.get('y', 0) * 2" """ # First validate the expression validate_expression(expr) # Parse and transform tree = ast.parse(expr, mode='eval') class NameTransformer(ast.NodeTransformer): """Transform Name nodes to dict.get() calls.""" def visit_Name(self, node: ast.Name) -> ast.AST: # Skip function names and builtins if node.id in SAFE_FUNCTIONS or node.id in {"True", "False", "None"}: return node # Transform: x -> pv.get('x', 0) if isinstance(node.ctx, ast.Load): return ast.Call( func=ast.Attribute( value=ast.Name(id=namespace_var, ctx=ast.Load()), attr='get', ctx=ast.Load() ), args=[ ast.Constant(value=node.id), ast.Constant(value=0) ], keywords=[] ) return node # Transform the tree transformer = NameTransformer() new_tree = transformer.visit(tree) ast.fix_missing_locations(new_tree) # Convert back to code return ast.unparse(new_tree.body)