Skip to content

[Codegen] Use default read semantics for LinalgExt scatter#24504

Open
moomoohorse321 wants to merge 1 commit into
iree-org:mainfrom
YuWei-CH:haor2/scatter-init-bufferization-read
Open

[Codegen] Use default read semantics for LinalgExt scatter#24504
moomoohorse321 wants to merge 1 commit into
iree-org:mainfrom
YuWei-CH:haor2/scatter-init-bufferization-read

Conversation

@moomoohorse321
Copy link
Copy Markdown

@moomoohorse321 moomoohorse321 commented May 16, 2026

The original impl. will introduce critical but very subtle bugs. We identified the bug when running Qwen-3.5, with decode-prefill-decode pattern. Because of the problem detailed below, the 2nd decode are completely wrong.

Problem

iree_linalg_ext.scatter was special-cased in the LinalgExt bufferization
external model as not reading its DPS init operand which is not correct.

Fix

Remove the ScatterOp-specific no-read override and use the default
DestinationStyleOpInterface bufferization read semantics.

After this change, out-of-place bufferization preserves the init value before
scatter updates it:

  alloc tmp[32][1][512]
  copy original -> tmp // which was previously dropped.
  scatter writes 11 rows into tmp

Minimal Repro

A simple overwrite scatter still needs to preserve elements that are not updated:

result = original // this was incorrectly dropped if it's not fixed with the config shown in test cases attached
for i in updated_indices:
  result[i] = update[i]

Example shape:

original: 32 x 1 x 512
updates : 11 x 1 x 512
indices : 11 rows
result : 32 x 1 x 512

Scatter writes only 11 rows. The remaining 21 rows must come from original.

Before this change, bufferization thinks outs(%original) as a destination
placeholder instead of a read. So it stores tmp[32][1][512] to output

Analysis

I propose that we remove the override because scatter op should be bufferized as reads except for some small corner cases. Overriding the scatter op does not preserve correctness.

Let's analyze the following cases for scatter:

  • updates is read: scatter must read the update value.
  • indices is read: scatter must read indices to know where to write.
  • mask, if present, is read: a false mask suppresses the update and preserves
    the original value.
  • DPS init / original is read:
    • elements not hit by indices must be inherited from original;
    • masked-off updates must preserve original;
    • the combiner region may use the old value, e.g. yield update + old;
    • unique_indices(false) combine/reduction-style scatter also needs the
      old/current value.

The only case where init may not need to be read is a separately proven
full-overwrite optimization: indices cover the entire result, mask cannot
suppress updates, and the combiner does not use the old value. That should be
handled by a dedicated copy-elision/full-overwrite analysis, not by the default
bufferization semantics.

Tests

Added LLVMGPU bufferization tests for:

  • overwrite scatter preserving original;
  • combiner scatter reading original;
  • masked scatter preserving original.

This removes the ScatterOp-specific no-read override in the LinalgExt bufferization external model and uses the default DestinationStyleOpInterface read semantics instead.

Signed-off-by: Yuwei Sun <yuweis2@illinois.edu>
Signed-off-by: Hao Ren <rhao8608@gmail.com>
Co-authored-by: Yuwei Sun <yuweis2@illinois.edu>
@krzysz00
Copy link
Copy Markdown
Contributor

I think it might be a good idea to create that followup copy elision pass as I suspect this might have performance implications.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants