Skip to content

Commit bf72368

Browse files
Refactor _safe_eval to extract AST evaluation logic
- Extract _eval_ast_node as a module-level function for better testability - Simplify _safe_eval by removing nested function definition - Improve code readability and maintainability - Preserve all security benefits of AST-based evaluation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 4e4f171 commit bf72368

1 file changed

Lines changed: 38 additions & 35 deletions

File tree

freeride/formula.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,40 @@
88
from freeride.exceptions import FormulaParseError
99

1010

11+
def _eval_ast_node(node, allowed_names, values):
12+
"""Recursively evaluate an AST node with restricted operations."""
13+
if isinstance(node, ast.Expression):
14+
return _eval_ast_node(node.body, allowed_names, values)
15+
16+
if isinstance(node, ast.BinOp):
17+
if isinstance(node.op, (ast.Add, ast.Sub, ast.Mult)):
18+
left = _eval_ast_node(node.left, allowed_names, values)
19+
right = _eval_ast_node(node.right, allowed_names, values)
20+
if isinstance(node.op, ast.Add):
21+
return left + right
22+
if isinstance(node.op, ast.Sub):
23+
return left - right
24+
if isinstance(node.op, ast.Mult):
25+
return left * right
26+
27+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
28+
operand = _eval_ast_node(node.operand, allowed_names, values)
29+
return operand if isinstance(node.op, ast.UAdd) else -operand
30+
31+
if isinstance(node, ast.Num): # Python <3.8 compatibility
32+
return node.n
33+
34+
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
35+
return node.value
36+
37+
if isinstance(node, ast.Name):
38+
if node.id not in allowed_names:
39+
raise FormulaParseError(f"Invalid variable '{node.id}' in equation")
40+
return values.get(node.id, 0.0)
41+
42+
raise FormulaParseError("Invalid syntax in equation")
43+
44+
1145
def _formula(equation: str):
1246
"""
1347
Parse a linear equation string and return an Affine object.
@@ -68,49 +102,18 @@ def _formula(equation: str):
68102
lhs, rhs = equation.split("=")
69103

70104
def _safe_eval(expr: str, **values: float) -> float:
71-
"""Safely evaluate arithmetic expressions for ``x`` and ``y``."""
72-
105+
"""Safely evaluate arithmetic expressions for x and y."""
73106
if not re.fullmatch(r"[0-9xy+\-*.()]+", expr):
74107
raise FormulaParseError("Invalid characters in equation")
75108

76109
try:
77-
node = ast.parse(expr, mode="eval")
110+
ast_tree = ast.parse(expr, mode="eval")
78111
except SyntaxError as exc:
79112
raise FormulaParseError("Invalid syntax in equation") from exc
80113

81114
allowed_names = {"x", "y"}
82-
83-
def _eval(n):
84-
if isinstance(n, ast.Expression):
85-
return _eval(n.body)
86-
if isinstance(n, ast.BinOp) and isinstance(
87-
n.op, (ast.Add, ast.Sub, ast.Mult)
88-
):
89-
left = _eval(n.left)
90-
right = _eval(n.right)
91-
if isinstance(n.op, ast.Add):
92-
return left + right
93-
if isinstance(n.op, ast.Sub):
94-
return left - right
95-
if isinstance(n.op, ast.Mult):
96-
return left * right
97-
if isinstance(n, ast.UnaryOp) and isinstance(n.op, (ast.UAdd, ast.USub)):
98-
val = _eval(n.operand)
99-
return val if isinstance(n.op, ast.UAdd) else -val
100-
if isinstance(n, ast.Num): # pragma: no cover - for Python <3.8
101-
return n.n
102-
if isinstance(n, ast.Constant): # for Python >=3.8
103-
if isinstance(n.value, (int, float)):
104-
return n.value
105-
if isinstance(n, ast.Name):
106-
if n.id not in allowed_names:
107-
raise FormulaParseError(
108-
f"Invalid variable '{n.id}' in equation"
109-
)
110-
return values.get(n.id, 0.0)
111-
raise FormulaParseError("Invalid syntax in equation")
112-
113-
return float(_eval(node))
115+
result = _eval_ast_node(ast_tree, allowed_names, values)
116+
return float(result)
114117

115118
# If equation is given explicitly as y = f(x)
116119
if lhs == "y" or rhs == "y":

0 commit comments

Comments
 (0)