Skip to content

GEMM support in waveasm LLVM path#1288

Open
Hardcode84 wants to merge 19 commits into
iree-org:mainfrom
Hardcode84:llvm-asm-backend-gemm-v2
Open

GEMM support in waveasm LLVM path#1288
Hardcode84 wants to merge 19 commits into
iree-org:mainfrom
Hardcode84:llvm-asm-backend-gemm-v2

Conversation

@Hardcode84

@Hardcode84 Hardcode84 commented Apr 9, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Enable the water-to-WaveASM GEMM path by plumbing dynamic buffer strides through the runtime and host wrapper, preserving wrapper argument locations, and keeping structured loops available for waveasm-translate.
  • Extend LLVM-to-WaveASM lowering for GEMM-related ops and tighten correctness around GEP byte offsets, LDS byte accounting, signed div/rem lowering, address-space handling, and scf.for loop control.
  • Add coverage with focused WaveASM lit tests for supported and unsupported lowering cases, plus a CDNA4-gated end-to-end GEMM test.

Notes

  • scf.for lowering still carries a TODO for full zero-trip semantics because waveasm.loop is currently do-while shaped.

Hardcode84 and others added 10 commits April 9, 2026 15:05
…e helper

emit_host_func was building arg_types from linear_bindings only (buffers +
scalars + symbols) while the kernel function also carried trailing stride
index args added by emit_func when dynamic_strides is active.  The count
mismatch caused a segfault during canonicalization because old kernel
arguments were left with dangling uses after erase.

- Match stride arg count in emit_host_func so the gpu.func wrapper has
  the correct signature.
- Add wave_get_stride runtime function (mirrors wave_get_dim but calls
  tensor.stride() instead of tensor.size()).
- Pass stride values through gpu.launch_func kernel operands.
- Add GEMM waveasm e2e test skeleton.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Extend the waveasm-translate-from-llvm pass to handle all LLVM dialect
ops generated by the GEMM kernel, including:

- LDS: llvm.mlir.addressof for globals, ptr<3> GEPs, ds_read/ds_write
- Arithmetic: llvm.sdiv/srem (power-of-2 -> shift/mask)
- Barriers: llvm.fence (no-op), rocdl.s.barrier -> s_barrier
- MFMA: rocdl.mfma.f32.16x16x16f16 -> v_mfma_f32_16x16x16_f16
- Vector: llvm.shufflevector (single-element extract)
- Constants: dense vector splat (MFMA accumulator init)
- Control flow: scf.for -> waveasm.loop with do-while semantics

Also update the water+waveasm lowering pipeline to preserve scf.for
(remove convert-scf-to-cf) and use alloca-to-global for LDS, matching
what the translation pass expects.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
- translate-from-llvm-mfma.mlir: MFMA, dense vector constant, shufflevector extract
- translate-from-llvm-scf-for.mlir: scf.for -> waveasm.loop with condition
- translate-from-llvm-lds.mlir: addressof, ptr<3> GEP, ds_read/ds_write
- translate-from-llvm-barrier.mlir: rocdl.barrier, rocdl.s.barrier, llvm.fence
- translate-from-llvm-sdiv-srem.mlir: power-of-2 sdiv/srem -> shift/mask

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
inferResultType and handleCastOp always produced 1-wide register types
regardless of the LLVM operation's actual bit width. This made
arith.trunc a no-op for i64->i32 truncations (both sides were vreg),
so legalization removed it as a passthrough before the sext it guarded
was expanded to its true 2-wide form -- leaking vreg<2,2> values into
LDS address chains and producing invalid ds_read/ds_write assembly.

Fix inferResultType to propagate max operand width and handleCastOp to
derive width from the LLVM result type. Now sext i32->i64 produces
vreg<2,2>, trunc i64->i32 produces vreg, and legalization handles
both correctly without the trunc being silently eliminated.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
The waveasm backend only supports gfx950 for now.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
The LDS base value set was redundant -- the LLVM load/store address
operand already carries the pointer type with the address space.
Check getLLVMAddrSpace(op.getAddr()) == 3 instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Summing GEP indices is wrong -- each index operates at a different
level of the type hierarchy and must be multiplied by the element size
at that level. The all-zero multi-index case (the only one we hit in
practice) was already handled correctly. Remove the bogus fallback and
return nullopt for unhandled multi-index GEPs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
The GEP offset computation was treating all indices as byte offsets
regardless of element type. Rename to computeGEPByteOffset and
multiply dynamic indices by getGEPElementBytes(). Replace local
getConstantInt with getConstantIntValue from StaticValueUtils. Use
explicit types instead of auto where the type is not obvious.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
- Buffer GEP with constant attr index (not a dynamic Value).
- GEP on unsupported address space (e.g. ptr<5>).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Reject unsupported non-zero structural GEPs and preserve signed power-of-2 div/rem semantics for negative inputs. Keep scf.for loop control scalar and add regression tests for the guarded cases.

Made-with: Cursor
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@Hardcode84 Hardcode84 force-pushed the llvm-asm-backend-gemm-v2 branch from 20f1b87 to 593790f Compare April 9, 2026 20:03
Compute LDS/global sizes in bytes and reject unsupported GEP element types so the translator fails instead of silently misaddressing memory. Restore source locations for original wrapper args, add lit coverage for the new behavior, and leave a TODO for the remaining scf.for zero-trip semantic gap.

