Skip to content

externalize: SDPA submodule re-export drops the upper bound on the key-length dim with a static query + dynamic KV context #1

@scndls

Description

@scndls

externalize: SDPA submodule re-export drops the upper bound on the key-length dim when the query length is static but the KV context is dynamic

Summary

When a model is exported with a static query length but a dynamic KV-context length
(the standard prefill-chunk / decode shape: a fixed-size query attending to a growing cached
context — exactly the convention in coreai_models/models/macos/gpt_oss.py), the externalize
pass re-exports the SDPA submodule standalone and reconstructs its key sequence dim with
an unbounded upper bound. torch.export then rejects it:

RuntimeError: Internal error: failed to export submodule 'sdpa_061e31ac': Constraints violated (d_20)!
  - Not all values of d_20 = L['key'].size()[2] in the specified range satisfy the generated guard
    12 <= L['key'].size()[2] and L['key'].size()[2] <= IntInfinity()
Suggested fixes:
  d_20 = Dim('d_20', min=12)
This is a coreai-torch bug. Please report it.

The whole-model export bounds the context length cleanly (via the position_ids / KV-cache
Dims, both max=cap), but the standalone SDPA submodule re-export does not inherit those
bounds
— its key-length symbol comes back as [query_len, +inf), which is not a valid Dim.

The shipped models don't hit this because they keep the query dynamic as well, so query and
key share a single bounded symbol. The bug only surfaces when query_len is static (a fixed
prefill chunk / single decode step) while the context is dynamic — which is what you need for
hybrid linear-attention models (e.g. Qwen3.5 / Qwen3-Next Gated DeltaNet), where the query must
be a static chunk so the recurrence's scf.while lowers, while attention layers still read a
dynamic KV context.

Environment

  • coreai-torch 0.4.0
  • coreai-models @ b1cb71b
  • torch 2.9.0, Python 3.11.15
  • macOS 27.0 (build 26A5353q), arm64 (Apple M1 Max, 32 GB)

Minimal repro

A single attention layer (SDPA + RoPE + RMSNorm + KVCache, all from coreai_models),
exported with a static query and a dynamic KV context. No exotic ops.

import torch, torch.nn as nn
from coreai_models.export.macos import export_to_coreai
from coreai_models.primitives.macos.cache import KVCache
from coreai_models.primitives.macos.rms_norm import RMSNorm
from coreai_models.primitives.macos.rope import initialize_rope
from coreai_models.primitives.macos.sdpa import SDPA

HID, NH, NKV, HD = 256, 8, 2, 64

def repeat_kv(x, n):
    b, h, s, d = x.shape
    return x[:, :, None, :, :].expand(b, h, n, s, d).reshape(b, h * n, s, d)

class AttnLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.q_proj = nn.Linear(HID, NH * HD, bias=False)
        self.k_proj = nn.Linear(HID, NKV * HD, bias=False)
        self.v_proj = nn.Linear(HID, NKV * HD, bias=False)
        self.o_proj = nn.Linear(NH * HD, HID, bias=False)
        self.q_norm = RMSNorm(HD, eps=1e-6); self.k_norm = RMSNorm(HD, eps=1e-6)
        self.rope = initialize_rope(dims=HD, base=1e5)
        self.sdpa = SDPA(is_causal=True)

    def forward(self, x, position_ids, k_cache, v_cache):
        b, query_len, _ = x.shape
        cache = KVCache(k_cache, v_cache)
        sequence_length = position_ids.shape[-1]
        torch._check_is_size(sequence_length)
        offset = sequence_length - query_len
        torch._check_is_size(offset)
        q = self.q_proj(x).reshape(b, query_len, NH, HD).permute(0, 2, 1, 3)
        k = self.k_proj(x).reshape(b, query_len, NKV, HD).permute(0, 2, 1, 3)
        v = self.v_proj(x).reshape(b, query_len, NKV, HD).permute(0, 2, 1, 3)
        q = self.q_norm(q); k = self.k_norm(k)
        rp = position_ids.narrow(-1, offset, query_len)
        q = self.rope(q, position_ids=rp); k = self.rope(k, position_ids=rp)
        k, v = cache.update_and_fetch(0, offset, k, v, seq_len=sequence_length, query_len=query_len)
        k = repeat_kv(k, NH // NKV); v = repeat_kv(v, NH // NKV)
        out = self.sdpa(q, k, v).permute(0, 2, 1, 3).reshape(b, query_len, NH * HD)
        return self.o_proj(out)

torch.manual_seed(0)
s, ctx, cap = 12, 20, 64                      # static query=12, dynamic context (trace 20), cap 64
model = AttnLayer().to(torch.float16).eval()
inputs = {
    "x": torch.randn(1, s, HID, dtype=torch.float16),
    "position_ids": torch.arange(ctx, dtype=torch.int32).unsqueeze(0),
    "k_cache": torch.zeros(1, 1, NKV, cap, HD, dtype=torch.float16),
    "v_cache": torch.zeros(1, 1, NKV, cap, HD, dtype=torch.float16),
}
dynamic_shapes = {
    "x": None,                                                       # STATIC query
    "position_ids": {1: torch.export.Dim("ctx",  min=s + 1, max=cap)},  # DYNAMIC context
    "k_cache":      {3: torch.export.Dim("kseq", min=s + 1, max=cap)},  # DYNAMIC cache seq
    "v_cache":      {3: torch.export.Dim("vseq", min=s + 1, max=cap)},
}
prog = export_to_coreai(model, inputs, input_names=("x", "position_ids"),
                        output_names=("out",), state_names=("k_cache", "v_cache"),
                        dynamic_shapes=dynamic_shapes)   # <-- raises the d_20 error above

Expected vs actual

  • Expected: the externalized SDPA submodule's key-length dim is reconstructed with the
    same [query_len, cap] bounds the whole-model export already proved, and export succeeds.
  • Actual: the standalone re-export gives the key dim [query_len, +inf) and torch.export
    raises ConstraintViolationError (d_20).

Notes / workaround

  • Dropping SDPA from the externalize list (so it decomposes to primitive ops) avoids this
    path and the model exports + loads. (We hit a separate MPSGraph issue downstream in larger
    hybrid models, but that's unrelated to this report.)
  • Adding torch._check(key.size(-2) <= cap) / >= query_len) on the exact SDPA inputs in the
    parent forward does not propagate into the standalone submodule re-export.
  • _dim_for_sym in _utils.py reads var_to_range for the reconstructed Dim, but the
    unbounded range originates from the submodule's fake-input symbol, so patching _dim_for_sym
    alone does not help — the fix likely needs the submodule re-export to carry the parent's
    symbol ranges (or re-narrow the key to a bounded dim before SDPA).

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions