Skip to content

ANECompiler crash (ANECPlistInterface::addOpToNetwork EXC_BAD_ACCESS) at AIModel.load when slice_update begin/end indices are runtime values #6

@scndls

Description

@scndls

Compile crash (ANECompiler ANECPlistInterface::addOpToNetwork EXC_BAD_ACCESS) when coreai.slice_update indices are runtime values

Summary

A fully static-shape stateful model whose KV-cache write offset is computed from an
input's values
(instead of being a compile-time constant) exports and saves fine, then
segfaults inside ANECompiler while loading/compiling the .aimodel:

thread #13, queue = 'MPSGraphExecutable_queue', stop reason = EXC_BAD_ACCESS (code=1, address=0x12a)
  frame #0: ANECompiler`mlir::anecir::ANECPlistInterface::addOpToNetwork(
      mlir::Operation*, mlir::anecir::ANECIRNetwork*, mlir::anecir::ANECIRWeightSerializer&) + 31400

The trigger is precisely the constancy of the begin/end operands of the slice update
(coreai_models' coreai::mutable_slice_updatecoreai.slice_update):

begin/end of the state slice update Result
compile-time constants runs (correct outputs)
computed from an input tensor's values (e.g. position_ids[0, :1]) EXC_BAD_ACCESS in ANECompiler at load

Everything else in the graph is identical and fully static — same shapes, same mask input,
same SDPA. Only the slice indices change from constant to runtime Value.

Why this matters

This is the third member of a bug family that currently blocks the natural export paths for
hybrid DeltaNet models (Qwen3.5/3.6, Qwen3-Next): dynamic context dims trip
#1 (SDPA externalize re-export) and
#2 (MPSGraph FloatType::getWidth()
null-deref), so the model must be exported fully static. But a static stateful decode
function then needs its KV write position as a runtime value — which this crash forbids.
(torch.export and the coreai-torch conversion handle it fine; the custom-op path even
documents "dynamic begin/end indices".)

Workaround we use: a sliding-window KV cache — every call shifts the cached window left by
the (static) query length and appends the new chunk at the end, so all slice indices stay
constant, and a host-built attention mask encodes validity/causality. Works, but costs a
full cache rewrite per step.

Environment

  • coreai-torch 0.4.0, coreai-models @ b1cb71b
  • torch 2.9.0, Python 3.11.15
  • macOS 27.0 (build 26A5353q), arm64 (Apple M1 Max, 32 GB)

Repro

Self-contained script below: one GQA attention layer with a 1-slot KV-cache state, exported
fully static (dynamic_shapes=None, SDPA decomposed), executed twice with persistent state:

  • MODE=runtime — write offset off = position_ids.reshape(-1).narrow(0, 0, 1);
    begin = cat([0,0,0, off, 0]), end = cat([1,B,H, off+Q, D])crash at AIModel.load.
  • MODE=constoff = zeros(1) (graph otherwise identical) → loads and runs.
  • MODE=roll — the constant-index sliding-window workaround → loads, runs, and matches the
    eager reference across stepped calls (rel err ≤ 5e-4).
# crash
MODE=runtime uv run python toy_static_step.py
# control / workaround
MODE=const   uv run python toy_static_step.py
MODE=roll    uv run python toy_static_step.py
toy_static_step.py
"""Repro: fully STATIC-shape stateful export, stepping through generation (persistent
KV state). A static decode function can't take its KV write position from a tensor's
SHAPE (everything is fixed), so the offset has to come from an input's VALUES — and
that is exactly what crashes.

  MODE=runtime  write offset = position_ids[0, :1] as a runtime VALUE feeding
                mutable_slice_update begin/end. Exports and saves fine, then CRASHES
                at AIModel.load: EXC_BAD_ACCESS in ANECompiler
                anecir::ANECPlistInterface::addOpToNetwork.
  MODE=const    offset baked to constant 0 (semantically wrong past chunk 0, this is
                a compile-only control). Graph otherwise identical. Loads and runs,
                which isolates the trigger to the non-constant slice indices.
  MODE=roll     the workaround: sliding-window cache. Every call shifts the cache
                left by Q and appends the new chunk at the END, so all slice indices
                stay CONSTANT. Newest token always at cache index MAX_CTX-1. A
                host-built mask encodes causality + validity for the sliding layout;
                RoPE stays absolute via position VALUES (runtime gather of cos/sin
                works fine). Loads, runs, matches eager across stepped calls.

Two sequential calls against persistent runtime state (chunk 0 = tokens 0..Q-1,
chunk 1 = tokens Q..2Q-1), each compared to the eager reference.

Env: MODE=roll|runtime|const (default roll), MASK=bool|float (default bool)

Run from a coreai-models checkout:

  MODE=runtime uv run python toy_static_step.py   # ANECompiler crash at load
  MODE=const   uv run python toy_static_step.py   # control: loads + runs
  MODE=roll    uv run python toy_static_step.py   # workaround: loads + runs + matches
"""

import asyncio
import os
import sys
import tempfile
from pathlib import Path

import torch
import torch.nn as nn

from coreai.runtime import AIModel, NDArray
import coreai_models.export.macos as _macos
from coreai_models.export.macos import export_to_coreai
from coreai_models.primitives._ops import mutable_slice_update
from coreai_models.primitives.macos.rope import initialize_rope
from coreai_models.primitives.macos.sdpa import SDPA

# Decompose SDPA (dodges apple/coreai-torch#1; also the only path verified to take a mask).
if os.environ.get("EXT_SDPA", "0") != "1":
    _macos._EXTERNALIZE_SPECS = [
        s for s in _macos._EXTERNALIZE_SPECS
        if getattr(s, "composite_op_name", "") != "scaled_dot_product_attention"
    ]

B, Q, MAX_CTX = 1, 4, 16
HID, NH, NKV, HD = 64, 4, 2, 16
MODE = os.environ.get("MODE", "roll")
MASK_KIND = os.environ.get("MASK", "bool")
DTYPE = torch.float16


def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    if n_rep == 1:
        return x
    b, h, s, d = x.shape
    return x[:, :, None, :, :].expand(b, h, n_rep, s, d).reshape(b, h * n_rep, s, d)


def _const_idx(*vals: int) -> torch.Tensor:
    return torch.tensor(vals, dtype=torch.int32)


class ToyAttn(nn.Module):
    """One GQA attention layer; KV stack has 1 slot."""

    def __init__(self) -> None:
        super().__init__()
        self.q_proj = nn.Linear(HID, NH * HD, bias=False)
        self.k_proj = nn.Linear(HID, NKV * HD, bias=False)
        self.v_proj = nn.Linear(HID, NKV * HD, bias=False)
        self.o_proj = nn.Linear(NH * HD, HID, bias=False)
        self.rope = initialize_rope(dims=HD, base=10000.0)
        self.sdpa = SDPA(is_causal=False)  # mask passed explicitly

    def forward(self, x, position_ids, attn_mask, k_cache, v_cache):
        b, s, _ = x.shape  # static
        q = self.q_proj(x).reshape(b, s, NH, HD).permute(0, 2, 1, 3)
        k = self.k_proj(x).reshape(b, s, NKV, HD).permute(0, 2, 1, 3)
        v = self.v_proj(x).reshape(b, s, NKV, HD).permute(0, 2, 1, 3)
        q = self.rope(q, position_ids=position_ids)
        k = self.rope(k, position_ids=position_ids)

        if MODE == "roll":
            # Sliding window: shift the slot left by Q, append the chunk at the END.
            # All slice indices constant; attention reads the freshly built window.
            old_k = k_cache.narrow(0, 0, 1).squeeze(0)  # (B, NKV, MAX_CTX, HD)
            old_v = v_cache.narrow(0, 0, 1).squeeze(0)
            win_k = torch.cat([old_k.narrow(2, Q, MAX_CTX - Q), k], dim=2)
            win_v = torch.cat([old_v.narrow(2, Q, MAX_CTX - Q), v], dim=2)
            begin = _const_idx(0, 0, 0, 0, 0)
            end = _const_idx(1, B, NKV, MAX_CTX, HD)
            mutable_slice_update(x=k_cache, update=win_k.unsqueeze(0), begin=begin, end=end)
            mutable_slice_update(x=v_cache, update=win_v.unsqueeze(0), begin=begin, end=end)
            k_all, v_all = win_k, win_v
        else:
            if MODE == "const":  # compile-only control; wrong results past chunk 0
                off = torch.zeros(1, dtype=torch.int32)
            else:  # "runtime": offset as a value -> ANECompiler crash (this issue)
                off = position_ids.reshape(-1).narrow(0, 0, 1).to(torch.int32)
            zero = torch.zeros(1, dtype=torch.int32)
            begin = torch.cat([zero, zero, zero, off, zero])
            end = torch.cat([_const_idx(1), _const_idx(B), _const_idx(NKV),
                             off + Q, _const_idx(HD)])
            mutable_slice_update(x=k_cache, update=k.unsqueeze(0), begin=begin, end=end)
            mutable_slice_update(x=v_cache, update=v.unsqueeze(0), begin=begin, end=end)
            k_all = k_cache.narrow(0, 0, 1).squeeze(0)
            v_all = v_cache.narrow(0, 0, 1).squeeze(0)

        k_all = _repeat_kv(k_all, NH // NKV)
        v_all = _repeat_kv(v_all, NH // NKV)
        mask = attn_mask if MASK_KIND == "bool" else attn_mask > 0.5
        out = self.sdpa(q, k_all, v_all, attn_mask=mask)
        out = out.permute(0, 2, 1, 3).reshape(b, s, NH * HD)
        return self.o_proj(out)


def chunk_inputs(step: int, torch_gen: torch.Generator):
    """Inputs for chunk `step` (tokens step*Q .. step*Q+Q-1)."""
    off = step * Q
    x = torch.randn(B, Q, HID, generator=torch_gen).to(DTYPE)
    position_ids = torch.arange(off, off + Q, dtype=torch.int32).unsqueeze(0)
    i = torch.arange(Q).unsqueeze(1)
    j = torch.arange(MAX_CTX).unsqueeze(0)
    if MODE == "roll":
        # After the in-call roll, cache index jj holds absolute pos off+Q-MAX_CTX+jj.
        # Query i (abs off+i) attends iff 0 <= abs(jj) <= off+i.
        abs_j = off + Q - MAX_CTX + j
        mask_bool = ((abs_j >= 0) & (abs_j <= off + i)).reshape(1, 1, Q, MAX_CTX)
    else:
        mask_bool = (j <= off + i).reshape(1, 1, Q, MAX_CTX)
    attn_mask = mask_bool if MASK_KIND == "bool" else mask_bool.to(DTYPE)
    return {"x": x, "position_ids": position_ids, "attn_mask": attn_mask}


def main() -> int:
    torch.manual_seed(0)
    gen = torch.Generator().manual_seed(42)
    model = ToyAttn().eval().to(DTYPE)
    print(f"=== static stepping toy: MODE={MODE} Q={Q} MAX_CTX={MAX_CTX} mask={MASK_KIND} ===",
          flush=True)

    state0 = {
        "k_cache": torch.zeros(1, B, NKV, MAX_CTX, HD, dtype=DTYPE),
        "v_cache": torch.zeros(1, B, NKV, MAX_CTX, HD, dtype=DTYPE),
    }
    chunks = [chunk_inputs(s, gen) for s in range(2)]

    # Eager reference: sequential calls mutating the same cache tensors.
    ref_state = {k: v.clone() for k, v in state0.items()}
    refs = []
    with torch.no_grad():
        for c in chunks:
            refs.append(model(**c, **ref_state))
    print("eager reference OK", flush=True)

    print("exporting (fully static)...", flush=True)
    trace = {**{k: v.clone() for k, v in chunks[0].items()},
             **{k: v.clone() for k, v in state0.items()}}
    prog = export_to_coreai(
        model, trace,
        input_names=("x", "position_ids", "attn_mask"),
        output_names=("out",),
        state_names=("k_cache", "v_cache"),
        dynamic_shapes=None,
    )
    prog.optimize()
    asset = Path(tempfile.mkdtemp(prefix="toy_step_")) / "m.aimodel"
    prog.save_asset(asset)
    print("EXPORT+OPTIMIZE+SAVE OK", flush=True)

    async def run():
        m = await AIModel.load(asset)
        fn = m.load_function(m.function_names[0])
        print("LOADED", flush=True)
        st = {k: NDArray(v.clone()) for k, v in state0.items()}
        outs = []
        for c in chunks:
            r = await fn(inputs={k: NDArray(v.clone()) for k, v in c.items()}, state=st)
            outs.append(torch.from_numpy(r["out"].numpy()))
        return outs

    outs = asyncio.run(run())

    ok = True
    for i, (got, ref) in enumerate(zip(outs, refs)):
        rel = ((got.float() - ref.float()).norm() / ref.float().norm()).item()
        print(f"chunk {i} (offset {i * Q}): rel err {rel:.4f}", flush=True)
        ok = ok and rel < 2e-2
    print(f"\nRESULT: {'GO' if ok else 'NO-GO'} (MODE={MODE})", flush=True)
    return 0 if ok else 1


if __name__ == "__main__":
    sys.exit(main())

Expected

Either compile coreai.slice_update with non-constant begin/end (the op accepts Value
operands), or reject the program with a diagnostic at conversion/optimize time instead of
crashing in ANECPlistInterface::addOpToNetwork during compilation.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions