Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 83 additions & 9 deletions vllm/model_executor/kernels/mhc/tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,87 @@ def mhc_pre_tilelang(
comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
layer_input: shape (..., hidden_size), dtype torch.bfloat16
"""
import vllm._tilelang_ops as tilelang_ops
from vllm._tilelang_ops import (

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check where should we put this better.
Sorry for breaking this.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for fixing this.

compute_num_split,
mhc_pre_big_fuse_tilelang,
)
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
from vllm.utils.math_utils import cdiv

return tilelang_ops.mhc_pre(
residual,
assert residual.dtype == torch.bfloat16
assert fn.dtype == torch.float32
assert hc_scale.dtype == torch.float32
assert hc_base.dtype == torch.float32

hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
hc_mult2 = hc_mult * hc_mult
hc_mult3 = hc_mult * 2 + hc_mult2

hc_hidden_size = hc_mult * hidden_size
assert fn.shape[0] == hc_mult3
assert fn.shape[1] == hc_hidden_size
assert hc_scale.shape == (3,)
assert hc_base.shape == (hc_mult3,)

outer_shape = residual.shape[:-2]

residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]

# these numbers are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))

post_mix = torch.empty(
num_tokens, hc_mult, dtype=torch.float32, device=residual.device
)
comb_mix = torch.empty(
num_tokens, hc_mult2, dtype=torch.float32, device=residual.device
)
layer_input = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device
)

gemm_out_mul = torch.empty(
n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device
)
gemm_out_sqrsum = torch.empty(
n_splits, num_tokens, dtype=torch.float32, device=residual.device
)

tf32_hc_prenorm_gemm(
residual_flat.view(num_tokens, hc_mult * hidden_size),
fn,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)

mhc_pre_big_fuse_tilelang(
gemm_out_mul,
gemm_out_sqrsum,
hc_scale,
hc_base,
residual_flat,
post_mix,
comb_mix,
layer_input,
hidden_size,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
hc_mult,
)

return (
post_mix.view(*outer_shape, hc_mult, 1),
comb_mix.view(*outer_shape, hc_mult, hc_mult),
layer_input.view(*outer_shape, hidden_size),
)


Expand Down Expand Up @@ -100,14 +168,19 @@ def mhc_post_tilelang(
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
import vllm._tilelang_ops as tilelang_ops
from vllm._tilelang_ops import mhc_post_tilelang as _mhc_post_kernel

return tilelang_ops.mhc_post(
x,
residual,
post_layer_mix,
out = torch.empty_like(residual)
_mhc_post_kernel(
comb_res_mix,
residual,
post_layer_mix.squeeze(-1),
x,
out,
residual.shape[-2],
residual.shape[-1],
)
return out


def mhc_fused_post_pre_tilelang(
Expand Down Expand Up @@ -137,11 +210,12 @@ def mhc_fused_post_pre_tilelang(
"""

from vllm._tilelang_ops import (
compute_num_split,
mhc_fused_tilelang,
mhc_post_tilelang,
mhc_pre_big_fuse_tilelang,
)
from vllm.utils.math import cdiv, compute_num_split
from vllm.utils.math_utils import cdiv

assert residual.dtype == torch.bfloat16
assert x.dtype == torch.bfloat16
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# this import will also register the custom ops
import vllm.model_executor.kernels.mhc as mhc_kernels
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform


@CustomOp.register("mhc_pre")
Expand Down Expand Up @@ -133,7 +132,6 @@ class HCHeadOp(CustomOp):
def enabled(cls) -> bool:
return True

@torch.compile(backend=current_platform.simple_compile_backend)
def forward_cuda(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -164,7 +162,6 @@ def forward_cuda(
)
return out.view(*outer_shape, hidden_size)

@torch.compile(backend=current_platform.simple_compile_backend)
def forward_hip(
self,
hidden_states: torch.Tensor,
Expand Down