Skip to content

Channel-reduce logsumexp on Apple Neural Engine has a hard fp16 overflow at x ≈ 7.63 (output drops to 0) #2690

@ChinChangYang

Description

@ChinChangYang

Description

Bug. Channel-reduce torch.logsumexp(x, dim=1, keepdim=True) exhibits a hard, single-step output collapse on Apple Neural Engine in fp16: at x ≈ 7.6313 the NE output drops from ≈ 11.10 (= log(32) + 7.63) to 0.0 across one grid point — not gradual precision loss. The cliff position matches log(65504 / C) for a C-element reduction (C = 32 here), pointing at fp16 overflow of Σ exp(x) inside the NE lowering of the MIL reduce_log_sum_exp op (i.e., the lowering does not apply max-shift before the exponential).

Affected. Models that route reduce_log_sum_exp to NE in fp16 over a reduction whose largest input element exceeds log(65504 / C). CPU and GPU compute units are unaffected. The cliff appears well below the asymptotic regime where lsexp(x, dim=1, keepdim=True) ≈ x + log(C) would justify any approximation.

Discovered while auditing PyTorch activations for the same NE-fp16 cliff signature first reported for softplus on this hardware. The lsexp cliff follows the same dynamic-range failure pattern as the softplus cliff but lands at a different x value — and the two ops do not share an internal kernel, since algebraically equivalent inputs produce different cliff x's. See Actual behavior for the side-by-side cross-op data table.

Steps to Reproduce

Save the following script as repro_logsumexp_ne_cliff.py and run with python repro_logsumexp_ne_cliff.py on an Apple-silicon Mac.

"""Minimal repro: channel-reduce logsumexp has a hard fp16 cliff on Apple Neural Engine.

Reproduces a hard, single-step output collapse at x ≈ 7.6313.
Requires macOS with Apple Neural Engine (M1/M2/M3/M4).

Run: python repro_logsumexp_ne_cliff.py
"""
import os
import tempfile

import coremltools as ct
import numpy as np
import torch
import torch.nn as nn
from coremltools.models.compute_device import MLNeuralEngineComputeDevice
from coremltools.models.compute_plan import MLComputePlan


SPATIAL = 8
CHANNELS = 32
FLAT_DIM = SPATIAL * SPATIAL * CHANNELS  # 8 * 8 * 32 = 2048


class M(nn.Module):
    """conv -> lsexp(channel-reduce, broadcast) -> flatten -> linear (pick-element).

    Topology chosen to attract NE routing for reduce_log_sum_exp: Conv2d(1->C, k=3,
    padding=same) followed by a Linear head with NE-friendly shapes. Smaller
    topologies compile but reduce_log_sum_exp routes to CPU.
    """

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, CHANNELS, kernel_size=3, padding="same")
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(FLAT_DIM, 16)
        with torch.no_grad():
            # Delta conv: conv output[k,i,j] = input[0,i,j] (kernel center only)
            self.conv.weight.zero_()
            self.conv.weight[:, 0, 1, 1] = 1.0
            self.conv.bias.zero_()
            # Pick-element head: fc.out[k] = flat[k] for k in 0..15
            self.fc.weight.zero_()
            self.fc.weight.fill_diagonal_(1.0)
            self.fc.bias.zero_()

    def forward(self, x):
        c = self.conv(x)
        # All 32 channels equal after delta conv -> lsexp(x_rep_C) = x + log(C).
        # expand_as broadcasts the 1-channel reduction back to 32 channels so the
        # downstream Linear sees the same (1, 32, 8, 8) shape softplus would.
        y = torch.logsumexp(c, dim=1, keepdim=True).expand_as(c)
        return self.fc(self.flatten(y))


def _op_kind(operator_name):
    """Strip 'iosXX.' / 'macOSXX.' namespace prefix from a MIL op_type."""
    return operator_name.split(".", 1)[1] if "." in operator_name else operator_name


def main():
    model = M().eval()

    # CPU sanity: forward computes lsexp(x_test repeated C times) = x_test + log(C).
    x_test = 2.5
    test_input = torch.full((1, 1, SPATIAL, SPATIAL), x_test, dtype=torch.float32)
    with torch.no_grad():
        cpu_out = model(test_input).numpy().flatten()
    lse_expected = float(x_test + np.log(float(CHANNELS)))
    assert np.allclose(cpu_out, lse_expected, atol=1e-5), (
        f"CPU sanity failed: cpu_out[:4]={cpu_out[:4]} != "
        f"x+log(C)={lse_expected}"
    )

    with tempfile.TemporaryDirectory() as d:
        traced = torch.jit.trace(model, test_input)
        mlm = ct.convert(
            traced,
            convert_to="mlprogram",
            inputs=[ct.TensorType(name="x", shape=test_input.shape)],
            minimum_deployment_target=ct.target.macOS14,
            compute_precision=ct.precision.FLOAT16,
        )
        pkg = os.path.join(d, "m.mlpackage")
        mlm.save(pkg)

        loaded = ct.models.MLModel(pkg, compute_units=ct.ComputeUnit.CPU_AND_NE)

        # Routing check — assert reduce_log_sum_exp dispatched to NE.
        plan = MLComputePlan.load_from_path(
            loaded.get_compiled_model_path(),
            compute_units=ct.ComputeUnit.CPU_AND_NE,
        )
        (fn,) = plan.model_structure.program.functions.values()
        ops = list(fn.block.operations)
        lse_ops = [op for op in ops if _op_kind(op.operator_name) == "reduce_log_sum_exp"]
        assert len(lse_ops) == 1, (
            f"expected exactly 1 reduce_log_sum_exp op, got {len(lse_ops)}"
        )
        usage = plan.get_compute_device_usage_for_mlprogram_operation(lse_ops[0])
        device_name = (
            type(usage.preferred_compute_device).__name__ if usage else "unknown"
        )
        assert usage is not None and isinstance(
            usage.preferred_compute_device, MLNeuralEngineComputeDevice
        ), f"reduce_log_sum_exp routed to {device_name}; this repro requires NE"

        # Sweep x in [-15, 15] @ 2000 points, capture lsexp output at each x.
        out_name = loaded.get_spec().description.output[0].name
        N = 2000
        xs = np.linspace(-15.0, 15.0, N, dtype=np.float32)
        ys = np.empty(N, dtype=np.float32)
        for i, xi in enumerate(xs):
            inp = np.full((1, 1, SPATIAL, SPATIAL), float(xi), dtype=np.float32)
            ys[i] = loaded.predict({"x": inp})[out_name].flat[0]

        # Reference: lsexp(x repeated C times) = x + log(C) (closed form for delta input).
        lse_ref = xs + np.log(float(CHANNELS))
        # Cliff: NE output collapses to ~0 while fp32 ref is large.
        cliff_idx = np.where((lse_ref > 5.0) & (ys < 1.0))[0]
        if cliff_idx.size:
            i = cliff_idx[0]
            print(
                f"CLIFF: x={xs[i]:.4f}  ne_out={ys[i]:.4f}  fp32_ref={lse_ref[i]:.4f}"
            )
        else:
            print("No cliff observed — please check NE actually engaged.")


if __name__ == "__main__":
    main()

Expected output on an Apple-silicon Mac:

CLIFF: x=7.6313  ne_out=0.0000  fp32_ref=11.0971

Expected behavior

torch.logsumexp(x, dim=1, keepdim=True) over a C-element reduction equals log(Σ_i exp(x_i)). When all reduction-axis entries equal x (the delta-replicated regime the repro produces), the closed form is lsexp = x + log(C). For x ∈ [7, 8] with C = 32, the fp32 reference values are ≈ 10.47–11.47 — already in the asymptotic regime where lsexp(x_replicated_C) ≈ x + log(C) to machine precision, since exp(x) >> 1 dominates the sum. Output should remain a smooth monotonic function across the full input range.

Actual behavior

Output drops from ≈ 11.10 to 0.0 in a single grid step at x ≈ 7.6313 (cliff observed at 2000-point linear sweep over [-15, 15], so step size ≈ 0.015). The transition is hard — adjacent grid points show ≈ 11.10 → 0.00. Increasing sweep resolution does not soften the transition; it locates it more precisely.

The cliff is exactly where naive fp16 overflow of Σ exp(x) predicts: with all 32 channels equal to x, the sum is 32 · exp(x), which overflows fp16 (= 65504) when x > log(65504 / 32) = log(2047) ≈ 7.6241. The 0.0072 gap to the observation (7.6313 − 7.6241) is well within a single sweep step. NE's lowering of MIL reduce_log_sum_exp therefore does not apply max-shift before the exponential — if it did, the bounded Σ exp(x − x_max) would fp16-saturate at log(C) ≈ 3.47, well below the cliff.

Image

Cross-op data: same fp16 dynamic-range family, different internal kernels. The lsexp cliff sits in a known NE-fp16 family with softplus, but the kernels are demonstrably distinct:

Op Form (delta-replicated input) NE-fp16 cliff x (observed) Naive fp16 overflow boundary Source
softplus log(1 + exp(x)) ≈ 10.395 (fine-sweep transition between 10.394 → 10.395) log(2¹⁵) ≈ 10.397 — points at a 2¹⁵ internal precision, not the full 2¹⁶ existing softplus issue
reduce_log_sum_exp lsexp(x_replicated_32, dim=1) (this report) 7.6313 log(65504 / 32) = log(2047) ≈ 7.6241 — full-2¹⁶ overflow on Σ_C exp(x) this issue
reduce_log_sum_exp lsexp(stack([x, 0]), dim=0) 11.0980 log(65504) ≈ 11.0899 — full-2¹⁶ overflow on the dominant exp(x) term follow-up scan; see Related

