Skip to content

Contribution Proposal: Add stable softplus, mish, and logsumexp conversion support #5

@Ashutosh0x

Description

@Ashutosh0x

What I want to contribute

Add numerically stable ATen-to-Core resolver entries for softplus, mish, and logsumexp. These operations use exp(x) internally, which overflows in fp16 when x > ~11.09 (IEEE 754 half-precision max = 65504, ln(65504) = 11.09).

The existing replace_log_softmax in _aten_to_core.py (line 2541) already implements the stable max-shift decomposition. The proposed additions follow the same pattern.

Proposed decompositions

Operation Naive (unsafe) Stable (proposed)
softplus log(1 + exp(x)) max(x, 0) + log(1 + exp(-abs(x)))
mish x * tanh(log(1 + exp(x))) x * tanh(stable_softplus(x))
logsumexp log(sum(exp(x))) max(x) + log(sum(exp(x - max(x))))

The softplus decomposition ensures exp(-abs(x)) is bounded in (0, 1], making overflow impossible in any precision.

Context

These are the same decompositions being contributed to coremltools in PRs #2725, #2726, and #2727. The underlying fp16 overflow issue is documented in the Orion paper (arXiv:2603.06728) and affects any model using these activations on Apple Neural Engine.

Open question

Before writing code, I need to verify whether torch.export with run_decompositions() preserves these ops as ATen entries or decomposes them into primitives. If they are decomposed, a remove_decomps entry or a graph rewrite pass would be needed instead.

Scope

  • Add 3 resolver functions in _aten_to_core.py (~50 lines total)
  • Add corresponding dispatch table entries
  • Add tests in tests/ops/test_ops.py following the existing pattern (e.g., test_hardswish, test_log_softmax)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions