GEMM support in waveasm LLVM path#1288
Conversation
…e helper emit_host_func was building arg_types from linear_bindings only (buffers + scalars + symbols) while the kernel function also carried trailing stride index args added by emit_func when dynamic_strides is active. The count mismatch caused a segfault during canonicalization because old kernel arguments were left with dangling uses after erase. - Match stride arg count in emit_host_func so the gpu.func wrapper has the correct signature. - Add wave_get_stride runtime function (mirrors wave_get_dim but calls tensor.stride() instead of tensor.size()). - Pass stride values through gpu.launch_func kernel operands. - Add GEMM waveasm e2e test skeleton. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Extend the waveasm-translate-from-llvm pass to handle all LLVM dialect ops generated by the GEMM kernel, including: - LDS: llvm.mlir.addressof for globals, ptr<3> GEPs, ds_read/ds_write - Arithmetic: llvm.sdiv/srem (power-of-2 -> shift/mask) - Barriers: llvm.fence (no-op), rocdl.s.barrier -> s_barrier - MFMA: rocdl.mfma.f32.16x16x16f16 -> v_mfma_f32_16x16x16_f16 - Vector: llvm.shufflevector (single-element extract) - Constants: dense vector splat (MFMA accumulator init) - Control flow: scf.for -> waveasm.loop with do-while semantics Also update the water+waveasm lowering pipeline to preserve scf.for (remove convert-scf-to-cf) and use alloca-to-global for LDS, matching what the translation pass expects. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
- translate-from-llvm-mfma.mlir: MFMA, dense vector constant, shufflevector extract - translate-from-llvm-scf-for.mlir: scf.for -> waveasm.loop with condition - translate-from-llvm-lds.mlir: addressof, ptr<3> GEP, ds_read/ds_write - translate-from-llvm-barrier.mlir: rocdl.barrier, rocdl.s.barrier, llvm.fence - translate-from-llvm-sdiv-srem.mlir: power-of-2 sdiv/srem -> shift/mask Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
inferResultType and handleCastOp always produced 1-wide register types regardless of the LLVM operation's actual bit width. This made arith.trunc a no-op for i64->i32 truncations (both sides were vreg), so legalization removed it as a passthrough before the sext it guarded was expanded to its true 2-wide form -- leaking vreg<2,2> values into LDS address chains and producing invalid ds_read/ds_write assembly. Fix inferResultType to propagate max operand width and handleCastOp to derive width from the LLVM result type. Now sext i32->i64 produces vreg<2,2>, trunc i64->i32 produces vreg, and legalization handles both correctly without the trunc being silently eliminated. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
The waveasm backend only supports gfx950 for now. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
The LDS base value set was redundant -- the LLVM load/store address operand already carries the pointer type with the address space. Check getLLVMAddrSpace(op.getAddr()) == 3 instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Summing GEP indices is wrong -- each index operates at a different level of the type hierarchy and must be multiplied by the element size at that level. The all-zero multi-index case (the only one we hit in practice) was already handled correctly. Remove the bogus fallback and return nullopt for unhandled multi-index GEPs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
The GEP offset computation was treating all indices as byte offsets regardless of element type. Rename to computeGEPByteOffset and multiply dynamic indices by getGEPElementBytes(). Replace local getConstantInt with getConstantIntValue from StaticValueUtils. Use explicit types instead of auto where the type is not obvious. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
- Buffer GEP with constant attr index (not a dynamic Value). - GEP on unsupported address space (e.g. ptr<5>). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Reject unsupported non-zero structural GEPs and preserve signed power-of-2 div/rem semantics for negative inputs. Keep scf.for loop control scalar and add regression tests for the guarded cases. Made-with: Cursor Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
20f1b87 to
593790f
Compare
Compute LDS/global sizes in bytes and reject unsupported GEP element types so the translator fails instead of silently misaddressing memory. Restore source locations for original wrapper args, add lit coverage for the new behavior, and leave a TODO for the remaining scf.for zero-trip semantic gap. Made-with: Cursor Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Use concrete local types in the new lowering paths so the translator is easier to audit and maintain without changing behavior. Made-with: Cursor Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
| def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any]]: | ||
| nonlocal mlir_asm | ||
| # Erase the last '}' closing the module, append the transform, re-close. | ||
| last_close = mlir_asm.rfind("}") | ||
| if last_close != -1: | ||
| mlir_asm = mlir_asm[:last_close] | ||
| mlir_asm += transform | ||
| mlir_asm += "}\n" | ||
| return ("transform-interpreter", {"entry-point": entry_point}) |
There was a problem hiding this comment.
Doing an in-place fragile textual modification of a non-local variable is a terrible idea. We have API to modify IR properly. And I think we have enough bindings to build the script below rather than hardcode a string.
| # Match stride args added by emit_func when dynamic_strides is active. | ||
| stride_arg_count = 0 | ||
| if self.options.dynamic_strides: | ||
| stride_arg_count = sum( | ||
| max(0, len(b.kernel_buffer_type.symbolic_shape) - 1) | ||
| for b in self.root_sig.sig.kernel_buffer_bindings | ||
| ) | ||
| if stride_arg_count > 0: | ||
| arg_types += [IndexType.get()] * stride_arg_count |
There was a problem hiding this comment.
Can this logic be factored out into a function also called from emit_func so we don't duplicate and risk divergence?
| # Preserve source locations for original args; synthesized stride args | ||
| # do not have an originating kernel SSA value. | ||
| locs = [a.location for a in kernel_func.body.blocks[0].arguments] | ||
| locs += [Location.unknown()] * (len(arg_types) - len(locs)) |
There was a problem hiding this comment.
IMO they can have something line Location.name("stride #42 for argument", arg.loc) so we clearly see what they relate to in the debug build.
| get_stride_func_symbol, | ||
| [arg, dim], | ||
| ) | ||
| stride = arith_d.index_cast(IndexType.get(), stride) |
There was a problem hiding this comment.
Should we just pass it as i32 and not bother with index?
There was a problem hiding this comment.
We can, probably, but it will require more changes across the codebase (StreamExecutable.define_entrypoint and other places), lets keep it as index for now.
| static int64_t getLLVMTypeBytes(Type ty) { | ||
| if (ty.isIntOrFloat()) | ||
| return ty.getIntOrFloatBitWidth() / 8; | ||
| if (VectorType vecTy = dyn_cast<VectorType>(ty)) | ||
| return vecTy.getNumElements() * | ||
| vecTy.getElementType().getIntOrFloatBitWidth() / 8; | ||
| if (LLVM::LLVMArrayType arrTy = dyn_cast<LLVM::LLVMArrayType>(ty)) | ||
| return getLLVMTypeBytes(arrTy.getElementType()) * arrTy.getNumElements(); | ||
| return 0; | ||
| } |
There was a problem hiding this comment.
Don't we have data layout support for this? It may in particular be consistent with things like sub-byte element types and non-power-of-two vector lengths.
There was a problem hiding this comment.
Not sure DataLayout is currently plumbed
| // Signed remainder by a positive power-of-2 constant: | ||
| // r = x & (divisor - 1) | ||
| // if (x < 0 && r != 0) r -= divisor | ||
| // This keeps the remainder's sign consistent with LLVM srem. |
| // SCF for/yield handler | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| /// Forward declaration for recursive op translation. |
There was a problem hiding this comment.
This is anti-documentation. It will show up in, e.g., autocomplete, and provide negative value compared to no documentation at all.
| if (isVGPRType(v.getType())) | ||
| return V_READFIRSTLANE_B32::create(builder, loc, sregTy, v).getResult(); | ||
| op->emitOpError(name) << " must lower to an SGPR or VGPR i32"; | ||
| return failure(); |
There was a problem hiding this comment.
Overall comment on the FailureOr<Value> vs Value, I saw agent was doing it and I intentionally didn't stopped it. In upstream MLIR there are some functions which never return empty Value and some which do and their result needed to be checked (or worse, historically, some function were switched from never returning null to returning), which is footgun. FailureOr gives a clear separation on type system level.
| SmallVector<Value> initArgs; | ||
| initArgs.push_back(*lbScalar); | ||
| for (Value arg : op.getInitArgs()) { | ||
| FailureOr<Value> resolved = resolve(arg, ctx); |
There was a problem hiding this comment.
Fly-by, there is rarely a need to wrap a nullable object into FailureOr, so I suspect resolve should be fixed.
Move signed div/rem strength reduction into a legalization pass, keep LDS globals distinct, and fix dynamic stride ordering so translation and runtime dispatch agree on kernel ABI. Made-with: Cursor Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
a85b75e to
4d11a22
Compare
Normalize commutative bitwise legalization so immediate-first operands still satisfy the concrete SALU operand constraints, and keep the case covered by a focused regression test. Made-with: Cursor Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
f8fd152 to
aec6903
Compare
Summary
waveasm-translate.scf.forloop control.Notes
scf.forlowering still carries a TODO for full zero-trip semantics becausewaveasm.loopis currently do-while shaped.