Two takeaways:

  1. reduce_log_sum_exp's NE lowering uniformly skips max-shift — both reduction sizes cliff at the naive log(65504 / C_eff) boundary, with C_eff set by the largest-magnitude term in the sum.
  2. NE's softplus and NE's reduce_log_sum_exp do not share an internal kernel: algebraically equivalent inputs (softplus(x) vs lsexp(stack([x, 0]), dim=0)) cliff at 10.395 and 11.098 respectively — both broken, but at different x values, pointing at distinct internal precisions (2¹⁵ vs 2¹⁶).

System environment

  • macOS: 26.3.1
  • Hardware: Apple M3 Max
  • coremltools: 9.0
  • PyTorch: 2.7.0
  • NumPy: 2.4.4
  • Python: 3.13.13

Workaround

A drop-in replacement using a host-side max-shift before calling torch.logsumexp eliminates the cliff while keeping every elementary op on NE:

def logsumexp_safe(x, dim=1, keepdim=True):
    # Standard textbook max-shift; recovers numerical stability without
    # changing the lsexp op.
    xm = x.amax(dim=dim, keepdim=True)
    shifted = torch.logsumexp(x - xm, dim=dim, keepdim=True) + xm
    return shifted if keepdim else shifted.squeeze(dim)

All four MIL ops introduced (reduce_max, sub, reduce_log_sum_exp, add) route to NE on this env. The cliff at x ≈ 7.63 disappears across the full [-15, 15] sweep:

Image

No measurable accuracy loss. Unlike the softplus safe identity, the max-shift form is mathematically exact for any reduction whose maximum element is finite — lsexp(x − x_max) + x_max is algebraically identical to lsexp(x), and the shifted inputs x − x_max ∈ (−∞, 0] keep exp(x − x_max) ∈ (0, 1], well below fp16 dynamic range. The full sweep curve overlaps the fp32 reference across the entire range after the shift.

Why this works. NE's reduce_log_sum_exp lowering computes log(Σ exp(x)) directly without max-shift, so Σ exp(x) overflows fp16 once C · exp(x_max) > 65504. Feeding shifted inputs caps x at 0, giving Σ exp(x − x_max) ≤ C ≈ 32 ≪ 65504 — well below the overflow boundary regardless of input magnitude. The op kernel itself is unchanged; only the inputs it sees are bounded.

Alternative — pure decomposition. If your deployment must avoid the reduce_log_sum_exp op entirely (e.g., op-allowlist constraints), the 6-primitive decomposition (x - x_max).exp().sum(dim, keepdim=True).log() + x_max is also validated NO_CLIFF + ALL_NE on this env (reduce_max, sub, exp, reduce_sum, log, add all route to NE). The hybrid above is the lighter-touch fix — 3 added MIL ops vs 5 — and is the recommended form unless reduce_log_sum_exp itself is disallowed.

A converter-side fix could lower MIL reduce_log_sum_exp to perform this max-shift internally when targeting NE — equivalent to teaching the kernel itself the standard logsumexp identity. We have not surveyed all reduce_log_sum_exp call sites; a maintainer should confirm scope.

Related

  • Existing softplus NE-fp16 cliff issue — same dynamic-range failure family on NE; provides the 2¹⁵-vs-2¹⁶ asymmetry shown in Actual behavior. Link once filed.
  • Reduce mish error by an alternative without softplus op #2618 — a separate PR proposing a fix for softplus precision on NE. Same family of NE-fp16 dynamic-range issue; an attribution-data comment supporting its premise has also been drafted.
  • Stack-of-2 lsexp variant (additional context). A second NE-routed reduce_log_sum_exp form was probed in follow-up scans: torch.logsumexp(torch.stack([x, zeros], dim=0), dim=0) — algebraically equivalent to softplus(x). It cliffs at x ≈ 11.0980, matching log(65504) ≈ 11.0899 (effective C = 1 because one stack element is pinned at 0, so the Σ exp is dominated by the single exp(x) term). The same host-side max-shift workaround applies. This variant is excluded from the primary repro to keep the bug narrative single-claim, but maintainers investigating the no-max-shift mechanism may want both data points: a 32-element reduction cliffing at log(65504 / 32) ≈ 7.62 and a 1-effective-element reduction cliffing at log(65504 / 1) ≈ 11.09 jointly confirm the log(65504 / C_eff) model.

Drafted by Claude Opus 4.7, and reviewed, verified, and edited by Me.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugUnexpected behaviour that should be corrected (type)

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions