Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 49 additions & 53 deletions freeride/formula.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,41 @@
"""
Formula module using sympy
"""
"""Parse linear and quadratic formula strings."""

import re
import ast
import operator
import re

from freeride.exceptions import FormulaParseError

_ALLOWED_BINARY_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
}


def _eval_ast_node(node, allowed_names, values):
"""Recursively evaluate an AST node with restricted operations."""
if isinstance(node, ast.Expression):
return _eval_ast_node(node.body, allowed_names, values)

if isinstance(node, ast.BinOp):
if isinstance(node.op, (ast.Add, ast.Sub, ast.Mult)):
left = _eval_ast_node(node.left, allowed_names, values)
right = _eval_ast_node(node.right, allowed_names, values)
if isinstance(node.op, ast.Add):
return left + right
if isinstance(node.op, ast.Sub):
return left - right
if isinstance(node.op, ast.Mult):
return left * right

if isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
result = _eval_ast_node(node.body, allowed_names, values)
elif isinstance(node, ast.BinOp) and type(node.op) in _ALLOWED_BINARY_OPERATORS:
left = _eval_ast_node(node.left, allowed_names, values)
right = _eval_ast_node(node.right, allowed_names, values)
result = _ALLOWED_BINARY_OPERATORS[type(node.op)](left, right)
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, (ast.UAdd, ast.USub)):
operand = _eval_ast_node(node.operand, allowed_names, values)
return operand if isinstance(node.op, ast.UAdd) else -operand

if isinstance(node, ast.Num): # Python <3.8 compatibility
return node.n

if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
return node.value

if isinstance(node, ast.Name):
result = operand if isinstance(node.op, ast.UAdd) else -operand
elif isinstance(node, ast.Num): # Python <3.8 compatibility
result = node.n
elif isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
result = node.value
elif isinstance(node, ast.Name):
if node.id not in allowed_names:
raise FormulaParseError(f"Invalid variable '{node.id}' in equation")
return values.get(node.id, 0.0)

raise FormulaParseError("Invalid syntax in equation")
result = values.get(node.id, 0.0)
else:
raise FormulaParseError("Invalid syntax in equation")

return result


def _formula(equation: str):
Expand Down Expand Up @@ -121,34 +117,34 @@ def _safe_eval(expr: str, **values: float) -> float:
if "y" in expr:
raise FormulaParseError("y must be expressed solely in terms of x")
intercept = _safe_eval(expr, x=0, y=0)
y1 = _safe_eval(expr, x=1, y=0)
slope = y1 - intercept
y_at_one = _safe_eval(expr, x=1, y=0)
slope = y_at_one - intercept
return float(intercept), float(slope)

# x = f(y)
if lhs == "x" or rhs == "x":
expr = rhs if lhs == "x" else lhs
if "x" in expr:
raise FormulaParseError("x must be expressed solely in terms of y")
b = _safe_eval(expr, x=0, y=0)
x1 = _safe_eval(expr, x=0, y=1)
m = x1 - b
if m == 0:
x_intercept_at_zero = _safe_eval(expr, x=0, y=0)
x_at_one = _safe_eval(expr, x=0, y=1)
inverse_slope = x_at_one - x_intercept_at_zero
if inverse_slope == 0:
raise FormulaParseError("Zero slope invalid for x=f(y)")
intercept = -b / m
slope = 1 / m
intercept = -x_intercept_at_zero / inverse_slope
slope = 1 / inverse_slope
return float(intercept), float(slope)

# General form ax + by = c
expr = f"{lhs}-({rhs})"
base = _safe_eval(expr, x=0, y=0)
a = _safe_eval(expr, x=1, y=0) - base
b = _safe_eval(expr, x=0, y=1) - base
c = -base
if b == 0:
x_coefficient = _safe_eval(expr, x=1, y=0) - base
y_coefficient = _safe_eval(expr, x=0, y=1) - base
constant = -base
if y_coefficient == 0:
raise FormulaParseError("Equation does not define y as a function of x")
intercept = c / b
slope = -a / b
intercept = constant / y_coefficient
slope = -x_coefficient / y_coefficient
return float(intercept), float(slope)


Expand Down Expand Up @@ -218,29 +214,29 @@ def _quadratic_formula(equation: str):
)

# Use regex to find terms
terms = re.findall(
(r"([+-]?(?:\d*\.)?\d*x²|[+-]?(?:\d*\.)?\d*x|" r"[+-]?(?:\d*\.)?\d+)"),
left_side,
)
term_pattern = r"[+-]?(?:\d*\.)?\d*x²|[+-]?(?:\d*\.)?\d*x|[+-]?(?:\d*\.)?\d+"
terms = re.findall(term_pattern, left_side)

a, b, c = 0, 0, 0
quadratic_coef = 0.0
linear_coef = 0.0
constant = 0.0

for term in terms:
if "x²" in term:
coef = term.replace("x²", "")
a = (
quadratic_coef = (
float(coef)
if coef and coef not in ("+", "-")
else (1 if coef in ("", "+") else -1)
)
elif "x" in term:
coef = term.replace("x", "")
b = (
linear_coef = (
float(coef)
if coef and coef not in ("+", "-")
else (1 if coef in ("", "+") else -1)
)
elif term:
c += float(term)
constant += float(term)

return a, b, c
return quadratic_coef, linear_coef, constant
Loading