Skip to content

Runtime clobbers an unrelated live tensor when an int64-comparison bool mask chain executes (deformable-attention sampler pattern) #11

@john-rocky

Description

@john-rocky

Summary

When a graph contains an int64-comparison → bool → float mask chain (e.g. ((ix0 >= 0) & (ix0 < W)).to(dtype) — the standard in-bounds mask of a gather-based bilinear sampler), executing that subgraph corrupts an unrelated, still-live intermediate tensor elsewhere in the graph. The victim tensor is provably computed correctly first (another consumer that runs before the sampler sees exact values), then reads back as garbage/NaN afterwards — including when it is a declared graph output. Looks like a buffer-liveness/aliasing bug in buffer planning. Deterministic; reproduces with cpu_only() (and on GPU); prog.optimize() not required; inserting clone()/contiguous() barriers does not protect the victim.

Replacing only the mask arithmetic with an equivalent float-only formulation makes the whole graph bit-exact, so the trigger is specifically the int64-comparison bool chain, not the gather or the surrounding structure.

Environment

  • coreai-torch 0.4.0, coreai-core 1.0.0b1 (cp312), torch 2.11.0
  • macOS 27.0 (build 26A5353q), M4 Max

Minimal repro (self-contained, deterministic)

The structure is the deformable-attention pattern: t = LayerNorm(...) feeding both a sampler query (cq = t + qpos) and a residual (out = LayerNorm2(t + sampler_out)); the sampler computes query-driven sampling locations, in-bounds masks, and a gather. Tapping all intermediates as outputs shows cq exact while t itself reads back NaN:

"""Self-contained repro: an int64-comparison bool chain inside a sampling subgraph
makes the runtime clobber a LIVE LayerNorm output consumed later by a residual.

No external deps beyond torch + coreai stack. Deterministic.

  python standalone_clobber.py          # bool-chain masks -> t corrupted (cos ~0)
  python standalone_clobber.py float    # float-arithmetic masks -> exact (cos 1.0)
"""
import asyncio
import shutil
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from coreai_torch import TorchConverter, get_decomp_table
import coreai.runtime as rt

FLOAT_MASKS = len(sys.argv) > 1 and sys.argv[1] == "float"

H = W = 36
C = 256
Q = 300
HEADS = 8
HDIM = C // HEADS


class Sampler(nn.Module):
    """Bilinear-style gather sampling driven by query-predicted locations
    (the deformable-attention pattern, single level, one point)."""

    def __init__(self):
        super().__init__()
        self.value_proj = nn.Linear(C, C)
        self.loc = nn.Linear(C, HEADS * 2)
        self.out = nn.Linear(C, C)

    def forward(self, q, memory):
        B = q.shape[0]
        value = self.value_proj(memory).view(B, H * W, HEADS, HDIM)
        value = value.permute(0, 2, 3, 1).reshape(B * HEADS, HDIM, H, W)
        loc = self.loc(q).view(B, Q, HEADS, 2).permute(0, 2, 1, 3).reshape(B * HEADS, Q, 2)
        ix = (loc[..., 0] + 1) * W / 2 - 0.5     # [B*HEADS, Q]
        iy = (loc[..., 1] + 1) * H / 2 - 0.5

        ix0 = ix.floor().long()
        iy0 = iy.floor().long()
        if FLOAT_MASKS:
            ix0f = ix.floor()
            iy0f = iy.floor()
            ix0c = ix0f.clamp(0, W - 1)
            iy0c = iy0f.clamp(0, H - 1)
            mask = (1.0 - (ix0f - ix0c).abs().clamp(max=1.0)) * (
                1.0 - (iy0f - iy0c).abs().clamp(max=1.0))
        else:
            # int64 comparison -> bool -> float chain: THE TRIGGER
            in_x = ((ix0 >= 0) & (ix0 < W)).to(value.dtype)
            in_y = ((iy0 >= 0) & (iy0 < H)).to(value.dtype)
            mask = in_x * in_y
        idx = (iy0.clamp(0, H - 1) * W + ix0.clamp(0, W - 1))
        gathered = value.flatten(2).gather(
            2, idx.unsqueeze(1).expand(B * HEADS, HDIM, Q))      # [B*H, HDIM, Q]
        sampled = gathered * mask.unsqueeze(1)
        return self.out(sampled.view(B, HEADS, HDIM, Q).permute(0, 3, 1, 2).reshape(B, Q, C))


class Repro(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm1 = nn.LayerNorm(C)
        self.norm2 = nn.LayerNorm(C)
        self.sampler = Sampler()

    def forward(self, tgt, qpos, memory):
        t = self.norm1(tgt + tgt)
        cq = t + qpos                      # consumer 1 of t (correct)
        t2 = self.sampler(cq, memory)
        out = self.norm2(t + t2)           # consumer 2 of t (reads clobbered data)
        return t, cq, t2, out


def main():
    torch.manual_seed(0)
    mod = Repro().eval()
    inputs = (torch.rand(1, Q, C), torch.rand(1, Q, C), torch.rand(1, H * W, C))
    names = ["tgt", "qpos", "memory"]
    out_names = ["t", "cq", "t2", "out"]
    with torch.no_grad():
        refs = mod(*inputs)

    ep = torch.export.export(mod, inputs).run_decompositions(get_decomp_table())
    prog = TorchConverter().add_exported_program(
        exported_program=ep, input_names=names, output_names=out_names).to_coreai()
    prog.optimize()
    path = Path("/tmp/standalone_clobber.aimodel")
    shutil.rmtree(path, ignore_errors=True)
    prog.save_asset(path, rt.AIModelAssetMetadata())

    async def run():
        m = await rt.AIModel.load(path, rt.SpecializationOptions.cpu_only())
        fn = m.load_function("main")
        return await fn({n: rt.NDArray(t.numpy()) for n, t in zip(names, inputs)})

    got = asyncio.run(run())
    for n, ref in zip(out_names, refs):
        g = got[n].numpy()
        r = ref.numpy()
        a, b = r.reshape(-1).astype(np.float64), g.reshape(-1).astype(np.float64)
        cos = float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))
        print(f"  {n:4s} cos={cos:.6f} max|d|={np.abs(r - g).max():.3e}")


if __name__ == "__main__":
    main()

Output with the bool-chain masks (default):

  t    cos=nan      max|d|=nan       <- LayerNorm output, ALSO a graph output, clobbered
  cq   cos=1.000000 max|d|=4.8e-07   <- cq = t + qpos consumed t BEFORE the sampler ran: t was computed right
  t2   cos=1.000000 max|d|=4.2e-07
  out  cos=nan      max|d|=nan       <- residual re-read of t after the sampler

Output with python standalone_clobber.py float (identical math, float-arithmetic masks):

  t    cos=1.000000 max|d|=4.8e-07
  cq   cos=1.000000 max|d|=4.8e-07
  t2   cos=1.000000 max|d|=4.2e-07
  out  cos=1.000000 max|d|=7.2e-07

In larger graphs (RF-DETR's full decoder) the same trigger corrupts the tensor to plausible-looking garbage rather than NaN, which makes it nastier to catch — output cosine ~0.65 with no error raised.

Expected

A live tensor (especially a declared output) must survive execution of unrelated subgraphs.

Workaround

Compute in-bounds masks in float arithmetic only — for integer-valued floats,
1 - (x - x.clamp(lo, hi)).abs().clamp(max=1) is an exact 0/1 mask — and cast to int only at the gather index.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions