What I want to contribute
Add numerically stable ATen-to-Core resolver entries for softplus, mish, and logsumexp. These operations use exp(x) internally, which overflows in fp16 when x > ~11.09 (IEEE 754 half-precision max = 65504, ln(65504) = 11.09).
The existing replace_log_softmax in _aten_to_core.py (line 2541) already implements the stable max-shift decomposition. The proposed additions follow the same pattern.
Proposed decompositions
| Operation |
Naive (unsafe) |
Stable (proposed) |
| softplus |
log(1 + exp(x)) |
max(x, 0) + log(1 + exp(-abs(x))) |
| mish |
x * tanh(log(1 + exp(x))) |
x * tanh(stable_softplus(x)) |
| logsumexp |
log(sum(exp(x))) |
max(x) + log(sum(exp(x - max(x)))) |
The softplus decomposition ensures exp(-abs(x)) is bounded in (0, 1], making overflow impossible in any precision.
Context
These are the same decompositions being contributed to coremltools in PRs #2725, #2726, and #2727. The underlying fp16 overflow issue is documented in the Orion paper (arXiv:2603.06728) and affects any model using these activations on Apple Neural Engine.
Open question
Before writing code, I need to verify whether torch.export with run_decompositions() preserves these ops as ATen entries or decomposes them into primitives. If they are decomposed, a remove_decomps entry or a graph rewrite pass would be needed instead.
Scope
- Add 3 resolver functions in
_aten_to_core.py (~50 lines total)
- Add corresponding dispatch table entries
- Add tests in
tests/ops/test_ops.py following the existing pattern (e.g., test_hardswish, test_log_softmax)
What I want to contribute
Add numerically stable ATen-to-Core resolver entries for
softplus,mish, andlogsumexp. These operations useexp(x)internally, which overflows in fp16 when x > ~11.09 (IEEE 754 half-precision max = 65504, ln(65504) = 11.09).The existing
replace_log_softmaxin_aten_to_core.py(line 2541) already implements the stable max-shift decomposition. The proposed additions follow the same pattern.Proposed decompositions
log(1 + exp(x))max(x, 0) + log(1 + exp(-abs(x)))x * tanh(log(1 + exp(x)))x * tanh(stable_softplus(x))log(sum(exp(x)))max(x) + log(sum(exp(x - max(x))))The softplus decomposition ensures
exp(-abs(x))is bounded in (0, 1], making overflow impossible in any precision.Context
These are the same decompositions being contributed to
coremltoolsin PRs #2725, #2726, and #2727. The underlying fp16 overflow issue is documented in the Orion paper (arXiv:2603.06728) and affects any model using these activations on Apple Neural Engine.Open question
Before writing code, I need to verify whether
torch.exportwithrun_decompositions()preserves these ops as ATen entries or decomposes them into primitives. If they are decomposed, aremove_decompsentry or a graph rewrite pass would be needed instead.Scope
_aten_to_core.py(~50 lines total)tests/ops/test_ops.pyfollowing the existing pattern (e.g.,test_hardswish,test_log_softmax)