Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ dev = [
"jupyterlab>=4.4.4",
"forallpeople>=2.7.1",
"black>=25.1.0",
"numpy>=2.4.4",
]
5 changes: 4 additions & 1 deletion src/handcalcs/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
"greek_exclusions": [],
"param_columns": 3,
"preferred_string_formatter": "L",
"array_truncate_threshold": -1,
"array_truncate_end": 3,
"matrix_environment": "bmatrix",
"custom_symbols": {},
"custom_brackets": {}
}
}
224 changes: 203 additions & 21 deletions src/handcalcs/handcalcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import os
import pathlib
import re
from typing import Any, Union, Optional, Tuple, List
from typing import Any, Sequence, Union, Optional, Tuple, List
import pyparsing as pp

from handcalcs.constants import GREEK_UPPER, GREEK_LOWER
Expand Down Expand Up @@ -820,7 +820,19 @@ def convert_conditional(line, calculated_results, **config_options):

@convert_line.register(ParameterLine)
def convert_parameter(line, calculated_results, **config_options):
# split_parameter_line stores the raw computed Python value (e.g. a numpy
# array) as the third token of line.line. The symbolic processing pipeline
# operates entirely on string tokens; non-string objects cause crashes in
# functions like swap_for_greek that call str() on every token.
# Solution: extract the value before symbolic processing and restore it
# afterward. round_and_render_parameter then calls latex_repr on the raw
# value with the correct cell-level precision.
raw_value = None
if len(line.line) == 3 and not isinstance(line.line[-1], (str, deque)):
raw_value = line.line.pop()
line.line = swap_symbolic_calcs(line.line, calculated_results, **config_options)
if raw_value is not None:
line.line.append(raw_value)
return line


Expand Down Expand Up @@ -1210,37 +1222,105 @@ def render_latex_str(
return outgoing


def _is_2d_array(item) -> bool:
"""Returns True if item is a 2D sequence (list-of-lists, 2D numpy array, etc.)."""
if isinstance(item, (str, dict)) or not hasattr(item, "__len__"):
return False
try:
if len(item) == 0:
return False
except TypeError:
# Objects like numpy 0-d arrays and scalar pint Quantities define __len__
# but raise TypeError when called; treat them as non-sequences.
return False
if hasattr(item, "ndim"): # numpy / numpy-like: use ndim for reliable detection
return item.ndim == 2
first = item[0]
return hasattr(first, "__len__") and not isinstance(first, (str, dict))


def set_truncation(threshold=None, end=None) -> None:
"""Set array / matrix truncation options, accepting ``int`` or ``(row, col)`` tuples.

Unlike ``global_config.set_option``, this helper is not restricted to the
type stored in the config, so you can freely switch between a plain ``int``
and a per-dimension tuple without hitting a type mismatch error.

Parameters
----------
threshold : int or (int, int), optional
Maximum elements (1-D) or ``(row_limit, col_limit)`` (2-D) before
truncation. ``-1`` disables that dimension.
end : int or (int, int), optional
Number of tail elements shown after the ellipsis.

Examples
--------
>>> set_truncation(5) # scalar — same limit for rows and columns
>>> set_truncation((5, 3)) # tuple — 5 rows, 3 columns
>>> set_truncation((-1, 3)) # columns only (rows off)
>>> set_truncation(-1) # disable all truncation
"""
if threshold is not None:
global_config._config["array_truncate_threshold"] = threshold
if end is not None:
global_config._config["array_truncate_end"] = end


def _unpack_dim(value, dim: int) -> int:
"""Return the scalar for *dim* from either an ``int`` or a 2-tuple.

*value* is ``int`` or ``(row_value, col_value)``. *dim* is 0 for rows,
1 for columns.
"""
if isinstance(value, (tuple, list)) and len(value) == 2:
return int(value[dim])
return int(value)


def _extract_array_units(
item: Sequence[Any],
use_scientific_notation: bool,
precision: int,
preferred_formatter: str,
) -> tuple:
"""Strip pint units from a flat sequence, returning (magnitudes, unit_latex or None)."""
if len(item) > 0 and all(hasattr(v, "units") for v in item):
first_unit = item[0].units
if all(getattr(v, "units", None) == first_unit for v in item):
magnitudes = [v.magnitude if hasattr(v, "magnitude") else v for v in item]
unit_str = latex_repr(
first_unit, use_scientific_notation, precision, preferred_formatter
)
return magnitudes, unit_str
return item, None


def latex_repr(
item: Any, use_scientific_notation: bool, precision: int, preferred_formatter: str
) -> str:
"""
Return a str if the object, 'item', has a special repr method
for rendering itself in latex. If not, returns str(result).
"""
# Check for arrays
if hasattr(item, "__len__") and not isinstance(item, (str, dict)):
comma_space = ",\\ "
# Sympy checked first: sympy.Matrix has __len__ and would fall into the array
# path otherwise, producing an incorrect flat sequence instead of matrix notation.
if hasattr(item, "__sympy__"):
return render_sympy(round_sympy(item, precision, use_scientific_notation))

# 2D arrays rendered as a LaTeX matrix environment (no newlines — MathJax-safe).
if _is_2d_array(item):
try:
array = (
"["
+ comma_space.join(
[
latex_repr(
v, use_scientific_notation, precision, preferred_formatter
)
for v in item
]
)
+ "]"
)
rendered_string = array
return rendered_string
return render_matrix(item, use_scientific_notation, precision, preferred_formatter)
except TypeError:
pass

# Check for sympy objects
if hasattr(item, "__sympy__"):
return render_sympy(round_sympy(item, precision, use_scientific_notation))
# 1D arrays rendered as a bracketed sequence.
if hasattr(item, "__len__") and not isinstance(item, (str, dict)):
try:
return render_array(item, use_scientific_notation, precision, preferred_formatter)
except TypeError:
pass

# Check for scientific notation strings
if isinstance(item, str) and test_for_scientific_float(item):
Expand Down Expand Up @@ -1286,6 +1366,108 @@ def latex_repr(
return rendered_string.replace("$", "")


def render_array(
item: Sequence[Any],
use_scientific_notation: bool,
precision: int,
preferred_formatter: str,
) -> str:
"""Render a 1D list-like value as a bracketed LaTeX sequence."""
threshold = _unpack_dim(global_config._config.get("array_truncate_threshold", -1), 0)
tail = _unpack_dim(global_config._config.get("array_truncate_end", 3), 0)

display_items, unit_str = _extract_array_units(
item, use_scientific_notation, precision, preferred_formatter
)

def render(v):
return latex_repr(v, use_scientific_notation, precision, preferred_formatter)

if threshold != -1 and len(display_items) > threshold:
tail_count = min(tail, threshold)
head_count = threshold - tail_count
head_rendered = [render(v) for v in display_items[:head_count]]
tail_rendered = [render(v) for v in display_items[-tail_count:]] if tail_count > 0 else []
rendered = head_rendered + ["\\ldots"] + tail_rendered
else:
rendered = [render(v) for v in display_items]

result = "[" + ",\\ ".join(rendered) + "]"
if unit_str:
result += " \\cdot " + unit_str
return result


def render_matrix(
item: Sequence[Any],
use_scientific_notation: bool,
precision: int,
preferred_formatter: str,
) -> str:
"""Render a 2D array-like value as a LaTeX matrix environment.

Environment: ``matrix_environment`` config (default ``bmatrix``).

Truncation uses ``array_truncate_threshold`` and ``array_truncate_end``.
Both accept either a scalar (same limit for rows and columns) or a 2-tuple
``(row_limit, col_limit)`` for independent per-dimension control::

global_config.set_option("array_truncate_threshold", 5) # rows and columns
global_config._config["array_truncate_threshold"] = (5, 3) # rows=5, cols=3

Truncated rows are replaced with a ``\\vdots`` / ``\\ddots`` row;
truncated columns are replaced with ``\\cdots``. Output contains no
literal newlines so it embeds safely inside MathJax / Jupyter's aligned env.
"""
env = global_config._config.get("matrix_environment", "bmatrix")
threshold_raw = global_config._config.get("array_truncate_threshold", -1)
tail_raw = global_config._config.get("array_truncate_end", 3)
row_threshold = _unpack_dim(threshold_raw, 0)
col_threshold = _unpack_dim(threshold_raw, 1)
row_tail = _unpack_dim(tail_raw, 0)
col_tail = _unpack_dim(tail_raw, 1)

def render(v):
return latex_repr(v, use_scientific_notation, precision, preferred_formatter)

def render_row(row):
cells = [render(v) for v in row]
if col_threshold != -1 and len(cells) > col_threshold:
c_tail = min(col_tail, col_threshold)
c_head = col_threshold - c_tail
head = cells[:c_head]
tail = cells[-c_tail:] if c_tail > 0 else []
cells = head + ["\\cdots"] + tail
return " & ".join(cells)

rows = list(item)
ncols = len(rows[0]) if rows else 1

# Build the ellipsis row used when rows are truncated.
# Use \ddots at the column-ellipsis position, \vdots elsewhere.
if col_threshold != -1 and ncols > col_threshold:
c_tail = min(col_tail, col_threshold)
c_head = col_threshold - c_tail
ellipsis_row = " & ".join(
["\\vdots"] * c_head + ["\\ddots"] + (["\\vdots"] * c_tail if c_tail > 0 else [])
)
else:
ellipsis_row = " & ".join(["\\vdots"] * ncols)

if row_threshold != -1 and len(rows) > row_threshold:
r_tail = min(row_tail, row_threshold)
r_head = row_threshold - r_tail
head_rows = [render_row(r) for r in rows[:r_head]]
tail_rows = [render_row(r) for r in rows[-r_tail:]] if r_tail > 0 else []
rendered_rows = head_rows + [ellipsis_row] + tail_rows
else:
rendered_rows = [render_row(r) for r in rows]

# " \\ " is the LaTeX matrix row separator; no literal \n so MathJax-safe.
body = " \\\\ ".join(rendered_rows)
return f"\\begin{{{env}}}{body}\\end{{{env}}}"


def round_sympy(elem: Any, precision: int, use_scientific_notation: bool) -> Any:
"""
Returns the Sympy expression 'elem' rounded to 'precision'
Expand Down
Loading