Currently, to use the cuDNN block scaled dot kernel, the weight scales need to be arranged in a specific scale_vec::1X layout . As an optimization we prefer to pre swizzle these weight scales at load time rather than dynamically at runtime.
In xla/backends/gpu/transforms/block_scaling_rewriter.cc, theBuildCudnnScaledDotInputfunction applies an unconditional swizzle (Reshape -> Transpose -> Reshape) to both lhs_scale and rhs_scale:
TF_ASSIGN_OR_RETURN(Shape scale_valid_shape, builder.GetShape(scale_op));
int64_t scale_rows = scale_valid_shape.dimensions(rank - 2);
int64_t scale_cols = scale_valid_shape.dimensions(rank - 1);
scale_op =
Reshape(scale_op, {batch_size, scale_rows / kInputNonContractingTileSize,
4, 32, scale_cols / kScaleContractingTileSize,
kScaleContractingTileSize});
scale_op = Transpose(scale_op, {0, 1, 4, 3, 2, 5});
scale_op = Reshape(scale_op, scale_valid_shape.dimensions());
Also we can not bypass this by using a specific layout because GetCudnnMxType explicitly rejects non-monotonic layouts:
CudnnMxType GetCudnnMxType(const Shape& input_shape, const Shape& scale_shape,
std::optional<int64_t> block_size) {
// Non-default layout is not supported.
if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) ||
!LayoutUtil::IsMonotonicWithDim0Major(scale_shape.layout())) {
return CudnnMxType::UNSUPPORTED_TYPE;
}
The compiler should provide a way to inform the BlockScalingRewriter that the scales (lhs_scale and/or rhs_scale) are already swizzled and that the transpose step in BuildCudnnScaledDotInput should be skipped.
I suggest to solutions to avoid the unconditional swizzling:
- Add a boolean flag to the BlockScaledDotBackendConfig (e.g., lhs_scale_pre_swizzled, rhs_scale_pre_swizzled)
- Modify GetCudnnMxType to accept specific non-default XLA layouts that explicitly represent the
scale_vec::1X swizzle pattern. BuildCudnnScaledDotInput could then conditionally skip the transpose.
Currently, to use the cuDNN block scaled dot kernel, the weight scales need to be arranged in a specific
scale_vec::1Xlayout . As an optimization we prefer to pre swizzle these weight scales at load time rather than dynamically at runtime.In xla/backends/gpu/transforms/block_scaling_rewriter.cc, the
BuildCudnnScaledDotInputfunction applies an unconditional swizzle (Reshape -> Transpose -> Reshape) to both lhs_scale and rhs_scale:Also we can not bypass this by using a specific layout because
GetCudnnMxTypeexplicitly rejects non-monotonic layouts:The compiler should provide a way to inform the BlockScalingRewriter that the scales (lhs_scale and/or rhs_scale) are already swizzled and that the transpose step in BuildCudnnScaledDotInput should be skipped.
I suggest to solutions to avoid the unconditional swizzling:
scale_vec::1Xswizzle pattern. BuildCudnnScaledDotInput could then conditionally skip the transpose.