Skip to content

[GPU] Allow avoiding scale swizzling in cuDNN BlockScaledDot (BlockScalingRewriter) for pre-swizzled weights #41201

@loupicaaa

Description

@loupicaaa

Currently, to use the cuDNN block scaled dot kernel, the weight scales need to be arranged in a specific scale_vec::1X layout . As an optimization we prefer to pre swizzle these weight scales at load time rather than dynamically at runtime.

In xla/backends/gpu/transforms/block_scaling_rewriter.cc, theBuildCudnnScaledDotInputfunction applies an unconditional swizzle (Reshape -> Transpose -> Reshape) to both lhs_scale and rhs_scale:

  TF_ASSIGN_OR_RETURN(Shape scale_valid_shape, builder.GetShape(scale_op));
  int64_t scale_rows = scale_valid_shape.dimensions(rank - 2);
  int64_t scale_cols = scale_valid_shape.dimensions(rank - 1);
  scale_op =
      Reshape(scale_op, {batch_size, scale_rows / kInputNonContractingTileSize,
                         4, 32, scale_cols / kScaleContractingTileSize,
                         kScaleContractingTileSize});
  scale_op = Transpose(scale_op, {0, 1, 4, 3, 2, 5});
  scale_op = Reshape(scale_op, scale_valid_shape.dimensions());

Also we can not bypass this by using a specific layout because GetCudnnMxType explicitly rejects non-monotonic layouts:

CudnnMxType GetCudnnMxType(const Shape& input_shape, const Shape& scale_shape,
                           std::optional<int64_t> block_size) {
  // Non-default layout is not supported.
  if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) ||
      !LayoutUtil::IsMonotonicWithDim0Major(scale_shape.layout())) {
    return CudnnMxType::UNSUPPORTED_TYPE;
  }

The compiler should provide a way to inform the BlockScalingRewriter that the scales (lhs_scale and/or rhs_scale) are already swizzled and that the transpose step in BuildCudnnScaledDotInput should be skipped.

I suggest to solutions to avoid the unconditional swizzling:

  • Add a boolean flag to the BlockScaledDotBackendConfig (e.g., lhs_scale_pre_swizzled, rhs_scale_pre_swizzled)
  • Modify GetCudnnMxType to accept specific non-default XLA layouts that explicitly represent the scale_vec::1X swizzle pattern. BuildCudnnScaledDotInput could then conditionally skip the transpose.

Metadata

Metadata

Labels

GPUXLA on GPU

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions