Summary
When a graph contains an int64-comparison → bool → float mask chain (e.g. ((ix0 >= 0) & (ix0 < W)).to(dtype) — the standard in-bounds mask of a gather-based bilinear sampler), executing that subgraph corrupts an unrelated, still-live intermediate tensor elsewhere in the graph. The victim tensor is provably computed correctly first (another consumer that runs before the sampler sees exact values), then reads back as garbage/NaN afterwards — including when it is a declared graph output. Looks like a buffer-liveness/aliasing bug in buffer planning. Deterministic; reproduces with cpu_only() (and on GPU); prog.optimize() not required; inserting clone()/contiguous() barriers does not protect the victim.
Replacing only the mask arithmetic with an equivalent float-only formulation makes the whole graph bit-exact, so the trigger is specifically the int64-comparison bool chain, not the gather or the surrounding structure.
Environment
- coreai-torch 0.4.0, coreai-core 1.0.0b1 (cp312), torch 2.11.0
- macOS 27.0 (build 26A5353q), M4 Max
Minimal repro (self-contained, deterministic)
The structure is the deformable-attention pattern: t = LayerNorm(...) feeding both a sampler query (cq = t + qpos) and a residual (out = LayerNorm2(t + sampler_out)); the sampler computes query-driven sampling locations, in-bounds masks, and a gather. Tapping all intermediates as outputs shows cq exact while t itself reads back NaN:
"""Self-contained repro: an int64-comparison bool chain inside a sampling subgraph
makes the runtime clobber a LIVE LayerNorm output consumed later by a residual.
No external deps beyond torch + coreai stack. Deterministic.
python standalone_clobber.py # bool-chain masks -> t corrupted (cos ~0)
python standalone_clobber.py float # float-arithmetic masks -> exact (cos 1.0)
"""
import asyncio
import shutil
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from coreai_torch import TorchConverter, get_decomp_table
import coreai.runtime as rt
FLOAT_MASKS = len(sys.argv) > 1 and sys.argv[1] == "float"
H = W = 36
C = 256
Q = 300
HEADS = 8
HDIM = C // HEADS
class Sampler(nn.Module):
"""Bilinear-style gather sampling driven by query-predicted locations
(the deformable-attention pattern, single level, one point)."""
def __init__(self):
super().__init__()
self.value_proj = nn.Linear(C, C)
self.loc = nn.Linear(C, HEADS * 2)
self.out = nn.Linear(C, C)
def forward(self, q, memory):
B = q.shape[0]
value = self.value_proj(memory).view(B, H * W, HEADS, HDIM)
value = value.permute(0, 2, 3, 1).reshape(B * HEADS, HDIM, H, W)
loc = self.loc(q).view(B, Q, HEADS, 2).permute(0, 2, 1, 3).reshape(B * HEADS, Q, 2)
ix = (loc[..., 0] + 1) * W / 2 - 0.5 # [B*HEADS, Q]
iy = (loc[..., 1] + 1) * H / 2 - 0.5
ix0 = ix.floor().long()
iy0 = iy.floor().long()
if FLOAT_MASKS:
ix0f = ix.floor()
iy0f = iy.floor()
ix0c = ix0f.clamp(0, W - 1)
iy0c = iy0f.clamp(0, H - 1)
mask = (1.0 - (ix0f - ix0c).abs().clamp(max=1.0)) * (
1.0 - (iy0f - iy0c).abs().clamp(max=1.0))
else:
# int64 comparison -> bool -> float chain: THE TRIGGER
in_x = ((ix0 >= 0) & (ix0 < W)).to(value.dtype)
in_y = ((iy0 >= 0) & (iy0 < H)).to(value.dtype)
mask = in_x * in_y
idx = (iy0.clamp(0, H - 1) * W + ix0.clamp(0, W - 1))
gathered = value.flatten(2).gather(
2, idx.unsqueeze(1).expand(B * HEADS, HDIM, Q)) # [B*H, HDIM, Q]
sampled = gathered * mask.unsqueeze(1)
return self.out(sampled.view(B, HEADS, HDIM, Q).permute(0, 3, 1, 2).reshape(B, Q, C))
class Repro(nn.Module):
def __init__(self):
super().__init__()
self.norm1 = nn.LayerNorm(C)
self.norm2 = nn.LayerNorm(C)
self.sampler = Sampler()
def forward(self, tgt, qpos, memory):
t = self.norm1(tgt + tgt)
cq = t + qpos # consumer 1 of t (correct)
t2 = self.sampler(cq, memory)
out = self.norm2(t + t2) # consumer 2 of t (reads clobbered data)
return t, cq, t2, out
def main():
torch.manual_seed(0)
mod = Repro().eval()
inputs = (torch.rand(1, Q, C), torch.rand(1, Q, C), torch.rand(1, H * W, C))
names = ["tgt", "qpos", "memory"]
out_names = ["t", "cq", "t2", "out"]
with torch.no_grad():
refs = mod(*inputs)
ep = torch.export.export(mod, inputs).run_decompositions(get_decomp_table())
prog = TorchConverter().add_exported_program(
exported_program=ep, input_names=names, output_names=out_names).to_coreai()
prog.optimize()
path = Path("/tmp/standalone_clobber.aimodel")
shutil.rmtree(path, ignore_errors=True)
prog.save_asset(path, rt.AIModelAssetMetadata())
async def run():
m = await rt.AIModel.load(path, rt.SpecializationOptions.cpu_only())
fn = m.load_function("main")
return await fn({n: rt.NDArray(t.numpy()) for n, t in zip(names, inputs)})
got = asyncio.run(run())
for n, ref in zip(out_names, refs):
g = got[n].numpy()
r = ref.numpy()
a, b = r.reshape(-1).astype(np.float64), g.reshape(-1).astype(np.float64)
cos = float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12))
print(f" {n:4s} cos={cos:.6f} max|d|={np.abs(r - g).max():.3e}")
if __name__ == "__main__":
main()
Output with the bool-chain masks (default):
t cos=nan max|d|=nan <- LayerNorm output, ALSO a graph output, clobbered
cq cos=1.000000 max|d|=4.8e-07 <- cq = t + qpos consumed t BEFORE the sampler ran: t was computed right
t2 cos=1.000000 max|d|=4.2e-07
out cos=nan max|d|=nan <- residual re-read of t after the sampler
Output with python standalone_clobber.py float (identical math, float-arithmetic masks):
t cos=1.000000 max|d|=4.8e-07
cq cos=1.000000 max|d|=4.8e-07
t2 cos=1.000000 max|d|=4.2e-07
out cos=1.000000 max|d|=7.2e-07
In larger graphs (RF-DETR's full decoder) the same trigger corrupts the tensor to plausible-looking garbage rather than NaN, which makes it nastier to catch — output cosine ~0.65 with no error raised.
Expected
A live tensor (especially a declared output) must survive execution of unrelated subgraphs.
Workaround
Compute in-bounds masks in float arithmetic only — for integer-valued floats,
1 - (x - x.clamp(lo, hi)).abs().clamp(max=1) is an exact 0/1 mask — and cast to int only at the gather index.
Summary
When a graph contains an int64-comparison → bool → float mask chain (e.g.
((ix0 >= 0) & (ix0 < W)).to(dtype)— the standard in-bounds mask of a gather-based bilinear sampler), executing that subgraph corrupts an unrelated, still-live intermediate tensor elsewhere in the graph. The victim tensor is provably computed correctly first (another consumer that runs before the sampler sees exact values), then reads back as garbage/NaN afterwards — including when it is a declared graph output. Looks like a buffer-liveness/aliasing bug in buffer planning. Deterministic; reproduces withcpu_only()(and on GPU);prog.optimize()not required; insertingclone()/contiguous()barriers does not protect the victim.Replacing only the mask arithmetic with an equivalent float-only formulation makes the whole graph bit-exact, so the trigger is specifically the int64-comparison bool chain, not the gather or the surrounding structure.
Environment
Minimal repro (self-contained, deterministic)
The structure is the deformable-attention pattern:
t = LayerNorm(...)feeding both a sampler query (cq = t + qpos) and a residual (out = LayerNorm2(t + sampler_out)); the sampler computes query-driven sampling locations, in-bounds masks, and a gather. Tapping all intermediates as outputs showscqexact whiletitself reads back NaN:Output with the bool-chain masks (default):
Output with
python standalone_clobber.py float(identical math, float-arithmetic masks):In larger graphs (RF-DETR's full decoder) the same trigger corrupts the tensor to plausible-looking garbage rather than NaN, which makes it nastier to catch — output cosine ~0.65 with no error raised.
Expected
A live tensor (especially a declared output) must survive execution of unrelated subgraphs.
Workaround
Compute in-bounds masks in float arithmetic only — for integer-valued floats,
1 - (x - x.clamp(lo, hi)).abs().clamp(max=1)is an exact 0/1 mask — and cast to int only at the gather index.