Made-with: Cursor
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Use concrete local types in the new lowering paths so the translator is easier to audit and maintain without changing behavior.

Made-with: Cursor
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@Hardcode84 Hardcode84 changed the title WIP: GEMM support in waveasm LLVM path GEMM support in waveasm LLVM path Apr 9, 2026
@Hardcode84 Hardcode84 requested review from ftynse and harsh-nod April 9, 2026 21:04
Comment thread wave_lang/kernel/wave/water.py Outdated
Comment on lines +393 to +401
def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any]]:
nonlocal mlir_asm
# Erase the last '}' closing the module, append the transform, re-close.
last_close = mlir_asm.rfind("}")
if last_close != -1:
mlir_asm = mlir_asm[:last_close]
mlir_asm += transform
mlir_asm += "}\n"
return ("transform-interpreter", {"entry-point": entry_point})

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing an in-place fragile textual modification of a non-local variable is a terrible idea. We have API to modify IR properly. And I think we have enough bindings to build the script below rather than hardcode a string.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +325 to +333
# Match stride args added by emit_func when dynamic_strides is active.
stride_arg_count = 0
if self.options.dynamic_strides:
stride_arg_count = sum(
max(0, len(b.kernel_buffer_type.symbolic_shape) - 1)
for b in self.root_sig.sig.kernel_buffer_bindings
)
if stride_arg_count > 0:
arg_types += [IndexType.get()] * stride_arg_count

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this logic be factored out into a function also called from emit_func so we don't duplicate and risk divergence?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Preserve source locations for original args; synthesized stride args
# do not have an originating kernel SSA value.
locs = [a.location for a in kernel_func.body.blocks[0].arguments]
locs += [Location.unknown()] * (len(arg_types) - len(locs))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO they can have something line Location.name("stride #42 for argument", arg.loc) so we clearly see what they relate to in the debug build.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

get_stride_func_symbol,
[arg, dim],
)
stride = arith_d.index_cast(IndexType.get(), stride)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just pass it as i32 and not bother with index?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, probably, but it will require more changes across the codebase (StreamExecutable.define_entrypoint and other places), lets keep it as index for now.

Comment on lines +213 to +222
static int64_t getLLVMTypeBytes(Type ty) {
if (ty.isIntOrFloat())
return ty.getIntOrFloatBitWidth() / 8;
if (VectorType vecTy = dyn_cast<VectorType>(ty))
return vecTy.getNumElements() *
vecTy.getElementType().getIntOrFloatBitWidth() / 8;
if (LLVM::LLVMArrayType arrTy = dyn_cast<LLVM::LLVMArrayType>(ty))
return getLLVMTypeBytes(arrTy.getElementType()) * arrTy.getNumElements();
return 0;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have data layout support for this? It may in particular be consistent with things like sub-byte element types and non-power-of-two vector lengths.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure DataLayout is currently plumbed

Comment on lines +968 to +971
// Signed remainder by a positive power-of-2 constant:
// r = x & (divisor - 1)
// if (x < 0 && r != 0) r -= divisor
// This keeps the remainder's sign consistent with LLVM srem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

// SCF for/yield handler
//===----------------------------------------------------------------------===//

/// Forward declaration for recursive op translation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is anti-documentation. It will show up in, e.g., autocomplete, and provide negative value compared to no documentation at all.

if (isVGPRType(v.getType()))
return V_READFIRSTLANE_B32::create(builder, loc, sregTy, v).getResult();
op->emitOpError(name) << " must lower to an SGPR or VGPR i32";
return failure();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not return Value()?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall comment on the FailureOr<Value> vs Value, I saw agent was doing it and I intentionally didn't stopped it. In upstream MLIR there are some functions which never return empty Value and some which do and their result needed to be checked (or worse, historically, some function were switched from never returning null to returning), which is footgun. FailureOr gives a clear separation on type system level.

SmallVector<Value> initArgs;
initArgs.push_back(*lbScalar);
for (Value arg : op.getInitArgs()) {
FailureOr<Value> resolved = resolve(arg, ctx);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fly-by, there is rarely a need to wrap a nullable object into FailureOr, so I suspect resolve should be fixed.

Comment thread waveasm/lib/Transforms/TranslateFromLLVMDialect.cpp Outdated
Move signed div/rem strength reduction into a legalization pass, keep LDS globals distinct, and fix dynamic stride ordering so translation and runtime dispatch agree on kernel ABI.

Made-with: Cursor
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@Hardcode84 Hardcode84 force-pushed the llvm-asm-backend-gemm-v2 branch from a85b75e to 4d11a22 Compare April 10, 2026 18:19
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@Hardcode84 Hardcode84 requested a review from panditsa April 10, 2026 18:36
Normalize commutative bitwise legalization so immediate-first operands still satisfy the concrete SALU operand constraints, and keep the case covered by a focused regression test.

Made-with: Cursor
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@Hardcode84 Hardcode84 force-pushed the llvm-asm-backend-gemm-v2 branch from f8fd152 to aec6903 Compare April 10, 2026 22:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants