From 5ff62ebdc1514f47096628ef956f12e05641f129 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Wed, 11 Jun 2025 13:21:40 +0800 Subject: [PATCH 01/26] [feat] add triton builder --- compiler/python/byteir/compile.py | 116 +++++++++++++ .../byteir/dialects/cat/ir_processor.py | 136 ++++++++++++++-- .../cat/ir_translator/backend/tit_registry.py | 39 +++++ .../dialects/cat/ir_translator/tit_builder.py | 152 ++++++++++++++++++ .../python/byteir/dialects/cat/tit_cache.py | 108 +++++++++++++ compiler/python/byteir/tools/compiler.py | 2 +- compiler/python/byteir/utils.py | 2 +- 7 files changed, 543 insertions(+), 12 deletions(-) create mode 100644 compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py create mode 100644 compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py create mode 100644 compiler/python/byteir/dialects/cat/tit_cache.py diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index f53c4c7f7..3b3dd6529 100755 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -287,6 +287,122 @@ def _compile_cuda_with_ait( raise ValueError("module asm has be changed after byre serialization") +@register_byteir_compiler_backend(target="cuda_with_triton", device="cuda") +def _compile_cuda_with_triton( + compile_options: CompileOptions, +) -> None: + from .dialects.cat import IRProcessor + + target = "cuda" + module = compile_options.module + entry_func = compile_options.entry_func + gpu_arch = compile_options.gpu_arch + verbose = compile_options.verbose + name = compile_options.name + enable_tf32 = compile_options.enable_tf32 + parallelism = compile_options.parallelism + disable_byteir_ait_cache = compile_options.disable_byteir_ait_cache + + output_file_dir = compile_options.output_dir + output_file_prefix = compile_options.output_file_prefix + output_type = compile_options.output_type + useBarePtrCallConv = True # all tensor must have static shapes if True + + context = module.context + + entry_func_str = "entry-func={}".format(entry_func) + target_str = "target={}".format(target) + + with context: + PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation) + _print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ... + + processor = IRProcessor(name, + "./workspace", + enable_tf32=enable_tf32, + compile_parallelism=parallelism, + disable_byteir_ait_cache=disable_byteir_ait_cache, + verbose=verbose) + processor.module = module + + processor.preprocess_pass() + processor.cat_opt_pass(anchor_only=False) + + with context: + pm = PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op outline-cat-op})") + pm.run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Hlo Fusion Opt (with Cat):") if verbose else ... + print(processor.module) + # not generate ait lib .so for cat functions + processor.triton_opt_pass(output_file_dir) + module = processor.module + + with context: + PassManager.parse("builtin.module(linalg-tensor-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Linalg Tensor Opt:") if verbose else ... + with context: + if enable_tf32: + PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types enable-tf32 {}}})".format(entry_func_str)).run(processor.module.operation) + else: + PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Byre Tensor Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(byteir-bufferize-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After ByteIR Bufferize Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(linalg-memref-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Linalg Memref Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(scf-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After SCF Opt:") if verbose else ... + with context: + if useBarePtrCallConv: + PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true device-file-name="+ output_file_prefix + ".ptx" + "})").run(module.operation) + else: + PassManager.parse("builtin.module(gpu-opt{device-file-name=" + output_file_prefix + ".ptx" + "})").run(module.operation) + _print_verbose(processor.module, "// IR Dump After GPU Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(inline)").run(processor.module.operation) + PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation) + PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(processor.module.operation) + PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(processor.module.operation) + PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Set Space Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(byre-opt{append-arg-types " + entry_func_str + "})").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Byre Opt:") if verbose else ... + + # create device module + module_str = processor.module.operation.get_asm(print_generic_op_form=True) + device_module = ir.Module.parse(module_str, context) + with context: + if useBarePtrCallConv: + PassManager.parse("builtin.module(nvvm-codegen{use-bare-ptr-memref-call-conv=true " + f" gpu-arch={gpu_arch}" + "})").run(device_module.operation) + else: + PassManager.parse("builtin.module(nvvm-codegen{" + f" gpu-arch= {gpu_arch}" + "})").run(device_module.operation) + _print_verbose(device_module, "// IR Dump After NVVM Codegen:") if verbose else ... + # write to output device ptx + byteir.translate_to_ptx(device_module, output_file_dir + "/" + output_file_prefix, gpu_arch) + + # create host module + with context: + PassManager.parse("builtin.module(byre-host)").run(processor.module.operation) + PassManager.parse("builtin.module(remove-module-tag{attr-name=gpu.container_module})").run(module.operation) + PassManager.parse("builtin.module(remove-module-tag{attr-name=torch.debug_module_name})").run(module.operation) + _print_verbose(processor.module, "// IR Dump After Byre Host:") if verbose else ... + + output_host_mlir_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIR.value) + output_host_mlirbc_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIRBC.value) + # write to output host mlir file + with open(output_host_mlir_path, "w") as f: + f.write(module.operation.get_asm()) + if output_type is OutputType.MLIRBC: + byteir.serialize_byre(module, compile_options.byre_serial_version, output_host_mlirbc_path) + deserialized_module = byteir.deserialize_byre(open(output_host_mlirbc_path, "rb").read(), context) + if (module.operation.get_asm() != deserialized_module.operation.get_asm()): + raise ValueError("module asm has be changed after byre serialization") + + @register_byteir_compiler_backend(target="cpu", device="cpu") def _compile_cpu( compile_options: CompileOptions, diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 23fbb7133..cf3387d39 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -4,6 +4,7 @@ from byteir.utils import get_gpu_type from .ait_cache import AITCache +from .tit_cache import TITCache from pathlib import Path from shutil import copyfile, copymode @@ -55,7 +56,8 @@ def __init__(self, enable_tf32 = False, compile_parallelism = 1, disable_byteir_ait_cache = False, - verbose = False): + verbose = False, + device = "cuda"): self.job_name = job_name self.workdir = workdir self.module = None @@ -65,11 +67,10 @@ def __init__(self, self.pool = multiprocessing.Pool(compile_parallelism) else: self.pool = None - self.byteir_cache = AITCache() self.verbose = verbose - self.disable_byteir_ait_cache = disable_byteir_ait_cache - if not disable_byteir_ait_cache: - self.byteir_cache.load_or_create_cache() + #TODO: signature rename + self.disable_byteir_cache = disable_byteir_ait_cache + self.device = device def _get_builder(self, func, subgraph_name, backend="ait"): assert func != None @@ -77,6 +78,9 @@ def _get_builder(self, func, subgraph_name, backend="ait"): if backend == "ait": from byteir.dialects.cat.ir_translator.ait_builder import AITBuilder return AITBuilder(func, workdir=self.workdir, subgraph_name=subgraph_name, enable_tf32=self.enable_tf32) + elif backend == "triton": + from byteir.dialects.cat.ir_translator.tit_builder import TRITONTBuilder + return TRITONTBuilder(func, workdir=self.workdir, subgraph_name=subgraph_name, enable_tf32=self.enable_tf32, device=self.device) else: raise RuntimeError(f"Unsupported runtime backend {backend}") @@ -100,6 +104,9 @@ def cat_opt_pass(self, anchor_only=False): return self.module def ait_opt_pass(self, output_dir): + self.byteir_cache = AITCache() + if not self.disable_byteir_cache: + self.byteir_cache.load_or_create_cache() funcNameArg = [] aitLibPathArg = [] @@ -141,7 +148,7 @@ def ait_opt_pass(self, output_dir): print("compilation finished in {}s".format(t_ed-t_st)) # update byteir cache - if not self.disable_byteir_ait_cache: + if not self.disable_byteir_cache: for key, lib_path in libs_to_add_to_cache.items(): self.byteir_cache.add(gpu_type, key, lib_path, override=False) self.byteir_cache._save() @@ -153,6 +160,99 @@ def ait_opt_pass(self, output_dir): _print_verbose(self.module, "// IR Dump After Gen AIT Config:") if self.verbose else ... return self.module + + def triton_opt_pass(self, output_dir): + + def decouple_triton_args(triton_args): + func_name_args = [] + ptx_path_args = [] + gridsize_x_args = [] + gridsize_y_args = [] + gridsize_z_args = [] + blocksize_x_args = [] + blocksize_y_args = [] + blocksize_z_args = [] + for func_name,ptx_path,gridsize,blocksize in triton_args: + func_name_args.append(func_name) + ptx_path_args.append(ptx_path) + gridsize_x_args.append(str(gridsize[0])) + gridsize_y_args.append(str(gridsize[1])) + gridsize_z_args.append(str(gridsize[2])) + #TODO: blocksize is 1d for now, need to check + blocksize_x_args.append(str(blocksize)) + blocksize_y_args.append(str(1)) + blocksize_z_args.append(str(1)) + return func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args + + self.pool=None + + self.byteir_cache = TITCache() + if not self.disable_byteir_cache: + self.byteir_cache.load_or_create_cache() + triton_args = [] + + gpu_type = get_gpu_type() + if gpu_type == None: + raise RuntimeError("No gpu found in this machine! cannot perform triton-opt-pass") + work_items = [] # work items of FuncOp + + for func in self.module.body.operations: + if BYTEIR_CAT_ATTR not in func.attributes: + continue + output_ptx_path = os.path.join(output_dir, func.name.value + ".ptx") + hash_str = func_hash_str(func, gpu_type) + # TODO: gridsize order need to be checked + # gridsize form (x,y,z), blocksize form (x,y,z) + cached_argv = self.byteir_cache.find(gpu_type, hash_str) + if cached_argv: + cache_ptx,gridsize,blocksize = cached_argv + print(f"func {func.name.value} cache hit") + copyfile(cache_ptx, output_ptx_path) + copymode(cache_ptx, output_ptx_path) + triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize)) + else: + work_items.append(func) + + # compile and benchmark + print("compile triton module using {} processes".format(min(len(work_items), self.compile_parallelism))) + print("\n".join([str(func) for func in work_items])) + t_st = time.time() + + new_args = [] + for func in work_items: + output_ptx_path = os.path.join(output_dir, func.name.value + ".ptx") + if self.pool: + new_args.append(self.pool.apply_async(_parallel_tit_compile, + (self.workdir, func, output_ptx_path, self.enable_tf32))) + else: + new_args.append(_parallel_tit_compile(self.workdir, func, output_ptx_path, self.enable_tf32)) + + if self.pool: + self.pool.close() + self.pool.join() + + for func,output_ptx_path,gridsize,blocksize in new_args: + triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize)) + self.byteir_cache.load_or_create_cache() + self.byteir_cache.add(gpu_type, func_hash_str(func, gpu_type), (output_ptx_path, gridsize, blocksize), override=False) + self.byteir_cache._save() + self.byteir_cache.close_cache() + + t_ed = time.time() + print("compilation finished in {}s".format(t_ed-t_st)) + + func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args = decouple_triton_args(triton_args) + + with self.module.context: + pm_str="builtin.module(func.func(gen-tit-config{{func-names={} tit-ptx-paths={} gridsize-x-args={} gridsize-y-args={} gridsize-z-args={} blocksize-x-args={} blocksize-y-args={} blocksize-z-args={}}}))".format(",".join(func_name_args), ",".join(ptx_path_args), ",".join(gridsize_x_args), ",".join(gridsize_y_args), ",".join(gridsize_z_args), ",".join(blocksize_x_args), ",".join(blocksize_y_args), ",".join(blocksize_z_args)) + print(pm_str) + pm = PassManager.parse(pm_str) + exit + pm.run(self.module.operation) + _print_verbose(self.module, "// IR Dump After Gen AIT Config:") if self.verbose else ... + + + return self.module def execute(self, inputs, func_name=None, backend="ait"): if func_name is None: @@ -177,10 +277,26 @@ def profile(self, backend="ait"): def _parallel_ait_compile(workdir: str, func: FuncOp, output_lib_path, enable_tf32): + + def touch_blank_file(file_path): + with open(file_path, 'w') as f: + pass # os.environ["CUDA_VISIBLE_DEVICES"]=str(os.getpid() % available_cuda_device_num) from byteir.dialects.cat.ir_translator.ait_builder import AITBuilder - builder = AITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) + # builder = AITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) + # builder.compile() + # builder.benchmark() + # copyfile(builder.ait_module_path, output_lib_path) + # copymode(builder.ait_module_path, output_lib_path) + touch_blank_file(output_lib_path) + +def _parallel_tit_compile(workdir: str, func: FuncOp, output_ptx_path, enable_tf32): + + # os.environ["CUDA_VISIBLE_DEVICES"]=str(os.getpid() % available_cuda_device_num) + from byteir.dialects.cat.ir_translator.tit_builder import TITBuilder + builder = TITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) builder.compile() - builder.benchmark() - copyfile(builder.ait_module_path, output_lib_path) - copymode(builder.ait_module_path, output_lib_path) + blockSize,gridsize=builder.blocksize,builder.gridsize + copyfile(builder.tit_module_path, output_ptx_path) + copymode(builder.tit_module_path, output_ptx_path) + return func,output_ptx_path,gridsize,blockSize \ No newline at end of file diff --git a/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py new file mode 100644 index 000000000..da3746312 --- /dev/null +++ b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py @@ -0,0 +1,39 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tritontemplate.compiler.base import Tensor +import tritontemplate.compiler.ops as tit_ops + +from ..translator import IRTranslator +from byteir import ir +from byteir.utils import mlir_attr_to_pyobj, mlir_type_to_torch_str + +class TRITONTemplateIRTranslator(IRTranslator): + pass + +@TRITONTemplateIRTranslator.register("mhlo.constant") +def _dispatch_mhlo_constant(op, inputs): + shaped_type = ir.ShapedType(op.result.type) + shape = shaped_type.shape + output = Tensor(shape, dtype=mlir_type_to_torch_str(shaped_type.element_type)) + return [output] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr_bias_relu") +def _dispatch_cat_gemm_rcr_bias_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr", is_bias=True, activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr_bias") +def _dispatch_cat_gemm_rcr_bias(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr", is_bias=True) + return [Y] \ No newline at end of file diff --git a/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py new file mode 100644 index 000000000..43d37e522 --- /dev/null +++ b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py @@ -0,0 +1,152 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import random +import torch +import os +import numpy as np +from typing import Set + +from tritontemplate.compiler.base import Tensor, IntImm +from tritontemplate.compiler import compile_kernel + +from .backend.tit_registry import * + +from byteir.utils import mlir_type_to_torch_str, torch_dtype_from_str + +class TITBuilder: + # stores a graph + # use stores mlir.Value to ait.Tensor map + _value2tensor = None + _op2parent_block = None + _im_vals = None + + def __init__(self, func, workdir="./workspace", subgraph_name="model", enable_tf32=False,device="cuda"): + self.func = func + self._value2tensor = {} + self._op2parent_block = {} + self._im_vals = set() + self.inputs : list[Tensor] = [] + self.outputs : list[Tensor] = [] + + self.tit_module_path = None + self.tit_model = None + + self.subgraph_name = subgraph_name + self.ptx_name = subgraph_name + ".ptx" + self.workdir = workdir + self.enable_tf32 = enable_tf32 + self.test_name = "./" + subgraph_name + self.constants = {} + self.constant_idx = 0 + + self.device = device + + + # init arguments + for idx,i in enumerate(self.func.arguments): + shaped_type = ir.ShapedType(i.type) + shape = shaped_type.shape + dtype = mlir_type_to_torch_str(shaped_type.element_type) + self._value2tensor[i] = Tensor(shape=shape, dtype=dtype, name=f"input_tensor_{idx}") + self.inputs.append(self._value2tensor[i]) + + # Note: variable `lib_path` is constructed according to the code in compile_model() + os.makedirs(os.path.join(self.workdir, self.test_name), exist_ok=True) + ptx_path = os.path.join(self.workdir, self.test_name, self.ptx_name) + print("TIT module path {} for {}".format(ptx_path, self.ptx_name)) + self.tit_module_path = ptx_path + + def compile(self): + self._visit_block(self.func.entry_block) + assert self.tit_module_path is not None + assert self.tit_kernel is not None + + def _visit_op(self, op): + # analyze here + # call bt APIs to create tensor & op + inputs = list(self._lookup_tensor(i) for i in op.operands) + if op.operation.name == "func.return": + self._gen_tit_kernel(inputs) + return + if hasattr(op, "operands"): + outputs = TRITONTemplateIRTranslator.translate(op, inputs) + if op.operation.name == "mhlo.constant": + # TODO: need to support FP16 + shaped_type = ir.ShapedType(op.result.type) + outputs[0]._attrs["name"] = f"const_tensor_{self.constant_idx}" + np_array = mlir_attr_to_pyobj(op.attributes["value"]) + data = torch.from_numpy(np_array).contiguous().cuda() + data = data.to(torch_dtype_from_str(mlir_type_to_torch_str(shaped_type.element_type))) + self.constants[outputs[0]._attrs["name"]] = data + self.constant_idx += 1 + if op.operation.name != "mhlo.constant": + for value in op.results: + self._im_vals.add(value) + for output, value in zip(outputs, op.results): + self._value2tensor[value] = output + + for region in op.operation.regions: + for block in region.blocks: + self._visit_block(block) + + def _visit_block(self, block): + for i in block.operations: + self._op2parent_block[i] = block + self._visit_op(i) + + def _lookup_tensor(self, val): + # return a bt.graph.Tensor + assert val in self._value2tensor + return self._value2tensor[val] + + def _gen_tit_kernel(self, results): + idx = 0 + for out in results: + out._attrs["name"] = f"output_tensor_{idx}" + idx += 1 + assert len(results) == 1, "only support single cat op" + result=results[0] + self.tit_kernel = compile_kernel( + op=result, + device=self.device, + workdir=self.workdir, + ) + # kernel rename + with open(self.tit_module_path, "w") as f: + f.write(self.tit_kernel.kernel_ptx(self.subgraph_name)) + self.gridsize = self.tit_kernel.gridsize + self.blocksize = self.tit_kernel.blocksize + + def _gen_runtime_tensor(self, tensor: Tensor): + shape = tensor.shape() + rt_shape = [] + for s in shape: + if isinstance(s, IntImm): + rt_shape.append(s.value()) + dtype = torch_dtype_from_str(tensor.dtype()) + if len(rt_shape) == 0: + return torch.tensor(1).to(torch_dtype_from_str(tensor.dtype())).cuda() + if dtype == torch.bool: + return torch.randint(high=1, size=rt_shape, dtype=dtype, device="cuda") + elif dtype in [torch.int8, torch.int, torch.int16, torch.int32, torch.int64]: + return torch.randint(high=100, size=rt_shape, dtype=dtype, device="cuda") + else: + return torch.randn(*rt_shape, device="cuda").to(torch_dtype_from_str(tensor.dtype())) + + def execute(self, np_inputs, num_trials=1, benchmark=False): + raise NotImplementedError("TITBuilder.execute() is not implemented yet") + + def benchmark(self, num_trials=5): + raise NotImplementedError("TITBuilder.benchmark() is not implemented yet") diff --git a/compiler/python/byteir/dialects/cat/tit_cache.py b/compiler/python/byteir/dialects/cat/tit_cache.py new file mode 100644 index 000000000..0bdb8f4e4 --- /dev/null +++ b/compiler/python/byteir/dialects/cat/tit_cache.py @@ -0,0 +1,108 @@ +import os +import re +import json +import fcntl +import tempfile + +from shutil import copyfile, copymode + +HOME_DIR = os.getenv("HOME") +if HOME_DIR is None: + HOME_DIR = tempfile.gettempdir() +DEFAULT_CACHE_DIR = os.path.join(HOME_DIR, ".byteir", "tit_cache") +CACHE_FILE_NAME = "tit_global_cache.json" +IDX_KEY = "byteir_tit_cache_auto_increment_idx" + +class TITCache: + def __init__(self, cache_dir = DEFAULT_CACHE_DIR) -> None: + self.idx = 0 # unique id of saved compiled .so + self.cache_dir = cache_dir + self.cache = { IDX_KEY : self.idx } # key: tit op hash str, value: relative path of compiled .so + self.fp = None + + def _open(self): + if not os.path.exists(os.path.join(self.cache_dir, CACHE_FILE_NAME)): + self.fp = open(os.path.join(self.cache_dir, CACHE_FILE_NAME), "w") + fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX) + self.cache = { IDX_KEY : self.idx } + json.dump(self.cache, self.fp, indent=2) + fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN) + self.fp.close() + self.fp = open(os.path.join(self.cache_dir, CACHE_FILE_NAME), "r+") + # print("try to acquire file lock...") + # acquire lock + fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX) + # print("file lock acquired.") + + def _close(self): + # release lock + fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN) + self.fp.close() + + def _load(self): + try: + self.cache = json.load(self.fp) + except json.decoder.JSONDecodeError: + self.cache = { IDX_KEY : self.idx } + + def _save(self): + self.fp.seek(0) + json.dump(self.cache, self.fp, indent=2) + + def sync_cache(self): + # sync cache with the files in cache dir + # remove invalid key-value pairs + max_idx = 0 + for gpu_type in self.cache: + if gpu_type == IDX_KEY: + continue + keys_to_check = list(self.cache[gpu_type].keys()) + for key in keys_to_check: + value = self.cache[gpu_type][key] + if not os.path.exists(os.path.join(self.cache_dir, value[0])): + self.cache[gpu_type].pop(key) + continue + if self.get_ptx_idx(value[0]) > max_idx: + max_idx = self.get_ptx_idx(value[0]) + self.idx = max_idx + 1 + self.cache[IDX_KEY] = self.idx + + def load_or_create_cache(self): + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir, exist_ok=True) + self._open() + self._load() + self.sync_cache() + + def add(self, gpu_type, key, argv, override = False): + if gpu_type not in self.cache: + self.cache[gpu_type] = {} + if override or key not in self.cache[gpu_type]: + value = "{:0>16}.ptx".format(self.idx) + self.idx += 1 + self.cache[IDX_KEY] = self.idx + self.cache[gpu_type][key] = (value, argv[1], argv[2]) + ptx_path = argv[0] + copyfile(ptx_path, os.path.join(self.cache_dir, value)) + copymode(ptx_path, os.path.join(self.cache_dir, value)) + + def find(self, gpu_type, key): + if gpu_type not in self.cache: + return None + if key in self.cache[gpu_type]: + lib_path, grid_size, block_size = self.cache[gpu_type][key] + return os.path.join(self.cache_dir, lib_path), grid_size, block_size + else: + return None + + def get_ptx_idx(self, ptx_name): + try: + idx=re.search(r'(\d+)\.ptx', ptx_name).group(1) + return int(idx) + except: + print("invalid ptx name: {}".format(ptx_name)) + raise RuntimeError("invalid ptx name") + + def close_cache(self): + if self.fp != None: + self._close() \ No newline at end of file diff --git a/compiler/python/byteir/tools/compiler.py b/compiler/python/byteir/tools/compiler.py index 1c01e06bd..03ee2e7c3 100644 --- a/compiler/python/byteir/tools/compiler.py +++ b/compiler/python/byteir/tools/compiler.py @@ -11,7 +11,7 @@ parser.add_argument("--target", type=str, default="cuda", - choices=["cuda", "cuda_with_ait", "cpu"], + choices=["cuda", "cuda_with_ait", "cuda_with_triton", "cpu"], help="target device name") parser.add_argument("--gpu_arch", type=str, diff --git a/compiler/python/byteir/utils.py b/compiler/python/byteir/utils.py index 8d5dc1e4e..715779891 100644 --- a/compiler/python/byteir/utils.py +++ b/compiler/python/byteir/utils.py @@ -164,7 +164,7 @@ def detect_gpu_arch_with_nvidia_smi(): "sm_70": ["V100"], "sm_75": ["T4", "Quadro T2000"], "sm_80": ["PG509", "A100", "A10", "RTX 30", "A30", "RTX 40", "A16"], - "sm_90": ["H100"], + "sm_90": ["H100","H800"], } for sm, names in sm_names.items(): if any(name in stdout for name in names): From 24309688c3a6273a77e7e48368b0ffa2c777ab1c Mon Sep 17 00:00:00 2001 From: liushanghao Date: Wed, 11 Jun 2025 15:44:54 +0800 Subject: [PATCH 02/26] [feat] add tit attr --- compiler/include/byteir/Conversion/Passes.h | 1 + compiler/include/byteir/Conversion/Passes.td | 27 +++++ .../include/byteir/Conversion/ToTIT/ToTIT.h | 45 ++++++++ compiler/lib/Conversion/CMakeLists.txt | 1 + compiler/lib/Conversion/ToTIT/CMakeLists.txt | 14 +++ .../lib/Conversion/ToTIT/GenTITConfig.cpp | 105 ++++++++++++++++++ .../byteir/dialects/cat/ir_processor.py | 2 +- 7 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 compiler/include/byteir/Conversion/ToTIT/ToTIT.h create mode 100644 compiler/lib/Conversion/ToTIT/CMakeLists.txt create mode 100644 compiler/lib/Conversion/ToTIT/GenTITConfig.cpp diff --git a/compiler/include/byteir/Conversion/Passes.h b/compiler/include/byteir/Conversion/Passes.h index 1bf88fff0..e3d3f11cf 100644 --- a/compiler/include/byteir/Conversion/Passes.h +++ b/compiler/include/byteir/Conversion/Passes.h @@ -28,6 +28,7 @@ #include "byteir/Conversion/LcclToByre/LcclToByre.h" #include "byteir/Conversion/MemrefToByre/MemrefToByre.h" #include "byteir/Conversion/ToAIT/ToAIT.h" +#include "byteir/Conversion/ToTIT/ToTIT.h" #include "byteir/Conversion/ToAce/MhloToAce.h" #include "byteir/Conversion/ToByre/ToByre.h" #include "byteir/Conversion/ToGPU/ToGPU.h" diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index 92269e369..b0ff13397 100755 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -155,6 +155,33 @@ def GenAITConfig : Pass<"gen-ait-config", "func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// ToTIT +//===----------------------------------------------------------------------===// + +def GenTITConfig : Pass<"gen-tit-config", "func::FuncOp"> { + let summary = "Generate TIT configuration"; + let constructor = "mlir::createGenTITConfigPass()"; + let options = [ + ListOption<"funcNames", "func-names", "std::string", + "names of all cat func for TIT backends.">, + ListOption<"titPtxPaths", "tit-ptx-paths", "std::string", + "paths to all TIT-generated .ptx files">, + ListOption<"gridsizeXArgs", "gridsize-x-args", "std::string", + "gridsize x args for TIT backends.">, + ListOption<"gridsizeYArgs", "gridsize-y-args", "std::string", + "gridsize y args for TIT backends.">, + ListOption<"gridsizeZArgs", "gridsize-z-args", "std::string", + "gridsize z args for TIT backends.">, + ListOption<"blocksizeXArgs", "blocksize-x-args", "std::string", + "blocksize x args for TIT backends.">, + ListOption<"blocksizeYArgs", "blocksize-y-args", "std::string", + "blocksize y args for TIT backends.">, + ListOption<"blocksizeZArgs", "blocksize-z-args", "std::string", + "blocksize z args for TIT backends.">, + ]; +} + //===----------------------------------------------------------------------===// // ToByre //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h new file mode 100644 index 000000000..a2199b078 --- /dev/null +++ b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h @@ -0,0 +1,45 @@ +//===- ToTIT.h ------------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_CONVERSION_TOTIT_H +#define BYTEIR_CONVERSION_TOTIT_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func +class ModuleOp; + +constexpr StringRef getByteIRTITOpKernelName() { return "TITOp"; } + +std::unique_ptr> +createGenTITConfigPass(ArrayRef funcNames = {""}, + ArrayRef titPtxPaths = {""}, + ArrayRef gridsizeXArgs = {""}, + ArrayRef gridsizeYArgs = {""}, + ArrayRef gridsizeZArgs = {""}, + ArrayRef blocksizeXArgs = {""}, + ArrayRef blocksizeYArgs = {""}, + ArrayRef blocksizeZArgs = {""} + ); + +} // namespace mlir + +#endif // BYTEIR_CONVERSION_TOTIT_H diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index a47af1137..4d6f77fa1 100755 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(HloToTensor) add_subdirectory(MemrefToByre) add_subdirectory(ToAce) add_subdirectory(ToAIT) +add_subdirectory(ToTIT) add_subdirectory(ToByre) add_subdirectory(ToGPU) add_subdirectory(ToHlo) diff --git a/compiler/lib/Conversion/ToTIT/CMakeLists.txt b/compiler/lib/Conversion/ToTIT/CMakeLists.txt new file mode 100644 index 000000000..2309c1fc3 --- /dev/null +++ b/compiler/lib/Conversion/ToTIT/CMakeLists.txt @@ -0,0 +1,14 @@ +add_byteir_conversion_library(ByteIRToTIT + GenTITConfig.cpp + + ADDITIONAL_HEADER_DIRS + ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Conversion/ToTIT + + DEPENDS + ByteIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRBufferizationTransforms + ByteIRUtils + ) diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp new file mode 100644 index 000000000..98b7cad90 --- /dev/null +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -0,0 +1,105 @@ +//===- GenTITConfig.cpp ---------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Conversion/ToTIT/ToTIT.h" +#include "byteir/Dialect/mhlo/Transforms/HloFuser.h" +#include "byteir/Utils/FuncUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/Debug.h" + +#include "../PassDetail.h" + +using namespace mlir; + +namespace { +static void AttachTITConfigToAttr(func::FuncOp func, + const std::string &titPtxPath, + const std::string &gridsizeXArg, + const std::string &gridsizeYArg, + const std::string &gridsizeZArg, + const std::string &blocksizeXArg, + const std::string &blocksizeYArg, + const std::string &blocksizeZArg + ) { + addGenericFuncAttrs(func, getByteIRTITOpKernelName().str()); + + mlir::OpBuilder opBuilder(func); + func->setAttr("tit_ptx_file", + opBuilder.getStringAttr(titPtxPath)); + func->setAttr("gridsize_x", + opBuilder.getStringAttr(gridsizeXArg)); + func->setAttr("gridsize_y", + opBuilder.getStringAttr(gridsizeYArg)); + func->setAttr("gridsize_z", + opBuilder.getStringAttr(gridsizeZArg)); + func->setAttr("blocksize_x", + opBuilder.getStringAttr(blocksizeXArg)); + func->setAttr("blocksize_y", + opBuilder.getStringAttr(blocksizeYArg)); + func->setAttr("blocksize_z", + opBuilder.getStringAttr(blocksizeZArg)); +} + +struct GenTITConfigPass : public GenTITConfigBase { + GenTITConfigPass( + ArrayRef funcNames, + ArrayRef titPtxPaths, + ArrayRef gridsizeXArgs, + ArrayRef gridsizeYArgs, + ArrayRef gridsizeZArgs, + ArrayRef blocksizeXArgs, + ArrayRef blocksizeYArgs, + ArrayRef blocksizeZArgs + ) + : GenTITConfigBase() { + this->funcNames = funcNames; + this->titPtxPaths = titPtxPaths; + this->gridsizeXArgs = gridsizeXArgs; + this->gridsizeYArgs = gridsizeYArgs; + this->gridsizeZArgs = gridsizeZArgs; + this->blocksizeXArgs = blocksizeXArgs; + this->blocksizeYArgs = blocksizeYArgs; + this->blocksizeZArgs = blocksizeZArgs; + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (!func->hasAttr(getByteIRCatFusionAttrName())) + return; + for (size_t i = 0; i < funcNames.size(); ++i) + if (func.getSymName() == funcNames[i]) + AttachTITConfigToAttr(func, titPtxPaths[i], gridsizeXArgs[i], gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], blocksizeYArgs[i], blocksizeZArgs[i]); + } +}; + +} // namespace + +std::unique_ptr> +mlir::createGenTITConfigPass( + ArrayRef funcNames, + ArrayRef titPtxPaths, + ArrayRef gridsizeXArgs, + ArrayRef gridsizeYArgs, + ArrayRef gridsizeZArgs, + ArrayRef blocksizeXArgs, + ArrayRef blocksizeYArgs, + ArrayRef blocksizeZArgs +) { + return std::make_unique(funcNames, titPtxPaths, gridsizeXArgs, gridsizeYArgs, gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs); +} diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index cf3387d39..722770d67 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -249,7 +249,7 @@ def decouple_triton_args(triton_args): pm = PassManager.parse(pm_str) exit pm.run(self.module.operation) - _print_verbose(self.module, "// IR Dump After Gen AIT Config:") if self.verbose else ... + _print_verbose(self.module, "// IR Dump After Gen TIT Config:") if self.verbose else ... return self.module From 8ae5ae3dac5406c6dd6bea4d849f8b4bc3b8090c Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 12 Jun 2025 13:53:44 +0800 Subject: [PATCH 03/26] [feat] tit demo --- .../lib/Conversion/ToTIT/GenTITConfig.cpp | 85 +++++++++++++------ .../byteir/dialects/cat/ir_processor.py | 3 +- 2 files changed, 62 insertions(+), 26 deletions(-) diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp index 98b7cad90..149694e31 100644 --- a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -16,6 +16,7 @@ //===----------------------------------------------------------------------===// #include "byteir/Conversion/ToTIT/ToTIT.h" +#include "byteir/Dialect/Byre/Common.h" #include "byteir/Dialect/mhlo/Transforms/HloFuser.h" #include "byteir/Utils/FuncUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -28,32 +29,61 @@ using namespace mlir; namespace { -static void AttachTITConfigToAttr(func::FuncOp func, - const std::string &titPtxPath, - const std::string &gridsizeXArg, - const std::string &gridsizeYArg, - const std::string &gridsizeZArg, - const std::string &blocksizeXArg, - const std::string &blocksizeYArg, - const std::string &blocksizeZArg - ) { + +static LogicalResult AttachTITConfigToAttr( + func::FuncOp func, + const std::string &titPtxPath, + const std::string &gridsizeXArg, + const std::string &gridsizeYArg, + const std::string &gridsizeZArg, + const std::string &blocksizeXArg, + const std::string &blocksizeYArg, + const std::string &blocksizeZArg) { + addGenericFuncAttrs(func, getByteIRTITOpKernelName().str()); + + std::string device_name; + if (titPtxPath.find(".ptx") != std::string::npos) { + device_name = "cuda"; + } + + if (device_name.empty()) { + return func.emitError("Invalid device type for TIT configuration"); + } mlir::OpBuilder opBuilder(func); - func->setAttr("tit_ptx_file", - opBuilder.getStringAttr(titPtxPath)); - func->setAttr("gridsize_x", - opBuilder.getStringAttr(gridsizeXArg)); - func->setAttr("gridsize_y", - opBuilder.getStringAttr(gridsizeYArg)); - func->setAttr("gridsize_z", - opBuilder.getStringAttr(gridsizeZArg)); - func->setAttr("blocksize_x", - opBuilder.getStringAttr(blocksizeXArg)); - func->setAttr("blocksize_y", - opBuilder.getStringAttr(blocksizeYArg)); - func->setAttr("blocksize_z", - opBuilder.getStringAttr(blocksizeZArg)); + llvm::StringMap titConfig; + + // Attach the Byre Tensor Info + titConfig["call_convention"] = opBuilder.getStringAttr("bare_ptr"); + titConfig["device"] = opBuilder.getStringAttr(device_name); + titConfig["device_file_name"] = opBuilder.getStringAttr(titPtxPath); + + + llvm::StringMap gpuLaunchArgs = { + {"BlockSize.x", blocksizeXArg}, + {"BlockSize.y", blocksizeYArg}, + {"BlockSize.z", blocksizeZArg}, + {"GridSize.x", gridsizeXArg}, + {"GridSize.y", gridsizeYArg}, + {"GridSize.z", gridsizeZArg}}; + + for (auto &kv : gpuLaunchArgs) { + int val; + if (kv.second.getAsInteger(0, val)) { + return func.emitError("Invalid integer format for ") << kv.first(); + } + if (val <= 0) { + return func.emitError("Value must be positive for ") << kv.first(); + } + titConfig[kv.first()] = opBuilder.getI32IntegerAttr(val); + } + + for (auto &kv : titConfig) { + func->setAttr(byre::getByrePrefix() + kv.first().str(), kv.second); + } + + return success(); } struct GenTITConfigPass : public GenTITConfigBase { @@ -83,8 +113,13 @@ struct GenTITConfigPass : public GenTITConfigBase { if (!func->hasAttr(getByteIRCatFusionAttrName())) return; for (size_t i = 0; i < funcNames.size(); ++i) - if (func.getSymName() == funcNames[i]) - AttachTITConfigToAttr(func, titPtxPaths[i], gridsizeXArgs[i], gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], blocksizeYArgs[i], blocksizeZArgs[i]); + if (func.getSymName() == funcNames[i]) { + if (failed(AttachTITConfigToAttr(func, titPtxPaths[i], gridsizeXArgs[i], + gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], + blocksizeYArgs[i], blocksizeZArgs[i]))) { + return signalPassFailure(); + } + } } }; diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 722770d67..6991a7c0e 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -242,7 +242,8 @@ def decouple_triton_args(triton_args): print("compilation finished in {}s".format(t_ed-t_st)) func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args = decouple_triton_args(triton_args) - + ptx_path_args= [os.path.split(path)[-1] for path in ptx_path_args] + with self.module.context: pm_str="builtin.module(func.func(gen-tit-config{{func-names={} tit-ptx-paths={} gridsize-x-args={} gridsize-y-args={} gridsize-z-args={} blocksize-x-args={} blocksize-y-args={} blocksize-z-args={}}}))".format(",".join(func_name_args), ",".join(ptx_path_args), ",".join(gridsize_x_args), ",".join(gridsize_y_args), ",".join(gridsize_z_args), ",".join(blocksize_x_args), ",".join(blocksize_y_args), ",".join(blocksize_z_args)) print(pm_str) From 584341d4298c0a648e1577b15468aa7dc23fcd17 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 12 Jun 2025 17:27:48 +0800 Subject: [PATCH 04/26] [fix] kernel_name bug --- compiler/include/byteir/Conversion/ToTIT/ToTIT.h | 2 -- compiler/lib/Conversion/ToTIT/GenTITConfig.cpp | 6 ++++-- .../python/byteir/dialects/cat/ir_processor.py | 15 +++++---------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h index a2199b078..ca063dc55 100644 --- a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h +++ b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h @@ -27,8 +27,6 @@ class FuncOp; } // namespace func class ModuleOp; -constexpr StringRef getByteIRTITOpKernelName() { return "TITOp"; } - std::unique_ptr> createGenTITConfigPass(ArrayRef funcNames = {""}, ArrayRef titPtxPaths = {""}, diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp index 149694e31..928e008c1 100644 --- a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -40,16 +40,18 @@ static LogicalResult AttachTITConfigToAttr( const std::string &blocksizeYArg, const std::string &blocksizeZArg) { - addGenericFuncAttrs(func, getByteIRTITOpKernelName().str()); std::string device_name; + std::string byreKernelName; if (titPtxPath.find(".ptx") != std::string::npos) { device_name = "cuda"; + byreKernelName="PTXOp"; } - if (device_name.empty()) { + if (device_name.empty()|| byreKernelName.empty()) { return func.emitError("Invalid device type for TIT configuration"); } + addGenericFuncAttrs(func, byreKernelName); mlir::OpBuilder opBuilder(func); llvm::StringMap titConfig; diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 6991a7c0e..86ff7d4e1 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -278,18 +278,13 @@ def profile(self, backend="ait"): def _parallel_ait_compile(workdir: str, func: FuncOp, output_lib_path, enable_tf32): - - def touch_blank_file(file_path): - with open(file_path, 'w') as f: - pass # os.environ["CUDA_VISIBLE_DEVICES"]=str(os.getpid() % available_cuda_device_num) from byteir.dialects.cat.ir_translator.ait_builder import AITBuilder - # builder = AITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) - # builder.compile() - # builder.benchmark() - # copyfile(builder.ait_module_path, output_lib_path) - # copymode(builder.ait_module_path, output_lib_path) - touch_blank_file(output_lib_path) + builder = AITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) + builder.compile() + builder.benchmark() + copyfile(builder.ait_module_path, output_lib_path) + copymode(builder.ait_module_path, output_lib_path) def _parallel_tit_compile(workdir: str, func: FuncOp, output_ptx_path, enable_tf32): From e343abe3b2cd8edcf6ee6fca005684bc084be79d Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 12 Jun 2025 17:28:51 +0800 Subject: [PATCH 05/26] [fix] compile print bug --- compiler/python/byteir/compile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 3b3dd6529..8852b2ad8 100755 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -332,7 +332,6 @@ def _compile_cuda_with_triton( pm = PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op outline-cat-op})") pm.run(processor.module.operation) _print_verbose(processor.module, "// IR Dump After Hlo Fusion Opt (with Cat):") if verbose else ... - print(processor.module) # not generate ait lib .so for cat functions processor.triton_opt_pass(output_file_dir) module = processor.module From b8facfaa02e57310c0024e20b29c080bb8807af8 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 12 Jun 2025 17:31:00 +0800 Subject: [PATCH 06/26] [fix] pm print --- compiler/python/byteir/dialects/cat/ir_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 86ff7d4e1..59c94ec67 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -246,7 +246,6 @@ def decouple_triton_args(triton_args): with self.module.context: pm_str="builtin.module(func.func(gen-tit-config{{func-names={} tit-ptx-paths={} gridsize-x-args={} gridsize-y-args={} gridsize-z-args={} blocksize-x-args={} blocksize-y-args={} blocksize-z-args={}}}))".format(",".join(func_name_args), ",".join(ptx_path_args), ",".join(gridsize_x_args), ",".join(gridsize_y_args), ",".join(gridsize_z_args), ",".join(blocksize_x_args), ",".join(blocksize_y_args), ",".join(blocksize_z_args)) - print(pm_str) pm = PassManager.parse(pm_str) exit pm.run(self.module.operation) From b489b68c8365a77408b16b693b45a11c46cdd6cb Mon Sep 17 00:00:00 2001 From: liushanghao Date: Mon, 16 Jun 2025 13:46:20 +0800 Subject: [PATCH 07/26] [module] add tritontemplate --- external/TritonTemplate/.gitignore | 146 ++++++++++++ external/TritonTemplate/README.md | 0 external/TritonTemplate/python/setup.py | 59 +++++ .../python/tritontemplate/__init__.py | 5 + .../python/tritontemplate/_libinfo.py | 1 + .../python/tritontemplate/backend/__init__.py | 0 .../tritontemplate/backend/cuda/__init__.py | 0 .../backend/cuda/gemm/__init__.py | 1 + .../backend/cuda/gemm/gemm_rcr.py | 114 ++++++++++ .../backend/cuda/gemm/gemm_rrr.py | 0 .../tritontemplate/compiler/__init__.py | 4 + .../python/tritontemplate/compiler/base.py | 113 ++++++++++ .../tritontemplate/compiler/compiler.py | 20 ++ .../python/tritontemplate/compiler/dtype.py | 153 +++++++++++++ .../python/tritontemplate/compiler/kernel.py | 20 ++ .../tritontemplate/compiler/op_registry.py | 23 ++ .../tritontemplate/compiler/ops/__init__.py | 1 + .../compiler/ops/gemm/__init__.py | 1 + .../tritontemplate/compiler/ops/gemm/gemm.py | 117 ++++++++++ .../tritontemplate/compiler/stable_set.py | 100 +++++++++ .../tritontemplate/compiler/symbolic.py | 154 +++++++++++++ .../python/tritontemplate/compiler/utils.py | 10 + .../python/tritontemplate/testing/__init__.py | 0 .../python/tritontemplate/testing/aot_demo.py | 208 ++++++++++++++++++ .../tritontemplate/testing/ptx_gen_demo.py | 77 +++++++ .../tritontemplate/testing/test_matmul.py | 62 ++++++ .../python/tritontemplate/utils/__init__.py | 0 .../tritontemplate/utils/tensor_utils.py | 28 +++ .../tritontemplate/utils/torch_utils.py | 86 ++++++++ 29 files changed, 1503 insertions(+) create mode 100644 external/TritonTemplate/.gitignore create mode 100644 external/TritonTemplate/README.md create mode 100644 external/TritonTemplate/python/setup.py create mode 100644 external/TritonTemplate/python/tritontemplate/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/_libinfo.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/base.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/compiler.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/dtype.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/kernel.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/op_registry.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/stable_set.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/symbolic.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/utils.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/aot_demo.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/test_matmul.py create mode 100644 external/TritonTemplate/python/tritontemplate/utils/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py create mode 100644 external/TritonTemplate/python/tritontemplate/utils/torch_utils.py diff --git a/external/TritonTemplate/.gitignore b/external/TritonTemplate/.gitignore new file mode 100644 index 000000000..8897298b9 --- /dev/null +++ b/external/TritonTemplate/.gitignore @@ -0,0 +1,146 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# tmp +tmp/ + +tags + +# macOS dir files +.DS_Store + +# PyCharm files +.idea + +# vscode +.vscode + +# vim temp files +*.swp diff --git a/external/TritonTemplate/README.md b/external/TritonTemplate/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/setup.py b/external/TritonTemplate/python/setup.py new file mode 100644 index 000000000..8b369e707 --- /dev/null +++ b/external/TritonTemplate/python/setup.py @@ -0,0 +1,59 @@ +import os +import shutil + +from setuptools import find_packages, setup + +CURRENT_DIR = os.path.dirname(__file__) +libinfo_py = os.path.join(CURRENT_DIR, "tritontemplate", "_libinfo.py") +libinfo = {} +with open(libinfo_py, "r") as f: + exec(f.read(), libinfo) +__version__ = libinfo["__version__"] + +def gen_file_list(srcs, f_cond): + file_list = [] + for src in srcs: + for root, _, files in os.walk(src): + value = [] + for file in files: + if f_cond(file): + path = os.path.join(root, file) + value.append(path.replace("tritontemplate/", "")) + file_list.extend(value) + return file_list + +def gen_backend_common_file_list(): + srcs = ["tritontemplate/backend"] + f_cond = lambda x: True if x.endswith(".py") else False + return gen_file_list(srcs, f_cond) + +def gen_utils_file_list(): + srcs = ["tritontemplate/utils"] + f_cond = lambda x: True if x.endswith(".py") else False + return gen_file_list(srcs, f_cond) + +def gen_compiler_file_list(): + srcs = ["tritontemplate/compiler"] + f_cond = lambda x: True if x.endswith(".py") else False + return gen_file_list(srcs, f_cond) + +setup_kwargs = {} +include_libs = True +wheel_include_libs = True + +setup( + name="tritontemplate", + version=__version__, + description="TritonTemplate: Make Flex Triton Templates for AI", + zip_safe=True, + install_requires=["torch>=2.1.0","triton"], + packages=find_packages(), + package_data={ + "tritontemplate": [] + + gen_utils_file_list() + + gen_backend_common_file_list() + + gen_compiler_file_list() + }, + python_requires=">=3.7, <4", + **setup_kwargs +) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/__init__.py b/external/TritonTemplate/python/tritontemplate/__init__.py new file mode 100644 index 000000000..a1bf64a27 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/__init__.py @@ -0,0 +1,5 @@ +import sys +from tritontemplate import backend,compiler,testing,utils +from tritontemplate._libinfo import __version__ + +__all__ = ["backend", "compiler", "testing", "utils"] diff --git a/external/TritonTemplate/python/tritontemplate/_libinfo.py b/external/TritonTemplate/python/tritontemplate/_libinfo.py new file mode 100644 index 000000000..45c421e85 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/_libinfo.py @@ -0,0 +1 @@ +__version__ = "dev0" \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py new file mode 100644 index 000000000..8bf9daa68 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr,gemm_rcr_bias,gen_grid_gemm_rcr \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py new file mode 100644 index 000000000..ed06e3514 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py @@ -0,0 +1,114 @@ +import triton +import triton.language as tl + +def gen_grid_gemm_rcr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + + return ( + triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1) + +@triton.jit +def gemm_rcr_bias( + a_ptr, b_ptr, bias_ptr, c_ptr, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + stride_a0: tl.constexpr, stride_a1: tl.constexpr, + stride_b0: tl.constexpr, stride_b1: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr +): + """ + Kernel for GEMM RCR (Row-Col-Row) + Bias + ReLU. + A (M, K) @ B (N, K)^T + Bias (N) -> C (M, N) + B is stored as (N, K) but accessed as if it's (K, N) for the matmul. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1) + b_ptrs = b_ptr + (offs_bn[None, :] * stride_b0 + offs_k[:, None] * stride_b1) # K is outer, N is inner + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) + + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_a1 + b_ptrs += BLOCK_SIZE_K * stride_b1 + + if bias_ptr is not None: + offs_bias_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + bias_ptrs = bias_ptr + offs_bias_n + bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + accumulator = accumulator + bias_vals[None, :] # Broadcast bias across M + + # Apply activation function + if ACTIVATION == 'relu': + accumulator = tl.maximum(accumulator, 0) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def gemm_rcr( + a_ptr, b_ptr, c_ptr, + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + stride_a0: tl.constexpr, stride_a1: tl.constexpr, + stride_b0: tl.constexpr, stride_b1: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr +): + """ + Kernel for GEMM RCR (Row-Col-Row) (+ReLU). + A (M, K) @ B (N, K)^T -> C (M, N) + B is stored as (N, K) but accessed as if it's (K, N) for the matmul. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1) + b_ptrs = b_ptr + (offs_bn[None, :] * stride_b0 + offs_k[:, None] * stride_b1) # K is outer, N is inner + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) + + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_a1 + b_ptrs += BLOCK_SIZE_K * stride_b1 + + # Apply activation function + if ACTIVATION == 'relu': + accumulator = tl.maximum(accumulator, 0) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/compiler/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/__init__.py new file mode 100644 index 000000000..aa1659943 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/__init__.py @@ -0,0 +1,4 @@ +from tritontemplate.compiler import base,dtype,op_registry,ops,symbolic +from tritontemplate.compiler.compiler import compile_kernel + +__all__ = ["base", "compile_kernel","dtype","op_registry","ops","symbolic",] \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py new file mode 100644 index 000000000..bf41aaf2d --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -0,0 +1,113 @@ +from abc import ABC,abstractmethod +from pprint import pformat +from typing import Any, Dict, Iterable, List, Optional, Set, Union + +class BaseType(ABC): + def __init__(self) -> None: + super().__init__() + self._attrs: Dict[str, Any] = {"name": None, "nop": False} + + def __str__(self) -> str: + return pformat(self._attrs, indent=2, depth=2) + + def __repr__(self) -> str: + return self.__str__() + + +class IntImm(int, BaseType): + def __new__(cls, val: int, divisibility: Optional[int] = None, name: Optional[str] = None): + instance = super().__new__(cls, val) + return instance + + def __init__( + self, + val: int, + divisibility: Optional[int] = None, + name: Optional[str] = None, + ) -> None: + BaseType.__init__(self) + self._attrs['val'] = int(self) + if name is not None: + self._attrs['name'] = name + self._attrs['divisibility'] = divisibility + + @property + def name(self) -> Optional[str]: + return self._attrs.get('name') + + @property + def divisibility(self) -> Optional[int]: + return self._attrs.get('divisibility') + + + + # set divisibility: + # @divisibility.setter + # def divisibility(self, value: Optional[int]): + # self._attrs['divisibility'] = value + + @property + def val(self) -> int: + return int(self) + + +class Tensor(BaseType): + """ + """ + def __init__( + self, + shape: List[IntImm], + dtype: str = "float16", + name: Optional[str] = None, + ) -> None: + super().__init__() + if name is not None: + self._attrs['name'] = name + self._attrs['dtype'] = dtype + self._attrs['shape'] = shape + + @property + def name(self) -> Optional[str]: + return self._attrs.get('name') + + @property + def dtype(self) -> str: + return self._attrs['dtype'] + + @property + def shape(self) -> List[IntImm]: + return self._attrs['shape'] + + +class Operation(BaseType): + """ + """ + def __init__( + self, + inputs: List[BaseType], + outputs: Optional[List[BaseType]], + name: Optional[str] = None, + ) -> None: + super().__init__() + self._attrs['inputs'] = inputs + if name is not None: + self._attrs['name'] = name + if outputs is not None: + self._attrs['outputs'] = outputs + + @property + def name(self) -> Optional[str]: + return self._attrs.get('name') + + @property + def inputs(self) -> List[Tensor]: + return self._attrs['inputs'] + + @property + def outputs(self) -> Optional[List[Tensor]]: + return self._attrs['outputs'] + + @abstractmethod + def compile(self,target_name,workdir): + raise NotImplementedError + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/compiler.py b/external/TritonTemplate/python/tritontemplate/compiler/compiler.py new file mode 100644 index 000000000..f97477306 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/compiler.py @@ -0,0 +1,20 @@ +from typing import List, Optional, Union +import logging +import importlib + +from tritontemplate import compiler,backend +from tritontemplate.compiler.kernel import TritonExecutor + +_LOGGER = logging.getLogger(__name__) + +def compile_kernel( + op: compiler.base.Operation, + device: str='cuda', + workdir: str='./workshop', +)->TritonExecutor: + try: + _ = importlib.import_module(f'tritontemplate.backend.{device}') + except ModuleNotFoundError: + raise ModuleNotFoundError(f'Target {device} not found') + return op.compile(device, workdir) + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/dtype.py b/external/TritonTemplate/python/tritontemplate/compiler/dtype.py new file mode 100644 index 000000000..5afdff586 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/dtype.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified by somehow6 on 2025-06-11 to tritontemplate. +""" +dtype definitions and utility functions of Tritontemplate +""" + + +_DTYPE2BYTE = { + "bool": 1, + "float16": 2, + "float32": 4, + "float": 4, + "int": 4, + "int32": 4, + "int64": 8, + "bfloat16": 2, +} + +_DTYPETRITONSIGNATURE = { + "float16": "fp16", + "float32": "fp32", + "float": "fp32", + "int": "i32", + "int32": "i32", + "int64": "i64", + "bfloat16": "bf16" +} + + +# Maps dtype strings to AITemplateDtype enum in model_interface.h. +# Must be kept in sync! +# We can consider defining an AITemplateDtype enum to use on the Python +# side at some point, but stick to strings for now to keep things consistent +# with other Python APIs. +_DTYPE_TO_ENUM = { + "float16": 1, + "float32": 2, + "float": 2, + "int": 3, + "int32": 3, + "int64": 4, + "bool": 5, + "bfloat16": 6, +} + + +def get_dtype_size(dtype: str) -> int: + """Returns size (in bytes) of the given dtype str. + + Parameters + ---------- + dtype: str + A data type string. + + Returns + ---------- + int + Size (in bytes) of this dtype. + """ + + if dtype not in _DTYPE2BYTE: + raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}") + return _DTYPE2BYTE[dtype] + + +def normalize_dtype(dtype: str) -> str: + """Returns a normalized dtype str. + + Parameters + ---------- + dtype: str + A data type string. + + Returns + ---------- + str + normalized dtype str. + """ + if dtype == "int": + return "int32" + if dtype == "float": + return "float32" + return dtype + + +def dtype_str_to_enum(dtype: str) -> int: + """Returns the AITemplateDtype enum value (defined in model_interface.h) of + the given dtype str. + + Parameters + ---------- + dtype: str + A data type string. + + Returns + ---------- + int + the AITemplateDtype enum value. + """ + if dtype not in _DTYPE_TO_ENUM: + raise ValueError( + f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}" + ) + return _DTYPE_TO_ENUM[dtype] + +def dtype_str_to_triton_signature(dtype: str) -> str: + """Returns the AITemplateDtype enum value (defined in model_interface.h) of + the given dtype str. + Parameters + ---------- + dtype: str + A data type string. + Returns + ---------- + int + the AITemplateDtype enum value. + """ + if dtype not in _DTYPETRITONSIGNATURE: + raise ValueError( + f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPETRITONSIGNATURE.keys())}" + ) + return _DTYPETRITONSIGNATURE[dtype] + + +def is_same_dtype(dtype1: str, dtype2: str) -> bool: + """Returns True if dtype1 and dtype2 are the same dtype and False otherwise. + + Parameters + ---------- + dtype1: str + A data type string. + dtype2: str + A data type string. + + Returns + ---------- + bool + whether dtype1 and dtype2 are the same dtype + """ + return normalize_dtype(dtype1) == normalize_dtype(dtype2) + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py new file mode 100644 index 000000000..e53535571 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -0,0 +1,20 @@ +from typing import Sequence +import triton + + +class TritonExecutor: + def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32): + self.triton_kernel = triton_kernel + self.gridsize = grid_size + self.blocksize = triton_kernel.num_warps * warp_size + self.warpsize = warp_size + self.name = triton_kernel.metadata['name'] + + def __call__(self, *args, **kwds): + return self.triton_kernel[self.gridsize](*args, **kwds) + + def kernel_ptx(self,func_name:str): + ptx = self.triton_kernel.asm['ptx'] + ptx = ptx.replace(self.name, func_name) + return ptx + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/op_registry.py b/external/TritonTemplate/python/tritontemplate/compiler/op_registry.py new file mode 100644 index 000000000..7b870c869 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/op_registry.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Registry for basic operators and math functions. +""" +from typing import Callable, Dict + +# OP_REGISTRY defines a mapping from a FuncEnum name to a function to create this elementwise operator. +# This object is initialized in elementwise.py, and referenced in base.py and math.py. +OP_REGISTRY: Dict[str, Callable] = {} diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py new file mode 100644 index 000000000..3b1fb08a3 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.gemm import Gemm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py new file mode 100644 index 000000000..3eec361e5 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.gemm.gemm import Gemm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py new file mode 100644 index 000000000..68280fe27 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -0,0 +1,117 @@ +from typing import List,Optional +import importlib + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize + +_supported_layouts = ['rcr','rrr'] +_supported_activations = ['relu',None] + + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 3, +} + +class Gemm(Operation): + def __init__( + self, + inputs: List[Tensor], + layout: str, + is_bias: bool = False, + is_transpose: bool = False, + outputs: Optional[List[Tensor]] = None, + activation: Optional[str] = None, + name: Optional[str] = None, + ) -> None: + assert layout in _supported_layouts, f'layout {layout} not supported' + assert activation in _supported_activations, f'activation {activation} not supported' + + super().__init__(inputs, outputs, name) + self.layout = layout + self.is_bias= is_bias + self._attrs['is_transpose'] = is_transpose + self._attrs['activation'] = activation + self._attrs['inputs'] = inputs + self._attrs['outputs'] = outputs if outputs is not None else self._induce_output_shape() + + + def _induce_output_shape(self): + # TODO: support transpose, by swap A,B + assert not self._attrs['is_transpose'], 'transpose not supported' + if self.layout == 'rcr': + M,N,K = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] + elif self.layout == 'rrr': + M,K,N = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] + else: + raise NotImplementedError(f'layout {self.layout} not supported') + return [Tensor(shape=[M,N],dtype=self._attrs['inputs'][0].dtype)] + def _gen_signature_divisiability(self): + signature_metadata={} + divisiability={1:[],16:[]} + for i,input in enumerate(self._attrs['inputs']+self._attrs['outputs']): + if isinstance(input,Tensor): + + try: + sptype='*'+dtype_str_to_triton_signature(input.dtype) + except KeyError: + raise KeyError(f'dtype {input.dtype} not supported') + signature_metadata[i]=sptype + # the ptr from torch is 16-byte aligned + if divisiability.get(16,None) is None: + divisiability[16]=[i] + else: + divisiability[16].append(i) + else: + raise NotImplementedError(f'input {input} not supported') + + return signature_metadata,divisiability + + def _gen_constants(self): + const_metadata={} + const_metadata['ACTIVATION'] = self._attrs['activation'] + if self.layout == 'rcr': + input=self._attrs['inputs'] + M,N,K=input[0].shape[0],input[1].shape[0],input[0].shape[1] + const_metadata['M']=M + const_metadata['N']=N + const_metadata['K']=K + const_metadata['stride_a0']=K + const_metadata['stride_a1']=1 + const_metadata['stride_b0']=N + const_metadata['stride_b1']=1 + const_metadata['stride_c0']=N + const_metadata['stride_c1']=1 + else: + raise NotImplementedError(f'layout {self.layout} not supported') + + const_metadata['BLOCK_SIZE_M']= 128 if M>=128 else triton.next_power_of_2(M) + const_metadata['BLOCK_SIZE_N']= 128 if N>=128 else triton.next_power_of_2(N) + const_metadata['BLOCK_SIZE_K']= 128 if K>=128 else triton.next_power_of_2(K) + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self,target_name,workdir)->TritonExecutor: + triton_kernel_name=f'gemm_{self.layout}'+ ('' if not self.is_bias else '_bias') + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm_{self.layout}') + + signature,divisiability=self._gen_signature_divisiability() + constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name)) + + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/stable_set.py b/external/TritonTemplate/python/tritontemplate/compiler/stable_set.py new file mode 100644 index 000000000..82f945078 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/stable_set.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A stable set is like a Python set which produces deterministic results. +It also tries to preserve the original element order as much as possible, which could +potentially make debugging (e.g. comparison with the original graph, comparison between +AIT GPU trace and other GPU traces) easier. +""" +from collections import abc +from typing import Any, Iterable + + +class StableSet(abc.MutableSet): + def __init__(self, s: Iterable[Any] = None): + if s is None: + s = [] + self._d = {item: None for item in s} + + def add(self, value) -> None: + self._d[value] = None + + def update(self, other) -> None: + for item in other: + self._d[item] = None + + def discard(self, value) -> None: + self._d.pop(value, None) + + def remove(self, value) -> None: + self._d.pop(value) + + def copy(self): + return StableSet(list(self._d)) + + def clear(self): + self._d = {} + + def __sub__(self, other): + res = self.copy() + for item in other: + res.discard(item) + return res + + def __str__(self) -> str: + return str(list(self._d)) + + def __repr__(self) -> str: + return str(list(self._d)) + + def __len__(self) -> int: + return len(self._d) + + def __contains__(self, value: Any) -> int: + return value in self._d + + def __iter__(self): + return list(self._d).__iter__() + + def _type_check(self, other): + if not isinstance(other, StableSet): + raise RuntimeError( + f"A StableSet can only be operated with another StableSet! " + f"Current type: {type(other)}." + ) + + def __eq__(self, other): + self._type_check(other) + return set(other._d) == set(self._d) + + def __le__(self, other): + self._type_check(other) + return set(self._d) <= set(other._d) + + def __lt__(self, other): + self._type_check(other) + return set(self._d) < set(other._d) + + def __ge__(self, other): + self._type_check(other) + return set(self._d) >= set(other._d) + + def __gt__(self, other): + self._type_check(other) + return set(self._d) > set(other._d) + + def __getitem__(self, idx): + return list(self._d)[idx] diff --git a/external/TritonTemplate/python/tritontemplate/compiler/symbolic.py b/external/TritonTemplate/python/tritontemplate/compiler/symbolic.py new file mode 100644 index 000000000..2edcfb195 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/symbolic.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Symbolic helpers for AITemplate. +AITemplate leverages Sympy to do symbolic computations for shapes. +The core of Sympy is surrounded around the class "Symbol". We could apply operations +on Symbols (i.e. add/mul/power/etc.) Which could help us do basic arithmetic with +unknown values. +The symbolic-ness comes from representation that includes Symbol (i.e. sym_1 + 100.) + +Example Usage: +A = IntVar(...) +sym_A = A.symbolic_value() # equivalent of A._attrs["symbolic_value"] + +# do something about sym_A, some common usage include: +new_sym = sym_A + 100 +new_sym = sym_A - 100 +new_sym = sym_A * 2 +new_sym = sym_A * sym_B + +# We could then assign the symbolic value to a new IntVar. +new_var = IntVar(..., symbolic_value=new_sym) + +For more advanced usage on Sympy, check: https://docs.sympy.org/latest/tutorials/intro-tutorial/intro.html +""" +from __future__ import annotations + +import itertools + +from numbers import Number +from typing import Any, List, Optional, Set + +import sympy + + +_k_symbolic_to_intvar = {} +_k_symbolic_index = 0 +_k_symbolic_value = {} + + +def create_new_symbol( + name: Optional[str] = None, + values: Optional[List[int]] = None, + check_duplicate: bool = False, +) -> sympy.Symbol: + """ + Creates and memoizing symbols. + + Parameters + ---------- + name : Optional[str] + The symbol name that is going to be used. If None is provided, an unused + name would be created. + values : Optional[List[int]] + The values for IntVar, which indicates the range of which the symbol could + represent. + check_duplicate : bool + If set as True and name is provided, we check whether the name and values + provided matches the corresponding symbol recorded. + """ + global _k_symbolic_index + global _k_symbolic_value + + if name is None: + while True: + name = f"_sym_{_k_symbolic_index}" + _k_symbolic_index += 1 + + if name not in _k_symbolic_value: + break + + values = sorted(set(values)) if values is not None else values + if ( + check_duplicate + and name in _k_symbolic_value + and _k_symbolic_value[name] != values + ): + raise ValueError( + f"Symbol ({name}) has different values! New value is {values}, stored value is {_k_symbolic_value[name]}" + ) + + _k_symbolic_value[name] = values + return sympy.Symbol(name) + + +def is_symbol(sym_val: Any) -> bool: + return isinstance(sym_val, sympy.Symbol) + + +def is_symbolic(sym_val: Any) -> bool: + """ + Check whether sym_val is a sympy class. + """ + return isinstance(sym_val, sympy.Basic) + + +def is_integer(sym_val: Any) -> bool: + # We wrap this since None is returned if sympy can't determine the property. + if is_symbolic(sym_val): + return sym_val.is_number and int(sym_val) - sym_val == 0 + elif isinstance(sym_val, Number): + return int(sym_val) - sym_val == 0 + + return False + + +def get_global_symbol_set() -> Set: + global _k_symbolic_value + return set(_k_symbolic_value.keys()) + + +def get_intvar(sym_name: str): + global _k_symbolic_to_intvar + + return _k_symbolic_to_intvar.get(sym_name, None) + + +def store_intvar(sym_name: str, int_var) -> None: + global _k_symbolic_to_intvar + + _k_symbolic_to_intvar[sym_name] = int_var + + +def simplify_intvar_values(sym_val: sympy.Basic): + """ + Given a symbolic value, resolve the symbol's value range. + + Example: + 'symbol_A' has value range of [10, 20] + simplify_intvar_values(symbol_A * 3 + 4) returns [34, 64] + """ + global _k_symbolic_value + + symbols = list(sym_val.free_symbols) + symbol_shapes = [_k_symbolic_value[s.name] for s in symbols] + symbol_shapes = [s for s in symbol_shapes if s is not None] + shape_perms = list(itertools.product(*symbol_shapes)) + + new_shape = [int(sym_val.subs(zip(symbols, s))) for s in shape_perms] + new_shape = sorted(set(new_shape)) + + return new_shape diff --git a/external/TritonTemplate/python/tritontemplate/compiler/utils.py b/external/TritonTemplate/python/tritontemplate/compiler/utils.py new file mode 100644 index 000000000..96ab615b0 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/utils.py @@ -0,0 +1,10 @@ + +_TARGET2WARPSIZE={ + 'cuda':32, +} + +def get_warpsize(target_name): + try: + return _TARGET2WARPSIZE[target_name] + except KeyError: + raise KeyError(f'target {target_name} not supported') \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/__init__.py b/external/TritonTemplate/python/tritontemplate/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/testing/aot_demo.py b/external/TritonTemplate/python/tritontemplate/testing/aot_demo.py new file mode 100644 index 000000000..ef4a5f40a --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/aot_demo.py @@ -0,0 +1,208 @@ +import glob +import os +import subprocess +import sys +import tempfile + +import numpy as np + +import triton +from triton.common import cuda_include_dir, libcuda_dirs + +kernel_utils_src = """ +import triton + +@triton.jit +def mul(x, y): + return x * y +""" + +kernel_src = """ +import triton +import triton.language as tl +import kernel_utils + +@triton.jit +def kernel(C, A, B, + stride_cm, stride_cn, + stride_am, stride_ak, + stride_bk, stride_bn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + ms = tl.arange(0, BLOCK_M) + ns = tl.arange(0, BLOCK_N) + ks = tl.arange(0, BLOCK_K) + a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak) + b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn) + c = tl.dot(a, b) + c = kernel_utils.mul(c, c) + tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c) +""" + +test_src = """ +#include +#include +#include +#include +#include "kernel.h" + +static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) { + FILE *file = fopen(filename, "w"); + if (file == NULL) { + printf(\"Could not open file %s\\n\", filename); + return; + } + for (int i = 0; i < size; i++) { + fprintf(file, "%d", buffer[i]); + if (i < size - 1) { + fprintf(file, ","); + } + } + fclose(file); +} + +static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + printf(\"Could not open file %s\\n\", filename); + return; + } + int index = 0; + while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) { + index++; + } + fclose(file); +} + +int main(int argc, char **argv) { + int M = 16, N = 16, K = 16; + int BM = 16, BN = 16, BK = 16; + + // initialize CUDA handles + CUdevice dev; + CUcontext ctx; + CUstream stream; + CUdeviceptr A, B, C; + CUresult err = 0; + cuInit(0); + cuDeviceGet(&dev, 0); + cuCtxCreate(&ctx, 0, dev); + cuMemAlloc(&A, M * K * 2); + cuMemAlloc(&B, K * N * 2); + cuMemAlloc(&C, M * N * 4); + cuStreamCreate(&stream, 0); + load_matmul_fp16xfp16_16x16x16(); + + // initialize input data + int16_t hA[M*K]; + int16_t hB[K*N]; + memset(hA, 0, M*K*2); + memset(hB, 0, K*N*2); + read_csv_to_buffer(argv[1], hA, M*K); + read_csv_to_buffer(argv[2], hB, K*N); + cuMemcpyHtoD(A, hA, M*K*2); + cuMemcpyHtoD(B, hB, K*N*2); + + // launch kernel + int gX = 1, gY = 1, gZ = 1; + cuStreamSynchronize(stream); + matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, K, N); + cuStreamSynchronize(stream); + + // read data + int32_t hC[M*N]; + memset(hC, 0, M*N*4); + cuMemcpyDtoH(hC, C, M*N*4); + write_buffer_to_csv(argv[3], hC, M*N); + + + // free cuda handles + unload_matmul_fp16xfp16_16x16x16(); + cuMemFree(A); + cuMemFree(B); + cuMemFree(C); + cuCtxDestroy(ctx); +} +""" + + +def test_compile_link_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + kernel_path = os.path.join(tmp_dir, "kernel.py") + with open(kernel_path, "w") as file: + file.write(kernel_src) + + kernel_utils_path = os.path.join(tmp_dir, "kernel_utils.py") + with open(kernel_utils_path, "w") as file: + file.write(kernel_utils_src) + + compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") + linker_path = os.path.join(triton.tools.__path__[0], "link.py") + + dtype = "fp16" + M, N, K = 16, 16, 16 + BM, BN, BK = 16, 16, 16 + + # compile all desired configs + hints = [":16", ""] + for ha in hints: + for hb in hints: + sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, 1, i32{hb}, 1, i32:16, 1, {BM}, {BN}, {BK}' + name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}" + subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=tmp_dir) + + # link all desired configs + h_files = glob.glob(os.path.join(tmp_dir, "*.h")) + subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=tmp_dir) + + # compile test case + with open(os.path.join(tmp_dir, "test.c"), "w") as file: + file.write(test_src) + c_files = glob.glob(os.path.join(tmp_dir, "*.c")) + subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(), + "-L", libcuda_dirs()[0], + "-l", "cuda", + "-o", "test"], check=True, cwd=tmp_dir) + + # initialize test data + a = np.random.randn(M * K).astype(np.float16).reshape((M, K)) + b = np.random.randn(M * K).astype(np.float16).reshape((K, N)) + a_path = os.path.join(tmp_dir, "a.csv") + b_path = os.path.join(tmp_dir, "b.csv") + c_path = os.path.join(tmp_dir, "c.csv") + for x, path in [(a, a_path), (b, b_path)]: + x.view(np.int16).ravel().tofile(path, sep=",") + + # run test case + subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.) + + +def test_ttgir_to_ptx(): + src = """ +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { + tt.return + } +} +""" + with tempfile.TemporaryDirectory() as tmp_dir: + kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") + with open(kernel_path, "w") as fp: + fp.write(src) + k = triton.compile(kernel_path, cc=80) + ptx = k.asm["ptx"] + assert ".target sm_80" in ptx + assert ".address_size 64" in ptx + +if __name__ == "__main__": + test_compile_link_matmul() + test_ttgir_to_ptx() \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py b/external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py new file mode 100644 index 000000000..a652b73dc --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py @@ -0,0 +1,77 @@ +import triton +import triton.language as tl +import torch + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, # Pointers to matrices + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, # Matrix dimensions + stride_am: tl.constexpr, stride_ak: tl.constexpr, # A matrix strides + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # B matrix strides + stride_cm: tl.constexpr, stride_cn: tl.constexpr, # C matrix strides + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # Tile sizes +): + # Get program IDs + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + # Create offsets for the block + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rk = tl.arange(0, BLOCK_SIZE_K) + + # Create masks to avoid out-of-bounds accesses + a_mask = (rm[:, None] < M) & (rk[None, :] < K) + b_mask = (rk[:, None] < K) & (rn[None, :] < N) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over K dimension + for k in range(0, K, BLOCK_SIZE_K): + # Load tiles from A and B + a = tl.load(a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak, mask=a_mask) + b = tl.load(b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn, mask=b_mask) + + # Compute matrix multiplication + acc += tl.dot(a, b) + + # Store result + tl.store(c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn, acc, mask=c_mask) + +def compile_matmul_kernel_to_ptx(filename="matmul_kernel.ptx"): + # Define the signature of the kernel + signature = { + 0: '*fp32', # a_ptr + 1: '*fp32', # b_ptr + 2: '*fp32', # c_ptr + } + + # Define compile-time constants + constants = { + 'M': 1024, 'N': 1024, 'K': 1024, # Example matrix dimensions + 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, # Tile sizes + 'stride_am': 1024, 'stride_ak': 1, # Assuming row-major layout + 'stride_bk': 1024, 'stride_bn': 1, # Assuming row-major layout + 'stride_cm': 1024, 'stride_cn': 1, # Assuming row-major layout + } + # AOT compile the kernel + compiled_kernel = triton.compile( + matmul_kernel, + signature=signature, + constants=constants, + num_warps=16, + ) + # Get the PTX assembly code + ptx_code = compiled_kernel.asm['ptx'] + print(f"Number of warps: {compiled_kernel.num_warps}") + + + # Save the PTX code to a file + with open(filename, "w") as f: + f.write(ptx_code) + +if __name__ == "__main__": + # AOT compile to PTX + compile_matmul_kernel_to_ptx("matmul_kernel.ptx") \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py new file mode 100644 index 000000000..faa056a05 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py @@ -0,0 +1,62 @@ +import torch +import pytest + +from tritontemplate.compiler.base import IntImm,Tensor +from tritontemplate.compiler.ops.gemm import Gemm +from tritontemplate.compiler.compiler import compile_kernel + +def gen_gemm_rcr_bias_relu(M, N, K): + A = Tensor(name='A', dtype='float16', shape=[IntImm(M), IntImm(K)]) + B = Tensor(name='B', dtype='float16', shape=[IntImm(N), IntImm(K)]) + Bias = Tensor(name='Bias', dtype='float16', shape=[IntImm(N)]) + C = Tensor(name='C', dtype='float16', shape=[IntImm(M), IntImm(N)]) + + gemm_op = Gemm( + inputs=[A, B, Bias], + outputs=[C], + layout='rcr', + is_bias=True, + is_transpose=False, + activation='relu', + ) + + kernel = compile_kernel(gemm_op, target_name='cuda') + return kernel + +def gemm_rcr_bias_relu(a, b, bias): + M, K = a.shape + N, K_b = b.shape + assert K == K_b, "K dimension mismatch" + + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + kernel = gen_gemm_rcr_bias_relu(M, N, K) + kernel(a, b, c, bias) + return c + +@pytest.mark.parametrize( + 'M, N, K', + [ + (1024, 1024, 1024), + (1024, 1024, 512), + (1024, 1024, 256), + (1023,1023,1023), + (1025,1025,511), + ], +) +def test_gemm_rcr_bias_relu(M, N, K): + A = torch.randn((M, K), dtype=torch.float16, device='cuda') + B = torch.randn((N, K), dtype=torch.float16, device='cuda') + Bias = torch.randn((N,), dtype=torch.float16, device='cuda') + + # Triton and PyTorch outputs + c_triton = gemm_rcr_bias_relu(A, B, Bias) + y_torch = torch.relu(torch.nn.functional.linear(A, B, bias=Bias)) + + if not torch.allclose(c_triton, y_torch, atol=1e-2, rtol=1e-2): + print("Outputs mismatch!") + diff = torch.abs(c_triton - y_torch) + print("Max diff:", torch.max(diff), "Mean diff:", torch.mean(diff)) + + + diff --git a/external/TritonTemplate/python/tritontemplate/utils/__init__.py b/external/TritonTemplate/python/tritontemplate/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py b/external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py new file mode 100644 index 000000000..2042a44ec --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Util functions to handle tensor shapes. +""" + + +def wrap_dim(idx, rank): + """ + Wrap tensor index, idx, if it's negative. + """ + assert isinstance(idx, int), "idx must be int, but got {}".format(type(idx)) + if idx < 0: + idx = idx + rank + assert idx < rank, "idx {} out of range; rank {}".format(idx, rank) + return idx diff --git a/external/TritonTemplate/python/tritontemplate/utils/torch_utils.py b/external/TritonTemplate/python/tritontemplate/utils/torch_utils.py new file mode 100644 index 000000000..133e120eb --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/utils/torch_utils.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Functions for working with torch Tensors. +AITemplate doesn't depend on PyTorch, but it exposes +many APIs that work with torch Tensors anyways. + +The functions in this file may assume that +`import torch` will work. +""" + +import struct + +import torch + +from tritontemplate.compiler.dtype import dtype_str_to_enum, get_dtype_size, normalize_dtype + + +def types_mapping(): + from torch import bfloat16, bool, float16, float32, int32, int64 + + yield (float16, "float16") + yield (bfloat16, "bfloat16") + yield (float32, "float32") + yield (int32, "int32") + yield (int64, "int64") + yield (bool, "bool") + + +def torch_dtype_to_string(dtype): + for (torch_dtype, ait_dtype) in types_mapping(): + if dtype == torch_dtype: + return ait_dtype + raise ValueError( + f"Got unsupported input dtype {dtype}! " + f"Supported dtypes are: {list(types_mapping())}" + ) + + +def string_to_torch_dtype(string_dtype): + if string_dtype is None: + # Many torch functions take optional dtypes, so + # handling None is useful here. + return None + + for (torch_dtype, ait_dtype) in types_mapping(): + if string_dtype == ait_dtype: + return torch_dtype + raise ValueError( + f"Got unsupported ait dtype {string_dtype}! " + f"Supported dtypes are: {list(types_mapping())}" + ) + + +def write_tensor_binary(tensor: "torch.Tensor", file_handle) -> None: + tensor = tensor.detach().cpu().contiguous() + endianness = "@" # system endianness + dtype_str = normalize_dtype(torch_dtype_to_string(tensor.dtype)) + dtype_int = dtype_str_to_enum(dtype_str) + sizeof_dtype = get_dtype_size(dtype_str) + num_dims = len(tensor.shape) + file_handle.write(struct.pack(endianness + "I", dtype_int)) # unsigned int + file_handle.write(struct.pack(endianness + "I", sizeof_dtype)) # unsigned int + file_handle.write(struct.pack(endianness + "I", num_dims)) # unsigned int + total_size = sizeof_dtype + for dim in tensor.shape: + file_handle.write(struct.pack(endianness + "N", dim)) # size_t + total_size *= dim + file_handle.write(struct.pack(endianness + "N", total_size)) # size_t + bytedata = tensor.numpy().tobytes() + # just as a safety check + if len(bytedata) != total_size: + raise RuntimeError("Tensor has wrong number of bytes!") + file_handle.write(bytedata) From 396edb1389bf57afc221a62ca19aa47da194b5c4 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Tue, 17 Jun 2025 17:18:11 +0800 Subject: [PATCH 08/26] [feat] ShareMemSize info --- compiler/include/byteir/Conversion/Passes.td | 2 ++ .../include/byteir/Conversion/ToTIT/ToTIT.h | 1 + .../lib/Conversion/ToTIT/GenTITConfig.cpp | 9 +++++-- .../byteir/dialects/cat/ir_processor.py | 25 ++++++++++--------- .../dialects/cat/ir_translator/tit_builder.py | 1 + .../python/byteir/dialects/cat/tit_cache.py | 6 ++--- .../backend/cuda/gemm/gemm_rcr.py | 9 ++++++- .../python/tritontemplate/compiler/base.py | 2 -- .../python/tritontemplate/compiler/kernel.py | 13 ++++++++++ .../tritontemplate/compiler/ops/gemm/gemm.py | 2 +- .../python/tritontemplate/compiler/utils.py | 25 ++++++++++++++++++- .../tritontemplate/testing/test_matmul.py | 1 + 12 files changed, 74 insertions(+), 22 deletions(-) diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index b0ff13397..760c4542f 100755 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -167,6 +167,8 @@ def GenTITConfig : Pass<"gen-tit-config", "func::FuncOp"> { "names of all cat func for TIT backends.">, ListOption<"titPtxPaths", "tit-ptx-paths", "std::string", "paths to all TIT-generated .ptx files">, + ListOption<"smemsizeArgs", "smemsize-args", "std::string", + "smemsize args for TIT backends.">, ListOption<"gridsizeXArgs", "gridsize-x-args", "std::string", "gridsize x args for TIT backends.">, ListOption<"gridsizeYArgs", "gridsize-y-args", "std::string", diff --git a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h index ca063dc55..61998abac 100644 --- a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h +++ b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h @@ -30,6 +30,7 @@ class ModuleOp; std::unique_ptr> createGenTITConfigPass(ArrayRef funcNames = {""}, ArrayRef titPtxPaths = {""}, + ArrayRef smemsizeArgs = {""}, ArrayRef gridsizeXArgs = {""}, ArrayRef gridsizeYArgs = {""}, ArrayRef gridsizeZArgs = {""}, diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp index 928e008c1..a331c6e37 100644 --- a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -33,6 +33,7 @@ namespace { static LogicalResult AttachTITConfigToAttr( func::FuncOp func, const std::string &titPtxPath, + const std::string &smemsizeArg, const std::string &gridsizeXArg, const std::string &gridsizeYArg, const std::string &gridsizeZArg, @@ -63,6 +64,7 @@ static LogicalResult AttachTITConfigToAttr( llvm::StringMap gpuLaunchArgs = { + {"SharedMemorySize", smemsizeArg}, {"BlockSize.x", blocksizeXArg}, {"BlockSize.y", blocksizeYArg}, {"BlockSize.z", blocksizeZArg}, @@ -92,6 +94,7 @@ struct GenTITConfigPass : public GenTITConfigBase { GenTITConfigPass( ArrayRef funcNames, ArrayRef titPtxPaths, + ArrayRef smemsizeArgs, ArrayRef gridsizeXArgs, ArrayRef gridsizeYArgs, ArrayRef gridsizeZArgs, @@ -102,6 +105,7 @@ struct GenTITConfigPass : public GenTITConfigBase { : GenTITConfigBase() { this->funcNames = funcNames; this->titPtxPaths = titPtxPaths; + this->smemsizeArgs = smemsizeArgs; this->gridsizeXArgs = gridsizeXArgs; this->gridsizeYArgs = gridsizeYArgs; this->gridsizeZArgs = gridsizeZArgs; @@ -116,7 +120,7 @@ struct GenTITConfigPass : public GenTITConfigBase { return; for (size_t i = 0; i < funcNames.size(); ++i) if (func.getSymName() == funcNames[i]) { - if (failed(AttachTITConfigToAttr(func, titPtxPaths[i], gridsizeXArgs[i], + if (failed(AttachTITConfigToAttr(func, titPtxPaths[i], smemsizeArgs[i], gridsizeXArgs[i], gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], blocksizeYArgs[i], blocksizeZArgs[i]))) { return signalPassFailure(); @@ -131,6 +135,7 @@ std::unique_ptr> mlir::createGenTITConfigPass( ArrayRef funcNames, ArrayRef titPtxPaths, + ArrayRef smemsizeArgs, ArrayRef gridsizeXArgs, ArrayRef gridsizeYArgs, ArrayRef gridsizeZArgs, @@ -138,5 +143,5 @@ mlir::createGenTITConfigPass( ArrayRef blocksizeYArgs, ArrayRef blocksizeZArgs ) { - return std::make_unique(funcNames, titPtxPaths, gridsizeXArgs, gridsizeYArgs, gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs); + return std::make_unique(funcNames, titPtxPaths,smemsizeArgs, gridsizeXArgs, gridsizeYArgs, gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs); } diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 59c94ec67..7c2e3a47d 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -172,7 +172,8 @@ def decouple_triton_args(triton_args): blocksize_x_args = [] blocksize_y_args = [] blocksize_z_args = [] - for func_name,ptx_path,gridsize,blocksize in triton_args: + smemsize_args = [] + for func_name,ptx_path,gridsize,blocksize,smem_size in triton_args: func_name_args.append(func_name) ptx_path_args.append(ptx_path) gridsize_x_args.append(str(gridsize[0])) @@ -182,7 +183,8 @@ def decouple_triton_args(triton_args): blocksize_x_args.append(str(blocksize)) blocksize_y_args.append(str(1)) blocksize_z_args.append(str(1)) - return func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args + smemsize_args.append(str(smem_size)) + return func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args,smemsize_args self.pool=None @@ -205,11 +207,11 @@ def decouple_triton_args(triton_args): # gridsize form (x,y,z), blocksize form (x,y,z) cached_argv = self.byteir_cache.find(gpu_type, hash_str) if cached_argv: - cache_ptx,gridsize,blocksize = cached_argv + cache_ptx,gridsize,blocksize,smemsize = cached_argv print(f"func {func.name.value} cache hit") copyfile(cache_ptx, output_ptx_path) copymode(cache_ptx, output_ptx_path) - triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize)) + triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize,smemsize)) else: work_items.append(func) @@ -231,23 +233,22 @@ def decouple_triton_args(triton_args): self.pool.close() self.pool.join() - for func,output_ptx_path,gridsize,blocksize in new_args: - triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize)) + for func,output_ptx_path,gridsize,blocksize,smemsize in new_args: + triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize,smemsize)) self.byteir_cache.load_or_create_cache() - self.byteir_cache.add(gpu_type, func_hash_str(func, gpu_type), (output_ptx_path, gridsize, blocksize), override=False) + self.byteir_cache.add(gpu_type, func_hash_str(func, gpu_type), (output_ptx_path, gridsize, blocksize,smemsize), override=False) self.byteir_cache._save() self.byteir_cache.close_cache() t_ed = time.time() print("compilation finished in {}s".format(t_ed-t_st)) - func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args = decouple_triton_args(triton_args) + func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args,smemsize_args = decouple_triton_args(triton_args) ptx_path_args= [os.path.split(path)[-1] for path in ptx_path_args] with self.module.context: - pm_str="builtin.module(func.func(gen-tit-config{{func-names={} tit-ptx-paths={} gridsize-x-args={} gridsize-y-args={} gridsize-z-args={} blocksize-x-args={} blocksize-y-args={} blocksize-z-args={}}}))".format(",".join(func_name_args), ",".join(ptx_path_args), ",".join(gridsize_x_args), ",".join(gridsize_y_args), ",".join(gridsize_z_args), ",".join(blocksize_x_args), ",".join(blocksize_y_args), ",".join(blocksize_z_args)) + pm_str="builtin.module(func.func(gen-tit-config{{func-names={} tit-ptx-paths={} smemsize-args={} gridsize-x-args={} gridsize-y-args={} gridsize-z-args={} blocksize-x-args={} blocksize-y-args={} blocksize-z-args={}}}))".format(",".join(func_name_args), ",".join(ptx_path_args),",".join(smemsize_args), ",".join(gridsize_x_args), ",".join(gridsize_y_args), ",".join(gridsize_z_args), ",".join(blocksize_x_args), ",".join(blocksize_y_args), ",".join(blocksize_z_args)) pm = PassManager.parse(pm_str) - exit pm.run(self.module.operation) _print_verbose(self.module, "// IR Dump After Gen TIT Config:") if self.verbose else ... @@ -291,7 +292,7 @@ def _parallel_tit_compile(workdir: str, func: FuncOp, output_ptx_path, enable_tf from byteir.dialects.cat.ir_translator.tit_builder import TITBuilder builder = TITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) builder.compile() - blockSize,gridsize=builder.blocksize,builder.gridsize + blockSize,gridsize,smemsize=builder.blocksize,builder.gridsize,builder.smemsize copyfile(builder.tit_module_path, output_ptx_path) copymode(builder.tit_module_path, output_ptx_path) - return func,output_ptx_path,gridsize,blockSize \ No newline at end of file + return func,output_ptx_path,gridsize,blockSize,smemsize \ No newline at end of file diff --git a/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py index 43d37e522..d7ff884b3 100644 --- a/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py +++ b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py @@ -128,6 +128,7 @@ def _gen_tit_kernel(self, results): f.write(self.tit_kernel.kernel_ptx(self.subgraph_name)) self.gridsize = self.tit_kernel.gridsize self.blocksize = self.tit_kernel.blocksize + self.smemsize = self.tit_kernel.smemsize def _gen_runtime_tensor(self, tensor: Tensor): shape = tensor.shape() diff --git a/compiler/python/byteir/dialects/cat/tit_cache.py b/compiler/python/byteir/dialects/cat/tit_cache.py index 0bdb8f4e4..8c28c8813 100644 --- a/compiler/python/byteir/dialects/cat/tit_cache.py +++ b/compiler/python/byteir/dialects/cat/tit_cache.py @@ -81,7 +81,7 @@ def add(self, gpu_type, key, argv, override = False): value = "{:0>16}.ptx".format(self.idx) self.idx += 1 self.cache[IDX_KEY] = self.idx - self.cache[gpu_type][key] = (value, argv[1], argv[2]) + self.cache[gpu_type][key] = (value, argv[1], argv[2],argv[3]) ptx_path = argv[0] copyfile(ptx_path, os.path.join(self.cache_dir, value)) copymode(ptx_path, os.path.join(self.cache_dir, value)) @@ -90,8 +90,8 @@ def find(self, gpu_type, key): if gpu_type not in self.cache: return None if key in self.cache[gpu_type]: - lib_path, grid_size, block_size = self.cache[gpu_type][key] - return os.path.join(self.cache_dir, lib_path), grid_size, block_size + lib_path, gridsize, blocksize,smemsize = self.cache[gpu_type][key] + return os.path.join(self.cache_dir, lib_path), gridsize, blocksize,smemsize else: return None diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py index ed06e3514..6b9c75704 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py @@ -6,6 +6,12 @@ def gen_grid_gemm_rcr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): return ( triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1) +# smem_size=val*dtype_size*num_stage +smem_demand_per_stage ={ + 'gemm_rcr_bias': 128*128*2, + 'gemm_rcr': 128*128*2, +} + @triton.jit def gemm_rcr_bias( a_ptr, b_ptr, bias_ptr, c_ptr, @@ -111,4 +117,5 @@ def gemm_rcr( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) \ No newline at end of file + tl.store(c_ptrs, accumulator, mask=c_mask) + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py index bf41aaf2d..3b0a110fe 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/base.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -39,8 +39,6 @@ def name(self) -> Optional[str]: def divisibility(self) -> Optional[int]: return self._attrs.get('divisibility') - - # set divisibility: # @divisibility.setter # def divisibility(self, value: Optional[int]): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py index e53535571..45463b3a7 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -1,6 +1,7 @@ from typing import Sequence import triton +from tritontemplate.compiler.utils import get_device_max_shared_memory,get_cuda_device_name class TritonExecutor: def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32): @@ -9,6 +10,18 @@ def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_siz self.blocksize = triton_kernel.num_warps * warp_size self.warpsize = warp_size self.name = triton_kernel.metadata['name'] + self.smemsize = triton_kernel.shared + + self.device_name = get_cuda_device_name() + try: + self.device_name = get_cuda_device_name() + assert self.smemsize <= get_device_max_shared_memory(self.device_name), \ + f'kernel {self.name} smem size {self.smemsize} exceeds device {self.device_name} max smem size {get_device_max_shared_memory(self.device_name)}' + except KeyError as e: + # Log the error and continue with default values + import logging + logging.warning(f"Unsupported device detected: {str(e)}. Continuing with default configuration.") + self.device_name = "unknown" def __call__(self, *args, **kwds): return self.triton_kernel[self.gridsize](*args, **kwds) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 68280fe27..7b709761d 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -14,7 +14,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 3, + 'num_stages': 1, } class Gemm(Operation): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/utils.py b/external/TritonTemplate/python/tritontemplate/compiler/utils.py index 96ab615b0..28cb397eb 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/utils.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/utils.py @@ -1,10 +1,33 @@ +import subprocess _TARGET2WARPSIZE={ 'cuda':32, } +_DEVICE_MAX_SHARED_MEMORY={ + "NVIDIA H800": 227 * 1024, + "NVIDIA H100": 227 * 1024, + "NVIDIA A100": 164 * 1024, + "NVIDIA A800": 164 * 1024, + "NVIDIA V100": 96 * 1024, + "NVIDIA T4": 64 * 1024, +} + +def get_cuda_device_name(idx=0): + cmd = "nvidia-smi --query-gpu=name --format=csv,noheader" + result = subprocess.check_output(cmd, shell=True) + gpu_names = result.decode().strip().split("\n") + return gpu_names[idx] + def get_warpsize(target_name): try: return _TARGET2WARPSIZE[target_name] except KeyError: - raise KeyError(f'target {target_name} not supported') \ No newline at end of file + raise KeyError(f'target {target_name} not supported') + +def get_device_max_shared_memory(target_name): + try: + return _DEVICE_MAX_SHARED_MEMORY[target_name] + except KeyError: + raise KeyError(f'target {target_name} not supported, please add max smem size info') + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py index faa056a05..ff936b06f 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py +++ b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py @@ -40,6 +40,7 @@ def gemm_rcr_bias_relu(a, b, bias): (1024, 1024, 1024), (1024, 1024, 512), (1024, 1024, 256), + (128, 512, 256), (1023,1023,1023), (1025,1025,511), ], From f9c6370fef5182819ec51e91ba579eb8b61c64dc Mon Sep 17 00:00:00 2001 From: liushanghao Date: Wed, 18 Jun 2025 11:08:27 +0800 Subject: [PATCH 09/26] [feat] runtime share memory supported --- .../lib/backends/cuda/device/cuda_work_queue.cc | 17 +++++++++++++++-- .../cuda/providers/default/codegen/ptx.cc | 9 ++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/runtime/lib/backends/cuda/device/cuda_work_queue.cc b/runtime/lib/backends/cuda/device/cuda_work_queue.cc index 8eecda95e..615c0e270 100644 --- a/runtime/lib/backends/cuda/device/cuda_work_queue.cc +++ b/runtime/lib/backends/cuda/device/cuda_work_queue.cc @@ -16,7 +16,6 @@ //===----------------------------------------------------------------------===// #include "brt/backends/cuda/device/cuda_work_queue.h" - #include "brt/backends/cuda/device/common/cuda_call.h" #include "brt/core/common/common.h" #include "brt/core/ir/ir.h" @@ -73,9 +72,23 @@ inline common::Status ComputeDrv(const void *func, void **args, dim3 *grid = static_cast(args[0]); dim3 *block = static_cast(args[1]); size_t *shared_size = static_cast(args[2]); + //TODO: unsafe operation? + CUfunction hFunc=reinterpret_cast(const_cast(func)); + + //extend the shared memory + int shared_optin; + int device_id=-1; + BRT_CUDA_CHECK(cudaGetDevice(&device_id)); + BRT_CUDA_CHECK(cudaDeviceGetAttribute(&shared_optin, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + + if (shared_optin>49152&&(*shared_size)>49152){ + BRT_CU_CHECK(cuFuncSetCacheConfig(hFunc, CU_FUNC_CACHE_PREFER_SHARED)); + BRT_CU_CHECK(cuFuncSetAttribute(hFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)); + } + void **kernel_args = args + 3; return BRT_CU_CALL( - cuLaunchKernel(reinterpret_cast(const_cast(func)), + cuLaunchKernel(hFunc, (*grid).x, (*grid).y, (*grid).z, (*block).x, (*block).y, (*block).z, *shared_size, stream, kernel_args, 0)); } diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index dfd6349f6..25256d620 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -37,6 +37,7 @@ using namespace mlir; #define FILE_NAME_ATTR "device_file_name" #define KERNEL_NAME_ATTR "kernel_name" +#define SHARED_SIZE_ATTR "SharedMemorySize" #define GRID_SIZE_X_ATTR "GridSize.x" #define GRID_SIZE_Y_ATTR "GridSize.y" #define GRID_SIZE_Z_ATTR "GridSize.z" @@ -138,6 +139,7 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr"); } + size_t shared_size=0; int gx = static_cast(info.GetOperation() ->getAttrOfType(GRID_SIZE_X_ATTR) .getInt()), @@ -167,6 +169,11 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) ->getAttrOfType(BLOCK_SIZE_Z_ATTR) .getInt()); } + if(info.GetOperation()->hasAttrOfType(SHARED_SIZE_ATTR)){ + shared_size=static_cast(info.GetOperation() + ->getAttrOfType(SHARED_SIZE_ATTR) + .getInt()); + } std::vector ranks; if (info.GetOperation()->hasAttrOfType(ARG_RANKS_ATTR)) { @@ -181,7 +188,7 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) auto num_arg = GetOpArgNum(info_); impl_->grid = dim3(gx, gy, gz); impl_->block = dim3(bx, by, bz); - impl_->shared_size = 0; + impl_->shared_size = shared_size; impl_->arg_reserve_size = 3; // initial 3 for grid/block/shared_size // store tensor meta From 514502607974a10b2e200313f1acc7f079a38eac Mon Sep 17 00:00:00 2001 From: liushanghao Date: Wed, 18 Jun 2025 17:08:48 +0800 Subject: [PATCH 10/26] [feat] tit e2e --- .../backend/cuda/gemm/gemm_rcr.py | 71 ++++++++++--------- .../python/tritontemplate/compiler/kernel.py | 3 +- .../tritontemplate/compiler/ops/gemm/gemm.py | 18 ++--- .../tritontemplate/testing/test_matmul.py | 45 +++++++----- .../examples/inference/tit_mlp.py | 49 +++++++++++++ 5 files changed, 125 insertions(+), 61 deletions(-) create mode 100644 frontends/torch-frontend/examples/inference/tit_mlp.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py index 6b9c75704..50373e03d 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py @@ -14,20 +14,24 @@ def gen_grid_gemm_rcr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): @triton.jit def gemm_rcr_bias( - a_ptr, b_ptr, bias_ptr, c_ptr, + # Pointers to matrices + a_ptr, b_ptr, bias_ptr, c_ptr, + # Matrix dimensions M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - stride_a0: tl.constexpr, stride_a1: tl.constexpr, - stride_b0: tl.constexpr, stride_b1: tl.constexpr, - stride_c0: tl.constexpr, stride_c1: tl.constexpr, + # Strides for matrices + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bn: tl.constexpr, stride_bk: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + stride_biasn: tl.constexpr, + # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr + ACTIVATION: tl.constexpr # 'relu' or None ): """ Kernel for GEMM RCR (Row-Col-Row) + Bias + ReLU. A (M, K) @ B (N, K)^T + Bias (N) -> C (M, N) B is stored as (N, K) but accessed as if it's (K, N) for the matmul. """ - pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -38,53 +42,55 @@ def gemm_rcr_bias( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1) - b_ptrs = b_ptr + (offs_bn[None, :] * stride_b0 + offs_k[:, None] * stride_b1) # K is outer, N is inner - + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_k[None, :] * stride_bk) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_bn[:, None] < N), other=0.0) + accumulator += tl.dot(a, tl.trans(b)) - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_a1 - b_ptrs += BLOCK_SIZE_K * stride_b1 + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk if bias_ptr is not None: offs_bias_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - bias_ptrs = bias_ptr + offs_bias_n + bias_ptrs = bias_ptr + offs_bias_n * stride_biasn bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) - accumulator = accumulator + bias_vals[None, :] # Broadcast bias across M + accumulator = accumulator + bias_vals[None, :] - # Apply activation function if ACTIVATION == 'relu': accumulator = tl.maximum(accumulator, 0) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + @triton.jit def gemm_rcr( + # Pointers to matrices a_ptr, b_ptr, c_ptr, + # Matrix dimensions M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - stride_a0: tl.constexpr, stride_a1: tl.constexpr, - stride_b0: tl.constexpr, stride_b1: tl.constexpr, - stride_c0: tl.constexpr, stride_c1: tl.constexpr, + # Strides for matrices + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bn: tl.constexpr, stride_bk: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr + ACTIVATION: tl.constexpr # 'relu' or None ): """ - Kernel for GEMM RCR (Row-Col-Row) (+ReLU). + Kernel for GEMM RCR (Row-Col-Row) + ReLU. A (M, K) @ B (N, K)^T -> C (M, N) B is stored as (N, K) but accessed as if it's (K, N) for the matmul. """ - pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -95,27 +101,24 @@ def gemm_rcr( offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1) - b_ptrs = b_ptr + (offs_bn[None, :] * stride_b0 + offs_k[:, None] * stride_b1) # K is outer, N is inner - + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_k[None, :] * stride_bk) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for k in tl.static_range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_bn[:, None] < N), other=0.0) + accumulator += tl.dot(a, tl.trans(b)) - accumulator += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_a1 - b_ptrs += BLOCK_SIZE_K * stride_b1 + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk - # Apply activation function if ACTIVATION == 'relu': accumulator = tl.maximum(accumulator, 0) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) - diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py index 45463b3a7..012c54c19 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -4,7 +4,8 @@ from tritontemplate.compiler.utils import get_device_max_shared_memory,get_cuda_device_name class TritonExecutor: - def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32): + def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32,constants:dict=None): + self.call_constants = constants self.triton_kernel = triton_kernel self.gridsize = grid_size self.blocksize = triton_kernel.num_warps * warp_size diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 7b709761d..e2d9a62f0 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -50,12 +50,12 @@ def _induce_output_shape(self): else: raise NotImplementedError(f'layout {self.layout} not supported') return [Tensor(shape=[M,N],dtype=self._attrs['inputs'][0].dtype)] + def _gen_signature_divisiability(self): signature_metadata={} divisiability={1:[],16:[]} for i,input in enumerate(self._attrs['inputs']+self._attrs['outputs']): if isinstance(input,Tensor): - try: sptype='*'+dtype_str_to_triton_signature(input.dtype) except KeyError: @@ -80,12 +80,14 @@ def _gen_constants(self): const_metadata['M']=M const_metadata['N']=N const_metadata['K']=K - const_metadata['stride_a0']=K - const_metadata['stride_a1']=1 - const_metadata['stride_b0']=N - const_metadata['stride_b1']=1 - const_metadata['stride_c0']=N - const_metadata['stride_c1']=1 + const_metadata['stride_am']=K + const_metadata['stride_ak']=1 + const_metadata['stride_bn']=K + const_metadata['stride_bk']=1 + const_metadata['stride_cm']=N + const_metadata['stride_cn']=1 + if self.is_bias: + const_metadata['stride_biasn']=1 else: raise NotImplementedError(f'layout {self.layout} not supported') @@ -112,6 +114,6 @@ def compile(self,target_name,workdir)->TritonExecutor: triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) exec_grid=gen_grid(constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name)) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py index ff936b06f..064483254 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py +++ b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py @@ -1,9 +1,11 @@ import torch import pytest +import triton from tritontemplate.compiler.base import IntImm,Tensor from tritontemplate.compiler.ops.gemm import Gemm from tritontemplate.compiler.compiler import compile_kernel +from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr_bias as gemm_rcr_bias_kernel def gen_gemm_rcr_bias_relu(M, N, K): A = Tensor(name='A', dtype='float16', shape=[IntImm(M), IntImm(K)]) @@ -17,32 +19,35 @@ def gen_gemm_rcr_bias_relu(M, N, K): layout='rcr', is_bias=True, is_transpose=False, - activation='relu', + activation=None, ) - kernel = compile_kernel(gemm_op, target_name='cuda') + kernel = compile_kernel(gemm_op, device='cuda') return kernel def gemm_rcr_bias_relu(a, b, bias): M, K = a.shape N, K_b = b.shape - assert K == K_b, "K dimension mismatch" # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) + a=a.contiguous() + b=b.contiguous() + bias=bias.contiguous() + c=c.contiguous() kernel = gen_gemm_rcr_bias_relu(M, N, K) - kernel(a, b, c, bias) + kernel(a, b, bias, c) return c @pytest.mark.parametrize( 'M, N, K', [ - (1024, 1024, 1024), - (1024, 1024, 512), - (1024, 1024, 256), - (128, 512, 256), - (1023,1023,1023), - (1025,1025,511), + (128, 256,512), + (128, 512, 256), + (128, 256, 256), + (128, 512, 512), + (256,128,256), + (256,256,256), ], ) def test_gemm_rcr_bias_relu(M, N, K): @@ -51,13 +56,17 @@ def test_gemm_rcr_bias_relu(M, N, K): Bias = torch.randn((N,), dtype=torch.float16, device='cuda') # Triton and PyTorch outputs - c_triton = gemm_rcr_bias_relu(A, B, Bias) - y_torch = torch.relu(torch.nn.functional.linear(A, B, bias=Bias)) - - if not torch.allclose(c_triton, y_torch, atol=1e-2, rtol=1e-2): - print("Outputs mismatch!") - diff = torch.abs(c_triton - y_torch) - print("Max diff:", torch.max(diff), "Mean diff:", torch.mean(diff)) - + c_triton = torch.zeros((M, N), device=A.device, dtype=A.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + triton_aot=gemm_rcr_bias_relu(A,B,Bias) + gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),64,64,64,None) + + assert torch.allclose(c_triton, triton_aot, atol=1e-2, rtol=1e-2), \ + f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" + c=torch.nn.functional.linear(A,B,bias=Bias) + assert torch.allclose(c, triton_aot, atol=1e-2, rtol=1e-2), \ + f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" diff --git a/frontends/torch-frontend/examples/inference/tit_mlp.py b/frontends/torch-frontend/examples/inference/tit_mlp.py new file mode 100644 index 000000000..66181c12d --- /dev/null +++ b/frontends/torch-frontend/examples/inference/tit_mlp.py @@ -0,0 +1,49 @@ +import os + +import torch +from torch import nn +import torch_frontend +import byteir + +from brt_backend import BRTBackend + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(256, 256,dtype=torch.float16) + + def forward(self, x): + x = self.linear1(x) + # x = torch.nn.functional.relu(x) + return x + +workspace = "./workspace" +os.makedirs(workspace, exist_ok=True) +with torch.no_grad(): + model = MLP().cuda().eval() + inputs = [torch.randn(128, 256, dtype=torch.float16).cuda()] + traced_model = torch.jit.trace(model, inputs) + + stablehlo_file = workspace + "/model.stablehlo.mlir" + byre_file = workspace + "/model.byre.mlir" + module = torch_frontend.compile(traced_model, inputs, "stablehlo") + with open(stablehlo_file, "w") as f: + f.write(module.operation.get_asm()) + + byteir.compile(stablehlo_file, byre_file, entry_func="forward", target="cuda_with_triton") + + backend = BRTBackend("cuda", byre_file) + byteir_outputs = backend.execute(inputs) + + torch_outputs = model(*inputs) + torch_jit_outputs = traced_model(*inputs) + if len(byteir_outputs) == 1: + byteir_outputs = byteir_outputs[0] + + torch.testing.assert_close(torch_outputs, torch_jit_outputs, rtol=1e-3, atol=1e-3) + try: + torch.testing.assert_close(torch_outputs, byteir_outputs, rtol=1e-3, atol=1e-3) + except AssertionError as e: + diff=torch.abs(torch_outputs-byteir_outputs) + print("diff:",diff) + raise e From 997dc4c31bcb42bc8c20d0983b985359b5bdae01 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Wed, 18 Jun 2025 17:20:44 +0800 Subject: [PATCH 11/26] [fix] tritontemplate example --- .../python/tritontemplate/testing/test_matmul.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py index 064483254..6175b2c91 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py +++ b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py @@ -19,7 +19,7 @@ def gen_gemm_rcr_bias_relu(M, N, K): layout='rcr', is_bias=True, is_transpose=False, - activation=None, + activation='relu', ) kernel = compile_kernel(gemm_op, device='cuda') @@ -61,11 +61,11 @@ def test_gemm_rcr_bias_relu(M, N, K): triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) triton_aot=gemm_rcr_bias_relu(A,B,Bias) - gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),64,64,64,None) + gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),128,128,128,'relu') assert torch.allclose(c_triton, triton_aot, atol=1e-2, rtol=1e-2), \ f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" - c=torch.nn.functional.linear(A,B,bias=Bias) + c=torch.nn.functional.relu(torch.nn.functional.linear(A,B,bias=Bias)) assert torch.allclose(c, triton_aot, atol=1e-2, rtol=1e-2), \ f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" From a7c01f1a9b8c68bd9fdbe15ec6c0e72b5dc84458 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Wed, 18 Jun 2025 17:39:04 +0800 Subject: [PATCH 12/26] [fix] tiriton template mlp complement --- frontends/torch-frontend/examples/inference/tit_mlp.py | 10 ++++++++-- runtime/lib/backends/cuda/device/cuda_work_queue.cc | 1 - 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/frontends/torch-frontend/examples/inference/tit_mlp.py b/frontends/torch-frontend/examples/inference/tit_mlp.py index 66181c12d..404effdbd 100644 --- a/frontends/torch-frontend/examples/inference/tit_mlp.py +++ b/frontends/torch-frontend/examples/inference/tit_mlp.py @@ -10,11 +10,16 @@ class MLP(nn.Module): def __init__(self): super().__init__() - self.linear1 = nn.Linear(256, 256,dtype=torch.float16) + self.linear1 = nn.Linear(256, 512,dtype=torch.float16) + self.linear2 = nn.Linear(512, 256,dtype=torch.float16) + self.linear3 = nn.Linear(256, 128,dtype=torch.float16) def forward(self, x): x = self.linear1(x) - # x = torch.nn.functional.relu(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + x = torch.nn.functional.relu(x) + x = self.linear3(x) return x workspace = "./workspace" @@ -47,3 +52,4 @@ def forward(self, x): diff=torch.abs(torch_outputs-byteir_outputs) print("diff:",diff) raise e + print("byteir tit backend success") \ No newline at end of file diff --git a/runtime/lib/backends/cuda/device/cuda_work_queue.cc b/runtime/lib/backends/cuda/device/cuda_work_queue.cc index 615c0e270..d119a5fd7 100644 --- a/runtime/lib/backends/cuda/device/cuda_work_queue.cc +++ b/runtime/lib/backends/cuda/device/cuda_work_queue.cc @@ -72,7 +72,6 @@ inline common::Status ComputeDrv(const void *func, void **args, dim3 *grid = static_cast(args[0]); dim3 *block = static_cast(args[1]); size_t *shared_size = static_cast(args[2]); - //TODO: unsafe operation? CUfunction hFunc=reinterpret_cast(const_cast(func)); //extend the shared memory From 2b49e12db060753d2513798ae379dc8882bbeb1a Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 20 Jun 2025 15:21:32 +0800 Subject: [PATCH 13/26] [feat] tf32 option --- .../byteir/dialects/cat/ir_processor.py | 1 - .../dialects/cat/ir_translator/tit_builder.py | 1 + .../backend/cuda/gemm/gemm_rcr.py | 14 ++++++++--- .../tritontemplate/compiler/compiler.py | 3 ++- .../tritontemplate/compiler/ops/gemm/gemm.py | 18 ++++++++----- .../tritontemplate/testing/test_matmul.py | 25 ++++++++++--------- .../examples/inference/tit_mlp.py | 9 ++++--- 7 files changed, 43 insertions(+), 28 deletions(-) diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 7c2e3a47d..8bd48c846 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -179,7 +179,6 @@ def decouple_triton_args(triton_args): gridsize_x_args.append(str(gridsize[0])) gridsize_y_args.append(str(gridsize[1])) gridsize_z_args.append(str(gridsize[2])) - #TODO: blocksize is 1d for now, need to check blocksize_x_args.append(str(blocksize)) blocksize_y_args.append(str(1)) blocksize_z_args.append(str(1)) diff --git a/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py index d7ff884b3..e382c2a40 100644 --- a/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py +++ b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py @@ -122,6 +122,7 @@ def _gen_tit_kernel(self, results): op=result, device=self.device, workdir=self.workdir, + enable_tf32=self.enable_tf32, ) # kernel rename with open(self.tit_module_path, "w") as f: diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py index 50373e03d..38b159283 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py @@ -25,13 +25,16 @@ def gemm_rcr_bias( stride_biasn: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr # 'relu' or None + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, ): """ Kernel for GEMM RCR (Row-Col-Row) + Bias + ReLU. A (M, K) @ B (N, K)^T + Bias (N) -> C (M, N) B is stored as (N, K) but accessed as if it's (K, N) for the matmul. """ + # _TF32_ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" + # DTYPE: tl.constexpr = k.dtype.element_ty pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -49,8 +52,10 @@ def gemm_rcr_bias( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) b = tl.load(b_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_bn[:, None] < N), other=0.0) - - accumulator += tl.dot(a, tl.trans(b)) + # if enable_tf32 and DTYPE == tl.float32: + # a = tl.inline_asm_elementwise(_TF32_ASM, "=r, r", [a], dtype=tl.float32, is_pure=True, pack=1) + # b = tl.inline_asm_elementwise(_TF32_ASM, "=r, r", [b], dtype=tl.float32, is_pure=True, pack=1) + accumulator += tl.dot(a, tl.trans(b),allow_tf32=enable_tf32) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -84,7 +89,8 @@ def gemm_rcr( stride_cm: tl.constexpr, stride_cn: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr # 'relu' or None + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, ): """ Kernel for GEMM RCR (Row-Col-Row) + ReLU. diff --git a/external/TritonTemplate/python/tritontemplate/compiler/compiler.py b/external/TritonTemplate/python/tritontemplate/compiler/compiler.py index f97477306..5a4726283 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/compiler.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/compiler.py @@ -11,10 +11,11 @@ def compile_kernel( op: compiler.base.Operation, device: str='cuda', workdir: str='./workshop', + enable_tf32: bool=False )->TritonExecutor: try: _ = importlib.import_module(f'tritontemplate.backend.{device}') except ModuleNotFoundError: raise ModuleNotFoundError(f'Target {device} not found') - return op.compile(device, workdir) + return op.compile(device, workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index e2d9a62f0..506f50ec0 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -23,7 +23,6 @@ def __init__( inputs: List[Tensor], layout: str, is_bias: bool = False, - is_transpose: bool = False, outputs: Optional[List[Tensor]] = None, activation: Optional[str] = None, name: Optional[str] = None, @@ -34,7 +33,6 @@ def __init__( super().__init__(inputs, outputs, name) self.layout = layout self.is_bias= is_bias - self._attrs['is_transpose'] = is_transpose self._attrs['activation'] = activation self._attrs['inputs'] = inputs self._attrs['outputs'] = outputs if outputs is not None else self._induce_output_shape() @@ -42,7 +40,6 @@ def __init__( def _induce_output_shape(self): # TODO: support transpose, by swap A,B - assert not self._attrs['is_transpose'], 'transpose not supported' if self.layout == 'rcr': M,N,K = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] elif self.layout == 'rrr': @@ -71,9 +68,17 @@ def _gen_signature_divisiability(self): return signature_metadata,divisiability - def _gen_constants(self): + def _gen_constants(self,enable_tf32): const_metadata={} const_metadata['ACTIVATION'] = self._attrs['activation'] + + any_float32=False + for input in self._attrs['inputs']: + if input.dtype == 'float32': + any_float32=True + break + + const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False if self.layout == 'rcr': input=self._attrs['inputs'] M,N,K=input[0].shape[0],input[1].shape[0],input[0].shape[1] @@ -99,13 +104,14 @@ def _gen_constants(self): def _gen_exec_metadata(self): return _exec_metadata.copy() - def compile(self,target_name,workdir)->TritonExecutor: + #TODO:enable_tf32 https://github.com/triton-lang/triton/issues/4574 + def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor: triton_kernel_name=f'gemm_{self.layout}'+ ('' if not self.is_bias else '_bias') triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm_{self.layout}') signature,divisiability=self._gen_signature_divisiability() - constants=self._gen_constants() + constants=self._gen_constants(enable_tf32) exec_metadata=self._gen_exec_metadata() num_warps=exec_metadata['num_warps'] diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py index 6175b2c91..1f6035469 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py +++ b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py @@ -8,17 +8,16 @@ from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr_bias as gemm_rcr_bias_kernel def gen_gemm_rcr_bias_relu(M, N, K): - A = Tensor(name='A', dtype='float16', shape=[IntImm(M), IntImm(K)]) - B = Tensor(name='B', dtype='float16', shape=[IntImm(N), IntImm(K)]) - Bias = Tensor(name='Bias', dtype='float16', shape=[IntImm(N)]) - C = Tensor(name='C', dtype='float16', shape=[IntImm(M), IntImm(N)]) + A = Tensor(name='A', dtype='float32', shape=[IntImm(M), IntImm(K)]) + B = Tensor(name='B', dtype='float32', shape=[IntImm(N), IntImm(K)]) + Bias = Tensor(name='Bias', dtype='float32', shape=[IntImm(N)]) + C = Tensor(name='C', dtype='float32', shape=[IntImm(M), IntImm(N)]) gemm_op = Gemm( inputs=[A, B, Bias], outputs=[C], layout='rcr', is_bias=True, - is_transpose=False, activation='relu', ) @@ -51,22 +50,24 @@ def gemm_rcr_bias_relu(a, b, bias): ], ) def test_gemm_rcr_bias_relu(M, N, K): - A = torch.randn((M, K), dtype=torch.float16, device='cuda') - B = torch.randn((N, K), dtype=torch.float16, device='cuda') - Bias = torch.randn((N,), dtype=torch.float16, device='cuda') + torch.backends.cuda.matmul.allow_tf32=False + A = torch.randn((M, K), dtype=torch.float32, device='cuda') + B = torch.randn((N, K), dtype=torch.float32, device='cuda') + Bias = torch.randn((N,), dtype=torch.float32, device='cuda') # Triton and PyTorch outputs - c_triton = torch.zeros((M, N), device=A.device, dtype=A.dtype) + c_triton = torch.empty((M, N), device=A.device, dtype=A.dtype) grid = lambda META: ( triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) triton_aot=gemm_rcr_bias_relu(A,B,Bias) - gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),128,128,128,'relu') + gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),64,64,64,'relu',enable_tf32=False) assert torch.allclose(c_triton, triton_aot, atol=1e-2, rtol=1e-2), \ - f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" + f"Outputs mismatch between aot and jit for M={M}, N={N}, K={K}\n" c=torch.nn.functional.relu(torch.nn.functional.linear(A,B,bias=Bias)) assert torch.allclose(c, triton_aot, atol=1e-2, rtol=1e-2), \ f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" - +if __name__ == '__main__': + test_gemm_rcr_bias_relu(128, 256, 512) \ No newline at end of file diff --git a/frontends/torch-frontend/examples/inference/tit_mlp.py b/frontends/torch-frontend/examples/inference/tit_mlp.py index 404effdbd..79467e543 100644 --- a/frontends/torch-frontend/examples/inference/tit_mlp.py +++ b/frontends/torch-frontend/examples/inference/tit_mlp.py @@ -10,9 +10,9 @@ class MLP(nn.Module): def __init__(self): super().__init__() - self.linear1 = nn.Linear(256, 512,dtype=torch.float16) - self.linear2 = nn.Linear(512, 256,dtype=torch.float16) - self.linear3 = nn.Linear(256, 128,dtype=torch.float16) + self.linear1 = nn.Linear(256, 512,dtype=torch.float32) + self.linear2 = nn.Linear(512, 256,dtype=torch.float32) + self.linear3 = nn.Linear(256, 128,dtype=torch.float32) def forward(self, x): x = self.linear1(x) @@ -25,8 +25,9 @@ def forward(self, x): workspace = "./workspace" os.makedirs(workspace, exist_ok=True) with torch.no_grad(): + torch.backends.cuda.matmul.allow_tf32=False model = MLP().cuda().eval() - inputs = [torch.randn(128, 256, dtype=torch.float16).cuda()] + inputs = [torch.randn(128, 256, dtype=torch.float32).cuda()] traced_model = torch.jit.trace(model, inputs) stablehlo_file = workspace + "/model.stablehlo.mlir" From 0e7ae32d7a01aab645f1c39d0fc09292a980da78 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Tue, 24 Jun 2025 13:06:39 +0800 Subject: [PATCH 14/26] [feat] Add size under 32 support --- .../tritontemplate/compiler/ops/gemm/gemm.py | 16 ++++-- .../tritontemplate/testing/test_matmul.py | 53 +++++++++++-------- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 506f50ec0..7f1647081 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -67,7 +67,15 @@ def _gen_signature_divisiability(self): raise NotImplementedError(f'input {input} not supported') return signature_metadata,divisiability - + + @staticmethod + def _block_size(x): + if x>=128: + return 128 + elif x<=32: + return 32 + return triton.next_power_of_2(x) + def _gen_constants(self,enable_tf32): const_metadata={} const_metadata['ACTIVATION'] = self._attrs['activation'] @@ -96,9 +104,9 @@ def _gen_constants(self,enable_tf32): else: raise NotImplementedError(f'layout {self.layout} not supported') - const_metadata['BLOCK_SIZE_M']= 128 if M>=128 else triton.next_power_of_2(M) - const_metadata['BLOCK_SIZE_N']= 128 if N>=128 else triton.next_power_of_2(N) - const_metadata['BLOCK_SIZE_K']= 128 if K>=128 else triton.next_power_of_2(K) + const_metadata['BLOCK_SIZE_M']= self._block_size(M) + const_metadata['BLOCK_SIZE_N']= self._block_size(N) + const_metadata['BLOCK_SIZE_K']= self._block_size(K) return const_metadata def _gen_exec_metadata(self): diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py index 1f6035469..e681a512e 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py +++ b/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py @@ -7,11 +7,11 @@ from tritontemplate.compiler.compiler import compile_kernel from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr_bias as gemm_rcr_bias_kernel -def gen_gemm_rcr_bias_relu(M, N, K): - A = Tensor(name='A', dtype='float32', shape=[IntImm(M), IntImm(K)]) - B = Tensor(name='B', dtype='float32', shape=[IntImm(N), IntImm(K)]) - Bias = Tensor(name='Bias', dtype='float32', shape=[IntImm(N)]) - C = Tensor(name='C', dtype='float32', shape=[IntImm(M), IntImm(N)]) +def gen_gemm_rcr_bias_relu(M, N, K, stype): + A = Tensor(name='A', dtype=stype, shape=[IntImm(M), IntImm(K)]) + B = Tensor(name='B', dtype=stype, shape=[IntImm(N), IntImm(K)]) + Bias = Tensor(name='Bias', dtype=stype, shape=[IntImm(N)]) + C = Tensor(name='C', dtype=stype, shape=[IntImm(M), IntImm(N)]) gemm_op = Gemm( inputs=[A, B, Bias], @@ -24,7 +24,7 @@ def gen_gemm_rcr_bias_relu(M, N, K): kernel = compile_kernel(gemm_op, device='cuda') return kernel -def gemm_rcr_bias_relu(a, b, bias): +def gemm_rcr_bias_relu(a, b, bias,stype): M, K = a.shape N, K_b = b.shape @@ -34,33 +34,43 @@ def gemm_rcr_bias_relu(a, b, bias): b=b.contiguous() bias=bias.contiguous() c=c.contiguous() - kernel = gen_gemm_rcr_bias_relu(M, N, K) + kernel = gen_gemm_rcr_bias_relu(M, N, K,stype) kernel(a, b, bias, c) return c @pytest.mark.parametrize( - 'M, N, K', + 'M, N, K, stype', [ - (128, 256,512), - (128, 512, 256), - (128, 256, 256), - (128, 512, 512), - (256,128,256), - (256,256,256), + (2, 128, 31,'float32'), + (128,2,31,'float16'), + (128,128,31,'float32'), + (31,128,2,'float16'), + (129,128,128,'float32'), + (128, 257,512,'float16'), + (128, 512, 257,'float32'), + (127, 256, 256,'float16'), + (128, 511, 512,'float32'), + (256,128,255,'float16'), + (1,256,256,'float32'), ], ) -def test_gemm_rcr_bias_relu(M, N, K): - torch.backends.cuda.matmul.allow_tf32=False - A = torch.randn((M, K), dtype=torch.float32, device='cuda') - B = torch.randn((N, K), dtype=torch.float32, device='cuda') - Bias = torch.randn((N,), dtype=torch.float32, device='cuda') +def test_gemm_rcr_bias_relu(M, N, K, stype): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32=False + dtype=torch.float32 + else: + dtype=torch.float16 + + A = torch.randn((M, K), dtype=dtype, device='cuda') + B = torch.randn((N, K), dtype=dtype, device='cuda') + Bias = torch.randn((N,), dtype=dtype, device='cuda') # Triton and PyTorch outputs c_triton = torch.empty((M, N), device=A.device, dtype=A.dtype) grid = lambda META: ( triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - triton_aot=gemm_rcr_bias_relu(A,B,Bias) + triton_aot=gemm_rcr_bias_relu(A,B,Bias,stype) gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),64,64,64,'relu',enable_tf32=False) assert torch.allclose(c_triton, triton_aot, atol=1e-2, rtol=1e-2), \ @@ -68,6 +78,3 @@ def test_gemm_rcr_bias_relu(M, N, K): c=torch.nn.functional.relu(torch.nn.functional.linear(A,B,bias=Bias)) assert torch.allclose(c, triton_aot, atol=1e-2, rtol=1e-2), \ f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" - -if __name__ == '__main__': - test_gemm_rcr_bias_relu(128, 256, 512) \ No newline at end of file From 1658f0e381cf1d5c9e364fd21a11ea901f625870 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Tue, 24 Jun 2025 22:02:11 +0800 Subject: [PATCH 15/26] [format] after clang-format --- compiler/include/byteir/Conversion/Passes.h | 2 +- .../include/byteir/Conversion/ToTIT/ToTIT.h | 3 +- .../lib/Conversion/ToTIT/GenTITConfig.cpp | 76 ++++++++----------- .../backends/cuda/device/cuda_work_queue.cc | 23 +++--- .../cuda/providers/default/codegen/ptx.cc | 11 +-- 5 files changed, 51 insertions(+), 64 deletions(-) diff --git a/compiler/include/byteir/Conversion/Passes.h b/compiler/include/byteir/Conversion/Passes.h index e3d3f11cf..2d29ef4ea 100644 --- a/compiler/include/byteir/Conversion/Passes.h +++ b/compiler/include/byteir/Conversion/Passes.h @@ -28,7 +28,6 @@ #include "byteir/Conversion/LcclToByre/LcclToByre.h" #include "byteir/Conversion/MemrefToByre/MemrefToByre.h" #include "byteir/Conversion/ToAIT/ToAIT.h" -#include "byteir/Conversion/ToTIT/ToTIT.h" #include "byteir/Conversion/ToAce/MhloToAce.h" #include "byteir/Conversion/ToByre/ToByre.h" #include "byteir/Conversion/ToGPU/ToGPU.h" @@ -36,6 +35,7 @@ #include "byteir/Conversion/ToLLVM/ToLLVM.h" #include "byteir/Conversion/ToLinalg/ToLinalg.h" #include "byteir/Conversion/ToPTX/ToPTX.h" +#include "byteir/Conversion/ToTIT/ToTIT.h" namespace mlir { diff --git a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h index 61998abac..308023a26 100644 --- a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h +++ b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h @@ -36,8 +36,7 @@ createGenTITConfigPass(ArrayRef funcNames = {""}, ArrayRef gridsizeZArgs = {""}, ArrayRef blocksizeXArgs = {""}, ArrayRef blocksizeYArgs = {""}, - ArrayRef blocksizeZArgs = {""} - ); + ArrayRef blocksizeZArgs = {""}); } // namespace mlir diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp index a331c6e37..93c793da1 100644 --- a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -31,25 +31,20 @@ using namespace mlir; namespace { static LogicalResult AttachTITConfigToAttr( - func::FuncOp func, - const std::string &titPtxPath, - const std::string &smemsizeArg, - const std::string &gridsizeXArg, - const std::string &gridsizeYArg, - const std::string &gridsizeZArg, - const std::string &blocksizeXArg, - const std::string &blocksizeYArg, + func::FuncOp func, const std::string &titPtxPath, + const std::string &smemsizeArg, const std::string &gridsizeXArg, + const std::string &gridsizeYArg, const std::string &gridsizeZArg, + const std::string &blocksizeXArg, const std::string &blocksizeYArg, const std::string &blocksizeZArg) { - - + std::string device_name; std::string byreKernelName; if (titPtxPath.find(".ptx") != std::string::npos) { device_name = "cuda"; - byreKernelName="PTXOp"; + byreKernelName = "PTXOp"; } - if (device_name.empty()|| byreKernelName.empty()) { + if (device_name.empty() || byreKernelName.empty()) { return func.emitError("Invalid device type for TIT configuration"); } addGenericFuncAttrs(func, byreKernelName); @@ -62,14 +57,10 @@ static LogicalResult AttachTITConfigToAttr( titConfig["device"] = opBuilder.getStringAttr(device_name); titConfig["device_file_name"] = opBuilder.getStringAttr(titPtxPath); - llvm::StringMap gpuLaunchArgs = { - {"SharedMemorySize", smemsizeArg}, - {"BlockSize.x", blocksizeXArg}, - {"BlockSize.y", blocksizeYArg}, - {"BlockSize.z", blocksizeZArg}, - {"GridSize.x", gridsizeXArg}, - {"GridSize.y", gridsizeYArg}, + {"SharedMemorySize", smemsizeArg}, {"BlockSize.x", blocksizeXArg}, + {"BlockSize.y", blocksizeYArg}, {"BlockSize.z", blocksizeZArg}, + {"GridSize.x", gridsizeXArg}, {"GridSize.y", gridsizeYArg}, {"GridSize.z", gridsizeZArg}}; for (auto &kv : gpuLaunchArgs) { @@ -91,17 +82,15 @@ static LogicalResult AttachTITConfigToAttr( } struct GenTITConfigPass : public GenTITConfigBase { - GenTITConfigPass( - ArrayRef funcNames, - ArrayRef titPtxPaths, - ArrayRef smemsizeArgs, - ArrayRef gridsizeXArgs, - ArrayRef gridsizeYArgs, - ArrayRef gridsizeZArgs, - ArrayRef blocksizeXArgs, - ArrayRef blocksizeYArgs, - ArrayRef blocksizeZArgs - ) + GenTITConfigPass(ArrayRef funcNames, + ArrayRef titPtxPaths, + ArrayRef smemsizeArgs, + ArrayRef gridsizeXArgs, + ArrayRef gridsizeYArgs, + ArrayRef gridsizeZArgs, + ArrayRef blocksizeXArgs, + ArrayRef blocksizeYArgs, + ArrayRef blocksizeZArgs) : GenTITConfigBase() { this->funcNames = funcNames; this->titPtxPaths = titPtxPaths; @@ -120,8 +109,9 @@ struct GenTITConfigPass : public GenTITConfigBase { return; for (size_t i = 0; i < funcNames.size(); ++i) if (func.getSymName() == funcNames[i]) { - if (failed(AttachTITConfigToAttr(func, titPtxPaths[i], smemsizeArgs[i], gridsizeXArgs[i], - gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], + if (failed(AttachTITConfigToAttr( + func, titPtxPaths[i], smemsizeArgs[i], gridsizeXArgs[i], + gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], blocksizeYArgs[i], blocksizeZArgs[i]))) { return signalPassFailure(); } @@ -131,17 +121,13 @@ struct GenTITConfigPass : public GenTITConfigBase { } // namespace -std::unique_ptr> -mlir::createGenTITConfigPass( - ArrayRef funcNames, - ArrayRef titPtxPaths, - ArrayRef smemsizeArgs, - ArrayRef gridsizeXArgs, - ArrayRef gridsizeYArgs, - ArrayRef gridsizeZArgs, - ArrayRef blocksizeXArgs, - ArrayRef blocksizeYArgs, - ArrayRef blocksizeZArgs -) { - return std::make_unique(funcNames, titPtxPaths,smemsizeArgs, gridsizeXArgs, gridsizeYArgs, gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs); +std::unique_ptr> mlir::createGenTITConfigPass( + ArrayRef funcNames, ArrayRef titPtxPaths, + ArrayRef smemsizeArgs, ArrayRef gridsizeXArgs, + ArrayRef gridsizeYArgs, ArrayRef gridsizeZArgs, + ArrayRef blocksizeXArgs, ArrayRef blocksizeYArgs, + ArrayRef blocksizeZArgs) { + return std::make_unique( + funcNames, titPtxPaths, smemsizeArgs, gridsizeXArgs, gridsizeYArgs, + gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs); } diff --git a/runtime/lib/backends/cuda/device/cuda_work_queue.cc b/runtime/lib/backends/cuda/device/cuda_work_queue.cc index d119a5fd7..35ec971f1 100644 --- a/runtime/lib/backends/cuda/device/cuda_work_queue.cc +++ b/runtime/lib/backends/cuda/device/cuda_work_queue.cc @@ -72,24 +72,25 @@ inline common::Status ComputeDrv(const void *func, void **args, dim3 *grid = static_cast(args[0]); dim3 *block = static_cast(args[1]); size_t *shared_size = static_cast(args[2]); - CUfunction hFunc=reinterpret_cast(const_cast(func)); + CUfunction hFunc = reinterpret_cast(const_cast(func)); - //extend the shared memory + // extend the shared memory int shared_optin; - int device_id=-1; + int device_id = -1; BRT_CUDA_CHECK(cudaGetDevice(&device_id)); - BRT_CUDA_CHECK(cudaDeviceGetAttribute(&shared_optin, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); - - if (shared_optin>49152&&(*shared_size)>49152){ + BRT_CUDA_CHECK(cudaDeviceGetAttribute( + &shared_optin, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + + if (shared_optin > 49152 && (*shared_size) > 49152) { BRT_CU_CHECK(cuFuncSetCacheConfig(hFunc, CU_FUNC_CACHE_PREFER_SHARED)); - BRT_CU_CHECK(cuFuncSetAttribute(hFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)); + BRT_CU_CHECK(cuFuncSetAttribute( + hFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)); } void **kernel_args = args + 3; - return BRT_CU_CALL( - cuLaunchKernel(hFunc, - (*grid).x, (*grid).y, (*grid).z, (*block).x, (*block).y, - (*block).z, *shared_size, stream, kernel_args, 0)); + return BRT_CU_CALL(cuLaunchKernel(hFunc, (*grid).x, (*grid).y, (*grid).z, + (*block).x, (*block).y, (*block).z, + *shared_size, stream, kernel_args, 0)); } inline common::Status ComputeHost(const void *func, void **args, diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index 25256d620..dad8689bf 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -139,7 +139,7 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr"); } - size_t shared_size=0; + size_t shared_size = 0; int gx = static_cast(info.GetOperation() ->getAttrOfType(GRID_SIZE_X_ATTR) .getInt()), @@ -169,10 +169,11 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) ->getAttrOfType(BLOCK_SIZE_Z_ATTR) .getInt()); } - if(info.GetOperation()->hasAttrOfType(SHARED_SIZE_ATTR)){ - shared_size=static_cast(info.GetOperation() - ->getAttrOfType(SHARED_SIZE_ATTR) - .getInt()); + if (info.GetOperation()->hasAttrOfType(SHARED_SIZE_ATTR)) { + shared_size = + static_cast(info.GetOperation() + ->getAttrOfType(SHARED_SIZE_ATTR) + .getInt()); } std::vector ranks; From add8f1d5f347dbdd2568960c277b65009223b264 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 27 Jun 2025 09:35:03 +0800 Subject: [PATCH 16/26] [fix] clear files --- .../tritontemplate/compiler/stable_set.py | 100 ------------ .../tritontemplate/compiler/symbolic.py | 154 ------------------ 2 files changed, 254 deletions(-) delete mode 100644 external/TritonTemplate/python/tritontemplate/compiler/stable_set.py delete mode 100644 external/TritonTemplate/python/tritontemplate/compiler/symbolic.py diff --git a/external/TritonTemplate/python/tritontemplate/compiler/stable_set.py b/external/TritonTemplate/python/tritontemplate/compiler/stable_set.py deleted file mode 100644 index 82f945078..000000000 --- a/external/TritonTemplate/python/tritontemplate/compiler/stable_set.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -A stable set is like a Python set which produces deterministic results. -It also tries to preserve the original element order as much as possible, which could -potentially make debugging (e.g. comparison with the original graph, comparison between -AIT GPU trace and other GPU traces) easier. -""" -from collections import abc -from typing import Any, Iterable - - -class StableSet(abc.MutableSet): - def __init__(self, s: Iterable[Any] = None): - if s is None: - s = [] - self._d = {item: None for item in s} - - def add(self, value) -> None: - self._d[value] = None - - def update(self, other) -> None: - for item in other: - self._d[item] = None - - def discard(self, value) -> None: - self._d.pop(value, None) - - def remove(self, value) -> None: - self._d.pop(value) - - def copy(self): - return StableSet(list(self._d)) - - def clear(self): - self._d = {} - - def __sub__(self, other): - res = self.copy() - for item in other: - res.discard(item) - return res - - def __str__(self) -> str: - return str(list(self._d)) - - def __repr__(self) -> str: - return str(list(self._d)) - - def __len__(self) -> int: - return len(self._d) - - def __contains__(self, value: Any) -> int: - return value in self._d - - def __iter__(self): - return list(self._d).__iter__() - - def _type_check(self, other): - if not isinstance(other, StableSet): - raise RuntimeError( - f"A StableSet can only be operated with another StableSet! " - f"Current type: {type(other)}." - ) - - def __eq__(self, other): - self._type_check(other) - return set(other._d) == set(self._d) - - def __le__(self, other): - self._type_check(other) - return set(self._d) <= set(other._d) - - def __lt__(self, other): - self._type_check(other) - return set(self._d) < set(other._d) - - def __ge__(self, other): - self._type_check(other) - return set(self._d) >= set(other._d) - - def __gt__(self, other): - self._type_check(other) - return set(self._d) > set(other._d) - - def __getitem__(self, idx): - return list(self._d)[idx] diff --git a/external/TritonTemplate/python/tritontemplate/compiler/symbolic.py b/external/TritonTemplate/python/tritontemplate/compiler/symbolic.py deleted file mode 100644 index 2edcfb195..000000000 --- a/external/TritonTemplate/python/tritontemplate/compiler/symbolic.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -""" -Symbolic helpers for AITemplate. -AITemplate leverages Sympy to do symbolic computations for shapes. -The core of Sympy is surrounded around the class "Symbol". We could apply operations -on Symbols (i.e. add/mul/power/etc.) Which could help us do basic arithmetic with -unknown values. -The symbolic-ness comes from representation that includes Symbol (i.e. sym_1 + 100.) - -Example Usage: -A = IntVar(...) -sym_A = A.symbolic_value() # equivalent of A._attrs["symbolic_value"] - -# do something about sym_A, some common usage include: -new_sym = sym_A + 100 -new_sym = sym_A - 100 -new_sym = sym_A * 2 -new_sym = sym_A * sym_B - -# We could then assign the symbolic value to a new IntVar. -new_var = IntVar(..., symbolic_value=new_sym) - -For more advanced usage on Sympy, check: https://docs.sympy.org/latest/tutorials/intro-tutorial/intro.html -""" -from __future__ import annotations - -import itertools - -from numbers import Number -from typing import Any, List, Optional, Set - -import sympy - - -_k_symbolic_to_intvar = {} -_k_symbolic_index = 0 -_k_symbolic_value = {} - - -def create_new_symbol( - name: Optional[str] = None, - values: Optional[List[int]] = None, - check_duplicate: bool = False, -) -> sympy.Symbol: - """ - Creates and memoizing symbols. - - Parameters - ---------- - name : Optional[str] - The symbol name that is going to be used. If None is provided, an unused - name would be created. - values : Optional[List[int]] - The values for IntVar, which indicates the range of which the symbol could - represent. - check_duplicate : bool - If set as True and name is provided, we check whether the name and values - provided matches the corresponding symbol recorded. - """ - global _k_symbolic_index - global _k_symbolic_value - - if name is None: - while True: - name = f"_sym_{_k_symbolic_index}" - _k_symbolic_index += 1 - - if name not in _k_symbolic_value: - break - - values = sorted(set(values)) if values is not None else values - if ( - check_duplicate - and name in _k_symbolic_value - and _k_symbolic_value[name] != values - ): - raise ValueError( - f"Symbol ({name}) has different values! New value is {values}, stored value is {_k_symbolic_value[name]}" - ) - - _k_symbolic_value[name] = values - return sympy.Symbol(name) - - -def is_symbol(sym_val: Any) -> bool: - return isinstance(sym_val, sympy.Symbol) - - -def is_symbolic(sym_val: Any) -> bool: - """ - Check whether sym_val is a sympy class. - """ - return isinstance(sym_val, sympy.Basic) - - -def is_integer(sym_val: Any) -> bool: - # We wrap this since None is returned if sympy can't determine the property. - if is_symbolic(sym_val): - return sym_val.is_number and int(sym_val) - sym_val == 0 - elif isinstance(sym_val, Number): - return int(sym_val) - sym_val == 0 - - return False - - -def get_global_symbol_set() -> Set: - global _k_symbolic_value - return set(_k_symbolic_value.keys()) - - -def get_intvar(sym_name: str): - global _k_symbolic_to_intvar - - return _k_symbolic_to_intvar.get(sym_name, None) - - -def store_intvar(sym_name: str, int_var) -> None: - global _k_symbolic_to_intvar - - _k_symbolic_to_intvar[sym_name] = int_var - - -def simplify_intvar_values(sym_val: sympy.Basic): - """ - Given a symbolic value, resolve the symbol's value range. - - Example: - 'symbol_A' has value range of [10, 20] - simplify_intvar_values(symbol_A * 3 + 4) returns [34, 64] - """ - global _k_symbolic_value - - symbols = list(sym_val.free_symbols) - symbol_shapes = [_k_symbolic_value[s.name] for s in symbols] - symbol_shapes = [s for s in symbol_shapes if s is not None] - shape_perms = list(itertools.product(*symbol_shapes)) - - new_shape = [int(sym_val.subs(zip(symbols, s))) for s in shape_perms] - new_shape = sorted(set(new_shape)) - - return new_shape From 4abc79686df4eaf94f4f72fc2048da57fbfdeef8 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Mon, 7 Jul 2025 15:14:40 +0800 Subject: [PATCH 17/26] [feat] gemm rrr add --- .../backend/cuda/gemm/__init__.py | 3 +- .../backend/cuda/gemm/gemm_rcr.py | 9 +- .../backend/cuda/gemm/gemm_rrr.py | 131 ++++++++++++++++++ .../backend/cuda/utils/__init__.py | 0 .../backend/cuda/utils/activation.py | 9 ++ .../tritontemplate/compiler/ops/gemm/gemm.py | 16 ++- .../{test_matmul.py => cuda/test_gemm_rcr.py} | 0 .../testing/cuda/test_gemm_rrr.py | 106 ++++++++++++++ 8 files changed, 269 insertions(+), 5 deletions(-) create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py rename external/TritonTemplate/python/tritontemplate/testing/{test_matmul.py => cuda/test_gemm_rcr.py} (100%) create mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py index 8bf9daa68..6711c64cc 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py @@ -1 +1,2 @@ -from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr,gemm_rcr_bias,gen_grid_gemm_rcr \ No newline at end of file +from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr,gemm_rcr_bias,gen_grid_gemm_rcr +from tritontemplate.backend.cuda.gemm.gemm_rrr import gemm_rrr,gemm_rrr_bias,gen_grid_gemm_rrr diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py index 38b159283..b55ab29a0 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py @@ -1,8 +1,11 @@ import triton import triton.language as tl +from tritontemplate.backend.cuda.utils.activation import * def gen_grid_gemm_rcr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): - + """ + Generates the grid for a GEMM kernel. + """ return ( triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1) @@ -67,7 +70,7 @@ def gemm_rcr_bias( accumulator = accumulator + bias_vals[None, :] if ACTIVATION == 'relu': - accumulator = tl.maximum(accumulator, 0) + accumulator = relu(accumulator) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -121,7 +124,7 @@ def gemm_rcr( b_ptrs += BLOCK_SIZE_K * stride_bk if ACTIVATION == 'relu': - accumulator = tl.maximum(accumulator, 0) + accumulator = relu(accumulator) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py index e69de29bb..1a53edfc3 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py @@ -0,0 +1,131 @@ +import triton +import triton.language as tl +from tritontemplate.backend.cuda.utils.activation import * + +def gen_grid_gemm_rrr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + Generates the grid for a GEMM kernel. + """ + return (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1, 1) + +smem_demand_per_stage = { + 'gemm_rrr_bias': 128 * 128 * 2, + 'gemm_rrr': 128 * 128 * 2, +} + +@triton.jit +def gemm_rrr_bias( + # Pointers to matrices + a_ptr, b_ptr, bias_ptr, c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + stride_biasn: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, +): + """ + Kernel for GEMM RRR (Row-Row-Row) + Bias + ReLU. + A (M, K) @ B (K, N) + Bias (N) -> C (M, N) + B is stored as (N, K) but accessed as if it's (K, N) for the matmul. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_SIZE_K < K), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) + + accumulator += tl.dot(a, b, allow_tf32=enable_tf32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + + if bias_ptr is not None: + offs_bias_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + bias_ptrs = bias_ptr + offs_bias_n * stride_biasn + bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + + accumulator = accumulator + bias_vals[None, :] + + + if ACTIVATION == 'relu': + accumulator = relu(accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def gemm_rrr( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, +): + """ + Kernel for GEMM RRR (Row-Row-Row) + ReLU. + A (M, K) @ B (K,N) -> C (M, N) + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in tl.static_range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_SIZE_K < K), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) + + accumulator += tl.dot(a, b, allow_tf32=enable_tf32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if ACTIVATION == 'relu': + accumulator = relu(accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py new file mode 100644 index 000000000..44ecea8ff --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py @@ -0,0 +1,9 @@ +import triton +import triton.language as tl + +__all__ = ['relu'] + +@triton.jit +def relu(x): + return tl.maximum(x, 0) + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 7f1647081..1361d2e78 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -39,7 +39,6 @@ def __init__( def _induce_output_shape(self): - # TODO: support transpose, by swap A,B if self.layout == 'rcr': M,N,K = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] elif self.layout == 'rrr': @@ -101,6 +100,21 @@ def _gen_constants(self,enable_tf32): const_metadata['stride_cn']=1 if self.is_bias: const_metadata['stride_biasn']=1 + elif self.layout == 'rrr': + input=self._attrs['inputs'] + M,K,N=input[0].shape[0],input[1].shape[0],input[1].shape[1] + print(M,K,N) + const_metadata['M']=M + const_metadata['N']=N + const_metadata['K']=K + const_metadata['stride_am']=K + const_metadata['stride_ak']=1 + const_metadata['stride_bk']=N + const_metadata['stride_bn']=1 + const_metadata['stride_cm']=N + const_metadata['stride_cn']=1 + if self.is_bias: + const_metadata['stride_biasn']=1 else: raise NotImplementedError(f'layout {self.layout} not supported') diff --git a/external/TritonTemplate/python/tritontemplate/testing/test_matmul.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rcr.py similarity index 100% rename from external/TritonTemplate/python/tritontemplate/testing/test_matmul.py rename to external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rcr.py diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py new file mode 100644 index 000000000..9a16f2a14 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py @@ -0,0 +1,106 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.gemm.gemm_rrr import gemm_rrr_bias as gemm_rrr_bias_kernel +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.gemm import Gemm +from tritontemplate.compiler.compiler import compile_kernel + +def gen_gemm_rrr_bias_relu(M, N, K, stype): + """ + Generates an AOT (Ahead-of-Time) compiled kernel for GEMM RRR + Bias + ReLU. + """ + A = Tensor(name='A', dtype=stype, shape=[IntImm(M), IntImm(K)]) + + B = Tensor(name='B', dtype=stype, shape=[IntImm(K), IntImm(N)]) + Bias = Tensor(name='Bias', dtype=stype, shape=[IntImm(N)]) + C = Tensor(name='C', dtype=stype, shape=[IntImm(M), IntImm(N)]) + + gemm_op = Gemm( + inputs=[A, B, Bias], + outputs=[C], + layout='rrr', + is_bias=True, + activation='relu', + ) + + kernel = compile_kernel(gemm_op, device='cuda') + return kernel + +def gemm_rrr_bias_relu_aot(a, b, bias, stype): + """ + Wrapper function to execute the AOT compiled kernel. + """ + M, K = a.shape + K_b, N = b.shape + assert K == K_b, "K dimension mismatch between A and B" + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + a = a.contiguous() + b = b.contiguous() + bias = bias.contiguous() + c = c.contiguous() + kernel = gen_gemm_rrr_bias_relu(M, N, K, stype) + kernel(a, b, bias, c) + return c + +@pytest.mark.parametrize( + 'M, N, K, stype', + [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (31, 128, 2, 'float16'), + (129, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), + ], +) +def test_gemm_rrr_bias_relu(M, N, K, stype): + """ + Tests the RRR GEMM kernel against a reference PyTorch implementation and an AOT compiled version. + """ + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + + + A = torch.randn((M, K), dtype=dtype, device='cuda') + B = torch.randn((K, N), dtype=dtype, device='cuda') + Bias = torch.randn((N,), dtype=dtype, device='cuda') + + + triton_aot_result = gemm_rrr_bias_relu_aot(A, B, Bias, stype) + + + c_triton_jit = torch.empty((M, N), device=A.device, dtype=A.dtype) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + gemm_rrr_bias_kernel[grid]( + A, B, Bias, c_triton_jit, + M, N, K, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + c_triton_jit.stride(0), c_triton_jit.stride(1), + Bias.stride(0), + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + ACTIVATION='relu', enable_tf32=False + ) + + + + pytorch_result = torch.nn.functional.relu(A @ B + Bias) + + assert torch.allclose(c_triton_jit, triton_aot_result, atol=1e-2, rtol=1e-2), \ + f"Outputs mismatch between AOT and JIT for M={M}, N={N}, K={K}\n" + assert torch.allclose(pytorch_result, triton_aot_result, atol=1e-2, rtol=1e-2), \ + f"Outputs mismatch between AOT and PyTorch for M={M}, N={N}, K={K}\n" + From 266f02ad3224886ad7d855572462e30f35a770fb Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 10 Jul 2025 22:06:07 +0800 Subject: [PATCH 18/26] [feat] bmm add --- .../backend/cuda/bmm/__init__.py | 1 + .../tritontemplate/backend/cuda/bmm/bmm.py | 170 ++++++++++++++++++ .../backend/cuda/gemm/gemm_rrr.py | 10 +- .../backend/cuda/utils/utils.py | 8 + .../tritontemplate/compiler/__init__.py | 4 +- .../python/tritontemplate/compiler/base.py | 34 ++++ .../compiler/ops/bmm/__init__.py | 1 + .../tritontemplate/compiler/ops/bmm/bmm.py | 104 +++++++++++ .../tritontemplate/compiler/ops/gemm/gemm.py | 34 +--- .../testing/cuda/test_bmm_add.py | 143 +++++++++++++++ 10 files changed, 471 insertions(+), 38 deletions(-) create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/utils/utils.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py new file mode 100644 index 000000000..5f8c6587a --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py new file mode 100644 index 000000000..bd94ce5d8 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -0,0 +1,170 @@ +import triton +import triton.language as tl + +def gen_grid_bmm(batch_size,M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + Generates the grid for a GEMM kernel. + """ + return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) + +@triton.jit +def bmm_bias( + # Pointers to matrices + a_ptr, b_ptr, bias_ptr, c_ptr, + # Matrix dimensions + BATCH_SIZE: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + is_transpose_a: tl.constexpr,stride_a0: tl.constexpr, stride_a1: tl.constexpr, stride_a2: tl.constexpr, + is_transpose_b: tl.constexpr,stride_b0: tl.constexpr, stride_b1: tl.constexpr, stride_b2: tl.constexpr, + stride_bias0: tl.constexpr,stride_bias1: tl.constexpr,stride_bias2: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, stride_c2: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + enable_tf32: tl.constexpr, + ): + """ + Kernel for BMM + Bias. + if !is_transpose_a and !is_transpose_b : + A (B, M, K) @ B (B, K, N) + Bias (B, M, N) -> C (B, M, N) + """ + batch_id=tl.program_id(0) + pid=tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M) + offs_bn = pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N) + offs_k = tl.arange(0,BLOCK_SIZE_K) + + if is_transpose_a: + # A(B,K,M) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:] + stride_ak=stride_a1 + else: + # A(B,M,K) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_am[:,None]+stride_a2*offs_k[None,:] + stride_ak=stride_a2 + + if is_transpose_b: + # B(B,N,K) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:] + stride_bk=stride_b2 + else: + # B(B,K,N) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_k[:,None]+stride_b2*offs_bn[None,:] + stride_bk=stride_b1 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k C (B, M, N) + """ + batch_id=tl.program_id(0) + pid=tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M) + offs_bn = pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N) + offs_k = tl.arange(0,BLOCK_SIZE_K) + + if is_transpose_a: + # A(B,K,M) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:] + stride_ak=stride_a1 + else: + # A(B,M,K) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_am[:,None]+stride_a2*offs_k[None,:] + stride_ak=stride_a2 + + if is_transpose_b: + # B(B,N,K) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:] + stride_bk=stride_b2 + else: + # B(B,K,N) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_k[:,None]+stride_b2*offs_bn[None,:] + stride_bk=stride_b1 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k Tuple[int, ...]: + slen=len(s) + stride=[1]*slen + for i in range(1,slen): + stride[slen-i-1]=stride[slen-i]*s[slen-i] + return tuple(stride) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/__init__.py index aa1659943..def690a0b 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/__init__.py @@ -1,4 +1,4 @@ -from tritontemplate.compiler import base,dtype,op_registry,ops,symbolic +from tritontemplate.compiler import base,dtype,op_registry,ops from tritontemplate.compiler.compiler import compile_kernel -__all__ = ["base", "compile_kernel","dtype","op_registry","ops","symbolic",] \ No newline at end of file +__all__ = ["base", "compile_kernel","dtype","op_registry","ops",] \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py index 3b0a110fe..c3bc5bf70 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/base.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -2,6 +2,8 @@ from pprint import pformat from typing import Any, Dict, Iterable, List, Optional, Set, Union +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature + class BaseType(ABC): def __init__(self) -> None: super().__init__() @@ -109,3 +111,35 @@ def outputs(self) -> Optional[List[Tensor]]: def compile(self,target_name,workdir): raise NotImplementedError + + def _gen_tensor_signature_divisiability(self,tensors_names:List[str]): + signature_metadata={} + divisiability={1:[],16:[]} + tensor_obj=[] + for tensor_name in tensors_names: + tensor_obj+=self._attrs[tensor_name] + for i,input in enumerate(tensor_obj): + if isinstance(input,Tensor): + try: + sptype='*'+dtype_str_to_triton_signature(input.dtype) + except KeyError: + raise KeyError(f'dtype {input.dtype} not supported') + signature_metadata[i]=sptype + # the ptr from torch is 16-byte aligned + if divisiability.get(16,None) is None: + divisiability[16]=[i] + else: + divisiability[16].append(i) + else: + raise NotImplementedError(f'input {input} not supported') + + return signature_metadata,divisiability + + @staticmethod + def _block_size(x): + if x<=32: + return 32 + elif x<=64: + return 64 + else: + return 128 diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py new file mode 100644 index 000000000..f849c3561 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.bmm.bmm import Bmm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py new file mode 100644 index 000000000..66de76af3 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -0,0 +1,104 @@ +from typing import List,Optional +import importlib + +import triton +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_supported_layouts = ['rcr','rrr','crr','ccr'] + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Bmm(Operation): + def __init__( + self, + inputs:List[Tensor], + layout:str,is_bias:bool=False, + outputs:Optional[List[Tensor]]=None, + name: Optional[str]=None): + assert layout in _supported_layouts, f"Unsupported layout {layout}" + + super().__init__(inputs, outputs,name) + self.layout = layout + self.is_bias = is_bias + self._attrs['inputs'] = inputs + self._attrs['outputs'] = outputs + self._deduce_output_shape() + + def _deduce_output_shape(self): + BATCH_SIZE = self._attrs['inputs'][0].shape[0] + is_transpose_a=self.layout[0]=='c' + is_transpose_b=self.layout[1]=='c' + M=self._attrs['inputs'][0].shape[2] if is_transpose_a else self._attrs['inputs'][0].shape[1] + K=self._attrs['inputs'][0].shape[1] if is_transpose_a else self._attrs['inputs'][0].shape[2] + N=self._attrs['inputs'][1].shape[1] if is_transpose_b else self._attrs['inputs'][1].shape[2] + + self._attrs['BATCH_SIZE']=BATCH_SIZE + self._attrs['M']=M + self._attrs['N']=N + self._attrs['K']=K + self._attrs['is_transpose_a']=is_transpose_a + self._attrs['is_transpose_b']=is_transpose_b + + res_shape=[BATCH_SIZE,M,N] if self.layout[2]=='r' else [BATCH_SIZE,N,M] + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=res_shape,dtype=self._attrs['inputs'][0].dtype)] + else: + assert self._attrs['outputs'][0].shape == res_shape + + def _gen_constants(self,enable_tf32): + const_metadata={} + any_float32=False + for input in self._attrs['inputs']: + if input.dtype == 'float32': + any_float32=True + break + const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) + const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) + + const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False + input=self._attrs['inputs'] + output=self._attrs['outputs'] + const_metadata['BATCH_SIZE']=self._attrs['BATCH_SIZE'] + const_metadata['M']=self._attrs['M'] + const_metadata['N']=self._attrs['N'] + const_metadata['K']=self._attrs['K'] + + const_metadata['is_transpose_a']=self._attrs['is_transpose_a'] + const_metadata['is_transpose_b']=self._attrs['is_transpose_b'] + const_metadata['stride_a0'],const_metadata['stride_a1'],const_metadata['stride_a2']=shape2stride(input[0].shape) + const_metadata['stride_b0'],const_metadata['stride_b1'],const_metadata['stride_b2']=shape2stride(input[1].shape) + if self.is_bias: + const_metadata['stride_bias0'],const_metadata['stride_bias1'],const_metadata['stride_bias2']=shape2stride(input[2].shape) + + const_metadata['stride_c0'],const_metadata['stride_c1'],const_metadata['stride_c2']=shape2stride(output[0].shape) + + + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self, target_name, workdir,enable_tf32: bool = False,)->TritonExecutor: + triton_kernel_name=f'bmm'+ ('' if not self.is_bias else '_bias') + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_grid_bmm') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants(enable_tf32) + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['BATCH_SIZE'],constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 1361d2e78..2df2e1e8e 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -35,10 +35,10 @@ def __init__( self.is_bias= is_bias self._attrs['activation'] = activation self._attrs['inputs'] = inputs - self._attrs['outputs'] = outputs if outputs is not None else self._induce_output_shape() + self._attrs['outputs'] = outputs if outputs is not None else self._deduce_output_shape() - def _induce_output_shape(self): + def _deduce_output_shape(self): if self.layout == 'rcr': M,N,K = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] elif self.layout == 'rrr': @@ -47,33 +47,6 @@ def _induce_output_shape(self): raise NotImplementedError(f'layout {self.layout} not supported') return [Tensor(shape=[M,N],dtype=self._attrs['inputs'][0].dtype)] - def _gen_signature_divisiability(self): - signature_metadata={} - divisiability={1:[],16:[]} - for i,input in enumerate(self._attrs['inputs']+self._attrs['outputs']): - if isinstance(input,Tensor): - try: - sptype='*'+dtype_str_to_triton_signature(input.dtype) - except KeyError: - raise KeyError(f'dtype {input.dtype} not supported') - signature_metadata[i]=sptype - # the ptr from torch is 16-byte aligned - if divisiability.get(16,None) is None: - divisiability[16]=[i] - else: - divisiability[16].append(i) - else: - raise NotImplementedError(f'input {input} not supported') - - return signature_metadata,divisiability - - @staticmethod - def _block_size(x): - if x>=128: - return 128 - elif x<=32: - return 32 - return triton.next_power_of_2(x) def _gen_constants(self,enable_tf32): const_metadata={} @@ -103,7 +76,6 @@ def _gen_constants(self,enable_tf32): elif self.layout == 'rrr': input=self._attrs['inputs'] M,K,N=input[0].shape[0],input[1].shape[0],input[1].shape[1] - print(M,K,N) const_metadata['M']=M const_metadata['N']=N const_metadata['K']=K @@ -132,7 +104,7 @@ def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm_{self.layout}') - signature,divisiability=self._gen_signature_divisiability() + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) constants=self._gen_constants(enable_tf32) exec_metadata=self._gen_exec_metadata() diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py new file mode 100644 index 000000000..11ff54790 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py @@ -0,0 +1,143 @@ +import torch +import pytest + +import triton +from tritontemplate.compiler.base import IntImm,Tensor +from tritontemplate.compiler.ops.bmm import Bmm +from tritontemplate.compiler.compiler import compile_kernel + +from tritontemplate.backend.cuda.bmm.bmm import bmm_bias as bmm_bias_kernel +from tritontemplate.backend.cuda.bmm.bmm import bmm as bmm_kernel + +def gen_bmm_bias(format, batch_size, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[batch_size,M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[batch_size,K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[batch_size,K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[batch_size,N,K]) + Bias=Tensor(name='Bias',dtype=stype,shape=[batch_size,M,N]) + C=Tensor(name='C',dtype=stype,shape=[batch_size,M,N]) + bmm_op=Bmm( + inputs=[A,B,Bias], + outputs=[C], + layout=format, + is_bias=True + ) + kernel = compile_kernel(bmm_op,device='cuda') + return kernel + +def gen_bmm(format, batch_size, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[batch_size,M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[batch_size,K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[batch_size,K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[batch_size,N,K]) + C=Tensor(name='C',dtype=stype,shape=[batch_size,M,N]) + bmm_op=Bmm( + inputs=[A,B], + outputs=[C], + layout=format, + is_bias=False + ) + kernel = compile_kernel(bmm_op,device='cuda') + return kernel + + +FORMATS = ['rcr','rrr','ccr','crr'] +MATRIX_PARAMS = [ + (2, 2, 128, 31, 'float32'), + (2, 128, 2, 31, 'float16'), + (2, 128, 128, 31, 'float32'), + (2, 31, 128, 2, 'float16'), + (2, 129, 128, 128, 'float32'), + (2, 128, 257, 512, 'float16'), + (2, 128, 512, 257, 'float32'), + (2, 127, 256, 256, 'float16'), + (2, 128, 511, 512, 'float32'), + (2, 256, 128, 255, 'float16'), + (2, 1, 256, 256, 'float32'), +] + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'batch_size, M, N, K, stype', + MATRIX_PARAMS +) +def test_bmm_bias(format, batch_size, M, N, K, stype): + torch.manual_seed(0) + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32=False + dtype=torch.float32 + else: + dtype=torch.float16 + + a = torch.randn(batch_size, M, K, dtype=dtype, device='cuda') + b = torch.randn(batch_size, K, N, dtype=dtype, device='cuda') + bias = torch.randn(batch_size,M, N, dtype=dtype, device='cuda') + c_triton = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') + c_ttemplate = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') + + c_torch= torch.bmm(a, b)+bias + + grid = lambda META: ( + META['BATCH_SIZE'],triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + is_trans_a=False + is_trans_b=False + if format[0]=='c': + a=a.transpose(1,2).contiguous() + is_trans_a=True + if format[1]=='c': + b=b.transpose(1,2).contiguous() + is_trans_b=True + test_kernel=gen_bmm_bias(format,batch_size,M,N,K,stype) + test_kernel(a,b,bias,c_ttemplate) + bmm_bias_kernel[grid](a,b,bias,c_triton,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*bias.stride(),*c_triton.stride(),64,64,64,False) + print(*b.stride()) + torch.testing.assert_close(c_ttemplate,c_triton,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(c_ttemplate,c_torch,atol=1e-2,rtol=1e-2) + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'batch_size, M, N, K, stype', + MATRIX_PARAMS +) +def test_bmm(format, batch_size, M, N, K, stype): + torch.manual_seed(0) + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32=False + dtype=torch.float32 + else: + dtype=torch.float16 + + a = torch.randn(batch_size, M, K, dtype=dtype, device='cuda') + b = torch.randn(batch_size, K, N, dtype=dtype, device='cuda') + c_triton = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') + c_ttemplate = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') + + c_torch= torch.bmm(a, b) + grid = lambda META: ( + META['BATCH_SIZE'],triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + is_trans_a=False + is_trans_b=False + if format[0]=='c': + a=a.transpose(1,2).contiguous() + is_trans_a=True + if format[1]=='c': + b=b.transpose(1,2).contiguous() + is_trans_b=True + + + bmm_kernel[grid](a,b,c_triton,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*c_triton.stride(),64,64,64,False) + kernel=gen_bmm(format,batch_size,M,N,K,stype) + kernel(a,b,c_ttemplate) + torch.testing.assert_close(c_ttemplate,c_triton,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(c_ttemplate,c_torch,atol=1e-2,rtol=1e-2) + From b69d010b5daed54dbb1ee1f7b708d9244efac0e1 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 11 Jul 2025 01:05:10 +0800 Subject: [PATCH 19/26] [fix] reform gemm --- .../tritontemplate/backend/cuda/bmm/bmm.py | 4 +- .../backend/cuda/gemm/__init__.py | 3 +- .../tritontemplate/backend/cuda/gemm/gemm.py | 185 +++++++++++++++++ .../backend/cuda/gemm/gemm_rcr.py | 133 ------------- .../backend/cuda/gemm/gemm_rrr.py | 131 ------------ .../tritontemplate/compiler/ops/bmm/bmm.py | 3 +- .../tritontemplate/compiler/ops/gemm/gemm.py | 89 +++++---- .../cuda/{test_bmm_add.py => test_bmm.py} | 24 +-- .../tritontemplate/testing/cuda/test_gemm.py | 186 ++++++++++++++++++ .../testing/cuda/test_gemm_rcr.py | 80 -------- .../testing/cuda/test_gemm_rrr.py | 106 ---------- 11 files changed, 431 insertions(+), 513 deletions(-) create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py delete mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py delete mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py rename external/TritonTemplate/python/tritontemplate/testing/cuda/{test_bmm_add.py => test_bmm.py} (80%) create mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py delete mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rcr.py delete mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py index bd94ce5d8..38a1813da 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -3,7 +3,7 @@ def gen_grid_bmm(batch_size,M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): """ - Generates the grid for a GEMM kernel. + Generates the grid for a Batch GEMM kernel. """ return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) @@ -88,8 +88,6 @@ def bmm_bias( tl.store(c_ptr+c_ptrs_offs,accumulator,mask=c_mask) - - @triton.jit def bmm( diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py index 6711c64cc..ddc868bd3 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py @@ -1,2 +1 @@ -from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr,gemm_rcr_bias,gen_grid_gemm_rcr -from tritontemplate.backend.cuda.gemm.gemm_rrr import gemm_rrr,gemm_rrr_bias,gen_grid_gemm_rrr +from tritontemplate.backend.cuda.gemm.gemm import gemm,gemm_bias,gen_grid_gemm diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py new file mode 100644 index 000000000..f9b3d1b52 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py @@ -0,0 +1,185 @@ +import triton +import triton.language as tl +from tritontemplate.backend.cuda.utils.activation import * + +def gen_grid_gemm(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + Generates the grid for a GEMM kernel. + """ + return ( + triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1) + +# smem_size=val*dtype_size*num_stage +smem_demand_per_stage ={ + 'gemm_bias': 128*128*2, + 'gemm': 128*128*2, +} + +@triton.jit +def gemm_bias( + # Pointers to matrices + a_ptr, b_ptr, bias_ptr, c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + is_transpose_a: tl.constexpr, + stride_a0: tl.constexpr, stride_a1: tl.constexpr, + is_transpose_b: tl.constexpr, + stride_b0: tl.constexpr, stride_b1: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, + stride_bias0: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, +): + """ + Kernel for GEMM + Bias + ReLU. + A(M,K) @ B(K,N) + Bias -> C(M,N) + """ + # _TF32_ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" + # DTYPE: tl.constexpr = k.dtype.element_ty + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + if is_transpose_a: + # A(K,M) + a_ptrs = a_ptr + offs_k[:, None] * stride_a0 + offs_am[None, :] * stride_a1 + stride_ak = stride_a0 + else: + # A(M,K) + a_ptrs = a_ptr + offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1 + stride_ak = stride_a1 + + if is_transpose_b: + # B(N,K) + b_ptrs = b_ptr + offs_bn[:, None] * stride_b0 + offs_k[None, :] * stride_b1 + stride_bk = stride_b1 + else: + # B(K,N) + b_ptrs = b_ptr + offs_k[:, None] * stride_b0 + offs_bn[None, :] * stride_b1 + stride_bk = stride_b0 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_am[None, :] < M) + else: + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + if is_transpose_b: + b_mask = (offs_bn[:, None] < N) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + else: + b_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_bn[None, :] < N) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + if is_transpose_a: + a = tl.trans(a) + if is_transpose_b: + b = tl.trans(b) + accumulator += tl.dot(a, b, allow_tf32=enable_tf32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if bias_ptr is not None: + offs_bias_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + bias_ptrs = bias_ptr + offs_bias_n * stride_bias0 + bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + accumulator = accumulator + bias_vals[None, :] + + if ACTIVATION == 'relu': + accumulator = relu(accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def gemm( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + is_transpose_a: tl.constexpr, stride_a0: tl.constexpr, stride_a1: tl.constexpr, + is_transpose_b: tl.constexpr, stride_b0: tl.constexpr, stride_b1: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, +): + """ + Kernel for GEMM + ReLU. + A @ B -> C + This kernel can handle transpositions for A and B. + """ + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + if is_transpose_a: + # A(K,M) + a_ptrs = a_ptr + offs_k[:, None] * stride_a0 + offs_am[None, :] * stride_a1 + stride_ak = stride_a0 + else: + # A(M,K) + a_ptrs = a_ptr + offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1 + stride_ak = stride_a1 + + if is_transpose_b: + # B(N,K) + b_ptrs = b_ptr + offs_bn[:, None] * stride_b0 + offs_k[None, :] * stride_b1 + stride_bk = stride_b1 + else: + # B(K,N) + b_ptrs = b_ptr + offs_k[:, None] * stride_b0 + offs_bn[None, :] * stride_b1 + stride_bk = stride_b0 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_am[None, :] < M) + else: + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + if is_transpose_b: + b_mask = (offs_bn[:, None] < N) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + else: + b_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_bn[None, :] < N) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + if is_transpose_a: + a = tl.trans(a) + if is_transpose_b: + b = tl.trans(b) + accumulator += tl.dot(a, b, allow_tf32=enable_tf32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if ACTIVATION == 'relu': + accumulator = relu(accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py deleted file mode 100644 index b55ab29a0..000000000 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rcr.py +++ /dev/null @@ -1,133 +0,0 @@ -import triton -import triton.language as tl -from tritontemplate.backend.cuda.utils.activation import * - -def gen_grid_gemm_rcr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): - """ - Generates the grid for a GEMM kernel. - """ - return ( - triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1) - -# smem_size=val*dtype_size*num_stage -smem_demand_per_stage ={ - 'gemm_rcr_bias': 128*128*2, - 'gemm_rcr': 128*128*2, -} - -@triton.jit -def gemm_rcr_bias( - # Pointers to matrices - a_ptr, b_ptr, bias_ptr, c_ptr, - # Matrix dimensions - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - # Strides for matrices - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bn: tl.constexpr, stride_bk: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - stride_biasn: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr, # 'relu' or None - enable_tf32: tl.constexpr, -): - """ - Kernel for GEMM RCR (Row-Col-Row) + Bias + ReLU. - A (M, K) @ B (N, K)^T + Bias (N) -> C (M, N) - B is stored as (N, K) but accessed as if it's (K, N) for the matmul. - """ - # _TF32_ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" - # DTYPE: tl.constexpr = k.dtype.element_ty - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_k[None, :] * stride_bk) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_bn[:, None] < N), other=0.0) - # if enable_tf32 and DTYPE == tl.float32: - # a = tl.inline_asm_elementwise(_TF32_ASM, "=r, r", [a], dtype=tl.float32, is_pure=True, pack=1) - # b = tl.inline_asm_elementwise(_TF32_ASM, "=r, r", [b], dtype=tl.float32, is_pure=True, pack=1) - accumulator += tl.dot(a, tl.trans(b),allow_tf32=enable_tf32) - - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if bias_ptr is not None: - offs_bias_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - bias_ptrs = bias_ptr + offs_bias_n * stride_biasn - bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) - accumulator = accumulator + bias_vals[None, :] - - if ACTIVATION == 'relu': - accumulator = relu(accumulator) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - - -@triton.jit -def gemm_rcr( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - # Strides for matrices - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bn: tl.constexpr, stride_bk: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr, # 'relu' or None - enable_tf32: tl.constexpr, -): - """ - Kernel for GEMM RCR (Row-Col-Row) + ReLU. - A (M, K) @ B (N, K)^T -> C (M, N) - B is stored as (N, K) but accessed as if it's (K, N) for the matmul. - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_k[None, :] * stride_bk) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in tl.static_range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_am[:, None] < M), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[None, :] + k * BLOCK_SIZE_K < K) & (offs_bn[:, None] < N), other=0.0) - - accumulator += tl.dot(a, tl.trans(b)) - - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if ACTIVATION == 'relu': - accumulator = relu(accumulator) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py deleted file mode 100644 index 07bc76b02..000000000 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm_rrr.py +++ /dev/null @@ -1,131 +0,0 @@ -import triton -import triton.language as tl -from tritontemplate.backend.cuda.utils.activation import * - -def gen_grid_gemm_rrr(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): - """ - Generates the grid for a GEMM kernel. - """ - return (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1, 1) - -smem_demand_per_stage = { - 'gemm_rrr_bias': 128 * 128 * 2, - 'gemm_rrr': 128 * 128 * 2, -} - -@triton.jit -def gemm_rrr_bias( - # Pointers to matrices - a_ptr, b_ptr, bias_ptr, c_ptr, - # Matrix dimensions - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - # Strides for matrices - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bk: tl.constexpr, stride_bn: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - stride_biasn: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr, # 'relu' or None - enable_tf32: tl.constexpr, -): - """ - Kernel for GEMM RRR (Row-Row-Row) + Bias + ReLU. - A (M, K) @ B (K, N) + Bias (N) -> C (M, N) - B is stored as (N, K) but accessed as if it's (K, N) for the matmul. - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_SIZE_K < K), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) - - accumulator += tl.dot(a, b, allow_tf32=enable_tf32) - - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - - - offs_bias_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - bias_ptrs = bias_ptr + offs_bias_n * stride_biasn - bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) - - accumulator = accumulator + bias_vals[None, :] - - - if ACTIVATION == 'relu': - accumulator = relu(accumulator) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -@triton.jit -def gemm_rrr( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - # Strides for matrices - stride_am: tl.constexpr, stride_ak: tl.constexpr, - stride_bk: tl.constexpr, stride_bn: tl.constexpr, - stride_cm: tl.constexpr, stride_cn: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - ACTIVATION: tl.constexpr, # 'relu' or None - enable_tf32: tl.constexpr, -): - """ - Kernel for GEMM RRR (Row-Row-Row) + ReLU. - A (M, K) @ B (K,N) -> C (M, N) - """ - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - pid_m = pid // num_pid_n - pid_n = pid % num_pid_n - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in tl.static_range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) & (offs_k[None, :] + k * BLOCK_SIZE_K < K), other=0.0) - b = tl.load(b_ptrs, mask=(offs_k[:, None] + k * BLOCK_SIZE_K < K) & (offs_bn[None, :] < N), other=0.0) - - accumulator += tl.dot(a, b, allow_tf32=enable_tf32) - - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if ACTIVATION == 'relu': - accumulator = relu(accumulator) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index 66de76af3..13ac1b531 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -50,7 +50,7 @@ def _deduce_output_shape(self): if self._attrs['outputs'] is None: self._attrs['outputs'] = [Tensor(shape=res_shape,dtype=self._attrs['inputs'][0].dtype)] else: - assert self._attrs['outputs'][0].shape == res_shape + assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" def _gen_constants(self,enable_tf32): const_metadata={} @@ -90,6 +90,7 @@ def compile(self, target_name, workdir,enable_tf32: bool = False,)->TritonExecut triton_kernel_name=f'bmm'+ ('' if not self.is_bias else '_bias') triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_grid_bmm') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) constants=self._gen_constants(enable_tf32) exec_metadata=self._gen_exec_metadata() diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 2df2e1e8e..5a8444237 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -7,8 +7,9 @@ from tritontemplate.compiler.dtype import dtype_str_to_triton_signature from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride -_supported_layouts = ['rcr','rrr'] +_supported_layouts = ['rcr','rrr','ccr','crr'] _supported_activations = ['relu',None] @@ -35,19 +36,31 @@ def __init__( self.is_bias= is_bias self._attrs['activation'] = activation self._attrs['inputs'] = inputs - self._attrs['outputs'] = outputs if outputs is not None else self._deduce_output_shape() + self._attrs['outputs'] = outputs + self._deduce_output_shape() def _deduce_output_shape(self): - if self.layout == 'rcr': - M,N,K = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] - elif self.layout == 'rrr': - M,K,N = self._attrs['inputs'][0].shape[0],self._attrs['inputs'][1].shape[0],self._attrs['inputs'][0].shape[1] + + is_transpose_a=self.layout[0]=='c' + is_transpose_b=self.layout[1]=='c' + M=self._attrs['inputs'][0].shape[1] if is_transpose_a else self._attrs['inputs'][0].shape[0] + K=self._attrs['inputs'][0].shape[0] if is_transpose_a else self._attrs['inputs'][0].shape[1] + N=self._attrs['inputs'][1].shape[0] if is_transpose_b else self._attrs['inputs'][1].shape[1] + + + self._attrs['M'] = M + self._attrs['K'] = K + self._attrs['N'] = N + self._attrs['is_transpose_a'] = is_transpose_a + self._attrs['is_transpose_b'] = is_transpose_b + + res_shape=[M,N] if self.layout[2]=='r' else [N,M] + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=res_shape,dtype=self._attrs['inputs'][0].dtype)] else: - raise NotImplementedError(f'layout {self.layout} not supported') - return [Tensor(shape=[M,N],dtype=self._attrs['inputs'][0].dtype)] - - + assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" + def _gen_constants(self,enable_tf32): const_metadata={} const_metadata['ACTIVATION'] = self._attrs['activation'] @@ -59,40 +72,26 @@ def _gen_constants(self,enable_tf32): break const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False - if self.layout == 'rcr': - input=self._attrs['inputs'] - M,N,K=input[0].shape[0],input[1].shape[0],input[0].shape[1] - const_metadata['M']=M - const_metadata['N']=N - const_metadata['K']=K - const_metadata['stride_am']=K - const_metadata['stride_ak']=1 - const_metadata['stride_bn']=K - const_metadata['stride_bk']=1 - const_metadata['stride_cm']=N - const_metadata['stride_cn']=1 - if self.is_bias: - const_metadata['stride_biasn']=1 - elif self.layout == 'rrr': - input=self._attrs['inputs'] - M,K,N=input[0].shape[0],input[1].shape[0],input[1].shape[1] - const_metadata['M']=M - const_metadata['N']=N - const_metadata['K']=K - const_metadata['stride_am']=K - const_metadata['stride_ak']=1 - const_metadata['stride_bk']=N - const_metadata['stride_bn']=1 - const_metadata['stride_cm']=N - const_metadata['stride_cn']=1 - if self.is_bias: - const_metadata['stride_biasn']=1 - else: - raise NotImplementedError(f'layout {self.layout} not supported') + + const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) + const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) + + input=self._attrs['inputs'] + output=self._attrs['outputs'] + const_metadata['M']=self._attrs['M'] + const_metadata['N']=self._attrs['N'] + const_metadata['K']=self._attrs['K'] - const_metadata['BLOCK_SIZE_M']= self._block_size(M) - const_metadata['BLOCK_SIZE_N']= self._block_size(N) - const_metadata['BLOCK_SIZE_K']= self._block_size(K) + const_metadata['is_transpose_a']=self._attrs['is_transpose_a'] + const_metadata['is_transpose_b']=self._attrs['is_transpose_b'] + const_metadata['stride_a0'],const_metadata['stride_a1']=shape2stride(input[0].shape) + const_metadata['stride_b0'],const_metadata['stride_b1']=shape2stride(input[1].shape) + + if self.is_bias: + const_metadata['stride_bias0']=1 + + const_metadata['stride_c0'],const_metadata['stride_c1']=shape2stride(output[0].shape) return const_metadata def _gen_exec_metadata(self): @@ -100,9 +99,9 @@ def _gen_exec_metadata(self): #TODO:enable_tf32 https://github.com/triton-lang/triton/issues/4574 def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor: - triton_kernel_name=f'gemm_{self.layout}'+ ('' if not self.is_bias else '_bias') + triton_kernel_name=f'gemm'+ ('' if not self.is_bias else '_bias') triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm_{self.layout}') + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm') signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) constants=self._gen_constants(enable_tf32) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py similarity index 80% rename from external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py rename to external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py index 11ff54790..cc353eacb 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm_add.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py @@ -80,8 +80,8 @@ def test_bmm_bias(format, batch_size, M, N, K, stype): a = torch.randn(batch_size, M, K, dtype=dtype, device='cuda') b = torch.randn(batch_size, K, N, dtype=dtype, device='cuda') bias = torch.randn(batch_size,M, N, dtype=dtype, device='cuda') - c_triton = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') - c_ttemplate = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') + c_triton_jit = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') + c_triton_aot = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') c_torch= torch.bmm(a, b)+bias @@ -97,11 +97,11 @@ def test_bmm_bias(format, batch_size, M, N, K, stype): b=b.transpose(1,2).contiguous() is_trans_b=True test_kernel=gen_bmm_bias(format,batch_size,M,N,K,stype) - test_kernel(a,b,bias,c_ttemplate) - bmm_bias_kernel[grid](a,b,bias,c_triton,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*bias.stride(),*c_triton.stride(),64,64,64,False) + test_kernel(a,b,bias,c_triton_aot) + bmm_bias_kernel[grid](a,b,bias,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*bias.stride(),*c_triton_jit.stride(),64,64,64,False) print(*b.stride()) - torch.testing.assert_close(c_ttemplate,c_triton,atol=1e-2,rtol=1e-2) - torch.testing.assert_close(c_ttemplate,c_torch,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(c_triton_aot,c_torch,atol=1e-2,rtol=1e-2) @pytest.mark.parametrize('format', FORMATS) @pytest.mark.parametrize( @@ -118,8 +118,8 @@ def test_bmm(format, batch_size, M, N, K, stype): a = torch.randn(batch_size, M, K, dtype=dtype, device='cuda') b = torch.randn(batch_size, K, N, dtype=dtype, device='cuda') - c_triton = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') - c_ttemplate = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') + c_triton_jit = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') + c_triton_aot = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') c_torch= torch.bmm(a, b) grid = lambda META: ( @@ -135,9 +135,9 @@ def test_bmm(format, batch_size, M, N, K, stype): is_trans_b=True - bmm_kernel[grid](a,b,c_triton,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*c_triton.stride(),64,64,64,False) + bmm_kernel[grid](a,b,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*c_triton_jit.stride(),64,64,64,False) kernel=gen_bmm(format,batch_size,M,N,K,stype) - kernel(a,b,c_ttemplate) - torch.testing.assert_close(c_ttemplate,c_triton,atol=1e-2,rtol=1e-2) - torch.testing.assert_close(c_ttemplate,c_torch,atol=1e-2,rtol=1e-2) + kernel(a,b,c_triton_aot) + torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(c_triton_aot,c_torch,atol=1e-2,rtol=1e-2) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py new file mode 100644 index 000000000..664374aae --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -0,0 +1,186 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.gemm.gemm import gemm_bias as gemm_bias_kernel +from tritontemplate.backend.cuda.gemm.gemm import gemm as gemm_kernel +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.gemm import Gemm +from tritontemplate.compiler.compiler import compile_kernel + +def gen_gemm_bias(format, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[N,K]) + Bias=Tensor(name='Bias',dtype=stype,shape=[M,N]) + C=Tensor(name='C',dtype=stype,shape=[M,N]) + gemm_op=Gemm( + inputs=[A,B,Bias], + outputs=[C], + layout=format, + is_bias=True, + activation='relu', + ) + kernel = compile_kernel(gemm_op,device='cuda') + return kernel + +def gen_gemm(format, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[N,K]) + C=Tensor(name='C',dtype=stype,shape=[M,N]) + gemm_op=Gemm( + inputs=[A,B], + outputs=[C], + layout=format, + is_bias=False, + activation='relu', + ) + kernel = compile_kernel(gemm_op,device='cuda') + return kernel + + + +FORMATS = ['rcr','rrr','ccr','crr'] +MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 2, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), +] + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'M, N, K, stype', + MATRIX_PARAMS +) +def test_gemm_bias_relu(format, M, N, K, stype): + + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + + + A = torch.randn((M, K), dtype=dtype, device='cuda') + B = torch.randn((K, N), dtype=dtype, device='cuda') + Bias = torch.randn((N,), dtype=dtype, device='cuda') + c_triton_jit = torch.empty((M, N), device=A.device, dtype=A.dtype) + c_triton_aot = torch.empty((M, N), device=A.device, dtype=A.dtype) + + pytorch_result = torch.nn.functional.relu(A @ B + Bias) + + is_trans_a=False + is_trans_b=False + + if format[0] == 'c': + is_trans_a = True + A = A.transpose(1,0).contiguous() + + if format[1] == 'c': + is_trans_b = True + B = B.transpose(1,0).contiguous() + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + gemm_bias_kernel[grid]( + A, B, Bias, c_triton_jit, + M, N, K, + is_trans_a, *A.stride(), + is_trans_b, *B.stride(), + *c_triton_jit.stride(), + *Bias.stride(), + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + ACTIVATION='relu', enable_tf32=False + ) + kernel = gen_gemm_bias(format, M, N, K, stype) + kernel(A, B, Bias, c_triton_aot) + + assert torch.allclose(c_triton_aot, c_triton_jit, atol=1e-2, rtol=1e-2) + assert torch.allclose(pytorch_result, c_triton_jit, atol=1e-2, rtol=1e-2) + +FORMATS = ['rcr','rrr','ccr','crr'] +MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 2, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), +] + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'M, N, K, stype', + MATRIX_PARAMS +) +def test_gemm_relu(format, M, N, K, stype): + + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + + A = torch.randn((M, K), dtype=dtype, device='cuda') + B = torch.randn((K, N), dtype=dtype, device='cuda') + c_triton_jit = torch.empty((M, N), device=A.device, dtype=A.dtype) + c_triton_aot = torch.empty((M, N), device=A.device, dtype=A.dtype) + + pytorch_result = torch.nn.functional.relu(A @ B) + + is_trans_a=False + is_trans_b=False + + if format[0] == 'c': + is_trans_a = True + A = A.transpose(1,0).contiguous() + + if format[1] == 'c': + is_trans_b = True + B = B.transpose(1,0).contiguous() + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + gemm_kernel[grid]( + A, B, c_triton_jit, + M, N, K, + is_trans_a, *A.stride(), + is_trans_b, *B.stride(), + *c_triton_jit.stride(), + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + ACTIVATION='relu', enable_tf32=False + ) + kernel = gen_gemm(format, M, N, K, stype) + kernel(A, B, c_triton_aot) + + assert torch.allclose(c_triton_aot, c_triton_jit, atol=1e-2, rtol=1e-2) + assert torch.allclose(pytorch_result, c_triton_jit, atol=1e-2, rtol=1e-2) + diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rcr.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rcr.py deleted file mode 100644 index e681a512e..000000000 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rcr.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -import pytest - -import triton -from tritontemplate.compiler.base import IntImm,Tensor -from tritontemplate.compiler.ops.gemm import Gemm -from tritontemplate.compiler.compiler import compile_kernel -from tritontemplate.backend.cuda.gemm.gemm_rcr import gemm_rcr_bias as gemm_rcr_bias_kernel - -def gen_gemm_rcr_bias_relu(M, N, K, stype): - A = Tensor(name='A', dtype=stype, shape=[IntImm(M), IntImm(K)]) - B = Tensor(name='B', dtype=stype, shape=[IntImm(N), IntImm(K)]) - Bias = Tensor(name='Bias', dtype=stype, shape=[IntImm(N)]) - C = Tensor(name='C', dtype=stype, shape=[IntImm(M), IntImm(N)]) - - gemm_op = Gemm( - inputs=[A, B, Bias], - outputs=[C], - layout='rcr', - is_bias=True, - activation='relu', - ) - - kernel = compile_kernel(gemm_op, device='cuda') - return kernel - -def gemm_rcr_bias_relu(a, b, bias,stype): - M, K = a.shape - N, K_b = b.shape - - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - a=a.contiguous() - b=b.contiguous() - bias=bias.contiguous() - c=c.contiguous() - kernel = gen_gemm_rcr_bias_relu(M, N, K,stype) - kernel(a, b, bias, c) - return c - -@pytest.mark.parametrize( - 'M, N, K, stype', - [ - (2, 128, 31,'float32'), - (128,2,31,'float16'), - (128,128,31,'float32'), - (31,128,2,'float16'), - (129,128,128,'float32'), - (128, 257,512,'float16'), - (128, 512, 257,'float32'), - (127, 256, 256,'float16'), - (128, 511, 512,'float32'), - (256,128,255,'float16'), - (1,256,256,'float32'), - ], -) -def test_gemm_rcr_bias_relu(M, N, K, stype): - if stype == 'float32': - torch.backends.cuda.matmul.allow_tf32=False - dtype=torch.float32 - else: - dtype=torch.float16 - - A = torch.randn((M, K), dtype=dtype, device='cuda') - B = torch.randn((N, K), dtype=dtype, device='cuda') - Bias = torch.randn((N,), dtype=dtype, device='cuda') - - # Triton and PyTorch outputs - c_triton = torch.empty((M, N), device=A.device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - triton_aot=gemm_rcr_bias_relu(A,B,Bias,stype) - gemm_rcr_bias_kernel[grid](A,B,Bias,c_triton, M, N, K,A.stride(0),A.stride(1),B.stride(0),B.stride(1),c_triton.stride(0),c_triton.stride(1),Bias.stride(0),64,64,64,'relu',enable_tf32=False) - - assert torch.allclose(c_triton, triton_aot, atol=1e-2, rtol=1e-2), \ - f"Outputs mismatch between aot and jit for M={M}, N={N}, K={K}\n" - c=torch.nn.functional.relu(torch.nn.functional.linear(A,B,bias=Bias)) - assert torch.allclose(c, triton_aot, atol=1e-2, rtol=1e-2), \ - f"Outputs mismatch standard for M={M}, N={N}, K={K}\n" diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py deleted file mode 100644 index 9a16f2a14..000000000 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm_rrr.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import pytest -import triton - -from tritontemplate.backend.cuda.gemm.gemm_rrr import gemm_rrr_bias as gemm_rrr_bias_kernel -from tritontemplate.compiler.base import IntImm, Tensor -from tritontemplate.compiler.ops.gemm import Gemm -from tritontemplate.compiler.compiler import compile_kernel - -def gen_gemm_rrr_bias_relu(M, N, K, stype): - """ - Generates an AOT (Ahead-of-Time) compiled kernel for GEMM RRR + Bias + ReLU. - """ - A = Tensor(name='A', dtype=stype, shape=[IntImm(M), IntImm(K)]) - - B = Tensor(name='B', dtype=stype, shape=[IntImm(K), IntImm(N)]) - Bias = Tensor(name='Bias', dtype=stype, shape=[IntImm(N)]) - C = Tensor(name='C', dtype=stype, shape=[IntImm(M), IntImm(N)]) - - gemm_op = Gemm( - inputs=[A, B, Bias], - outputs=[C], - layout='rrr', - is_bias=True, - activation='relu', - ) - - kernel = compile_kernel(gemm_op, device='cuda') - return kernel - -def gemm_rrr_bias_relu_aot(a, b, bias, stype): - """ - Wrapper function to execute the AOT compiled kernel. - """ - M, K = a.shape - K_b, N = b.shape - assert K == K_b, "K dimension mismatch between A and B" - c = torch.empty((M, N), device=a.device, dtype=a.dtype) - a = a.contiguous() - b = b.contiguous() - bias = bias.contiguous() - c = c.contiguous() - kernel = gen_gemm_rrr_bias_relu(M, N, K, stype) - kernel(a, b, bias, c) - return c - -@pytest.mark.parametrize( - 'M, N, K, stype', - [ - (2, 128, 31, 'float32'), - (128, 2, 31, 'float16'), - (128, 128, 31, 'float32'), - (31, 128, 2, 'float16'), - (129, 128, 128, 'float32'), - (128, 257, 512, 'float16'), - (128, 512, 257, 'float32'), - (127, 256, 256, 'float16'), - (128, 511, 512, 'float32'), - (256, 128, 255, 'float16'), - (1, 256, 256, 'float32'), - ], -) -def test_gemm_rrr_bias_relu(M, N, K, stype): - """ - Tests the RRR GEMM kernel against a reference PyTorch implementation and an AOT compiled version. - """ - if stype == 'float32': - torch.backends.cuda.matmul.allow_tf32 = False - dtype = torch.float32 - else: - dtype = torch.float16 - - - A = torch.randn((M, K), dtype=dtype, device='cuda') - B = torch.randn((K, N), dtype=dtype, device='cuda') - Bias = torch.randn((N,), dtype=dtype, device='cuda') - - - triton_aot_result = gemm_rrr_bias_relu_aot(A, B, Bias, stype) - - - c_triton_jit = torch.empty((M, N), device=A.device, dtype=A.dtype) - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) - - gemm_rrr_bias_kernel[grid]( - A, B, Bias, c_triton_jit, - M, N, K, - A.stride(0), A.stride(1), - B.stride(0), B.stride(1), - c_triton_jit.stride(0), c_triton_jit.stride(1), - Bias.stride(0), - BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, - ACTIVATION='relu', enable_tf32=False - ) - - - - pytorch_result = torch.nn.functional.relu(A @ B + Bias) - - assert torch.allclose(c_triton_jit, triton_aot_result, atol=1e-2, rtol=1e-2), \ - f"Outputs mismatch between AOT and JIT for M={M}, N={N}, K={K}\n" - assert torch.allclose(pytorch_result, triton_aot_result, atol=1e-2, rtol=1e-2), \ - f"Outputs mismatch between AOT and PyTorch for M={M}, N={N}, K={K}\n" - From 95901f8c0ca2d0885e3d8f2ae29ea84a4a8f5853 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 17 Jul 2025 11:00:33 +0800 Subject: [PATCH 20/26] [feat] supported softmax --- .../backend/cuda/softmax/__init__.py | 1 + .../backend/cuda/softmax/softmax.py | 91 +++++++++++++++++++ .../compiler/ops/softmax/__init__.py | 1 + .../compiler/ops/softmax/softmax.py | 69 ++++++++++++++ .../tritontemplate/testing/cuda/test_bmm.py | 5 +- .../tritontemplate/testing/cuda/test_gemm.py | 4 +- .../testing/cuda/test_softmax.py | 68 ++++++++++++++ 7 files changed, 234 insertions(+), 5 deletions(-) create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py new file mode 100644 index 000000000..2fc667265 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.softmax.softmax import softmax,online_softmax,gen_grid_softmax \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py new file mode 100644 index 000000000..8fea36cea --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py @@ -0,0 +1,91 @@ +import triton +import triton.language as tl + +def gen_grid_softmax(M, BLOCK_SIZE_M): + """ + Generates the grid for a softmax kernel. + """ + return ( + triton.cdiv(M, BLOCK_SIZE_M),1,1) + +@triton.jit +def softmax(x_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr, stride_y0:tl.constexpr, stride_y1:tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + ''' + softmax function for last dimension. + ''' + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_M + offsets_M = tl.arange(0, BLOCK_SIZE_M) + block_start + + mask_M = offsets_M < M + + max_ele = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + elem = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) + max_ele = tl.maximum(max_ele,tl.max(elem, axis=1)) + + + y_sum = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + elem = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) + y = tl.exp(elem - max_ele[:, None]) + y_sum += tl.sum(y, axis=1) + tl.store(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, y,mask=mask_M[:, None] & mask_N[None, :]) + + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + y = tl.load(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, mask=mask_M[:, None] & mask_N[None, :], other=0) + y = y / y_sum[:,None] + tl.store(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, y, mask=mask_M[:, None] & mask_N[None, :]) + + +@triton.jit +def online_softmax(x_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr, stride_y0:tl.constexpr, stride_y1:tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + ''' + online softmax function for last dimension. + ''' + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_M + offsets_M = tl.arange(0, BLOCK_SIZE_M) + block_start + mask_M = offsets_M < M + + m_i = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - float('inf') + s_i = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i * BLOCK_SIZE_N + mask_N = offsets_N < N + + x = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, + mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) + + m_block = tl.max(x, axis=1) + m_new = tl.maximum(m_i, m_block) + + s_i *= tl.exp(m_i - m_new) + s_i += tl.sum(tl.exp(x - m_new[:, None]), axis=1) + + m_i = m_new + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i * BLOCK_SIZE_N + mask_N = offsets_N < N + + x = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, + mask=mask_M[:, None] & mask_N[None, :], other=0.) + + y = tl.exp(x - m_i[:, None]) / s_i[:, None] + + tl.store(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, + y, mask=mask_M[:, None] & mask_N[None, :]) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py new file mode 100644 index 000000000..480cc116b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.softmax.softmax import Softmax \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py new file mode 100644 index 000000000..4c269d6b0 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -0,0 +1,69 @@ +from typing import List,Optional +import importlib +from math import prod + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Softmax(Operation): + def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None): + super().__init__(inputs, outputs, name) + self._attrs['inputs'] = inputs + self._attrs['outputs'] = outputs + self._attrs['enable_online'] = enable_online + + assert dim == len(inputs[0].shape)-1 + self._deduce_output_shape() + + def _deduce_output_shape(self): + M = prod(self._attrs['inputs'][0].shape[:-1]) + N = self._attrs['inputs'][0].shape[-1] + + self._attrs['M']= M + self._attrs['N']= N + + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] + + def _gen_constants(self): + const_metadata={} + const_metadata['M']= self._attrs['M'] + const_metadata['N']= self._attrs['N'] + const_metadata['stride_x0'] = self._attrs['N'] + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = self._attrs['N'] + const_metadata['stride_y1'] = 1 + + const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: + triton_kernel_name= 'online_softmax' if self._attrs['enable_online'] else 'softmax' + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_grid_softmax') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py index cc353eacb..7b0e6f21b 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py @@ -6,8 +6,8 @@ from tritontemplate.compiler.ops.bmm import Bmm from tritontemplate.compiler.compiler import compile_kernel -from tritontemplate.backend.cuda.bmm.bmm import bmm_bias as bmm_bias_kernel -from tritontemplate.backend.cuda.bmm.bmm import bmm as bmm_kernel +from tritontemplate.backend.cuda.bmm import bmm_bias as bmm_bias_kernel +from tritontemplate.backend.cuda.bmm import bmm as bmm_kernel def gen_bmm_bias(format, batch_size, M, N, K, stype): if format[0]=='r': @@ -99,7 +99,6 @@ def test_bmm_bias(format, batch_size, M, N, K, stype): test_kernel=gen_bmm_bias(format,batch_size,M,N,K,stype) test_kernel(a,b,bias,c_triton_aot) bmm_bias_kernel[grid](a,b,bias,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*bias.stride(),*c_triton_jit.stride(),64,64,64,False) - print(*b.stride()) torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=1e-2,rtol=1e-2) torch.testing.assert_close(c_triton_aot,c_torch,atol=1e-2,rtol=1e-2) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py index 664374aae..e156a8a6a 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -2,8 +2,8 @@ import pytest import triton -from tritontemplate.backend.cuda.gemm.gemm import gemm_bias as gemm_bias_kernel -from tritontemplate.backend.cuda.gemm.gemm import gemm as gemm_kernel +from tritontemplate.backend.cuda.gemm import gemm_bias as gemm_bias_kernel +from tritontemplate.backend.cuda.gemm import gemm as gemm_kernel from tritontemplate.compiler.base import IntImm, Tensor from tritontemplate.compiler.ops.gemm import Gemm from tritontemplate.compiler.compiler import compile_kernel diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py new file mode 100644 index 000000000..610cb2ac7 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py @@ -0,0 +1,68 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.softmax import softmax as kernel_softmax +from tritontemplate.backend.cuda.softmax import online_softmax as kernel_online_softmax +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.softmax import Softmax +from tritontemplate.compiler.compiler import compile_kernel + +def gen_softmax(is_online,M,N,stype): + A=Tensor(name='A',dtype=stype,shape=[M,N]) + B=Tensor(name='B',dtype=stype,shape=[M,N]) + softmax_op=Softmax( + inputs=[A], + dim=1, + enable_online=is_online, + outputs=[B], + ) + kernel = compile_kernel(softmax_op,device='cuda') + return kernel + +FORMATS = [ + 'softmax', + 'online_softmax', +] +MATRIX_PARAMS = [ + (128, 31, 'float32'), + (2, 31, 'float16'), + (128, 31, 'float32'), + (31, 2, 'float16'), + (128, 128, 'float32'), + (257, 512, 'float16'), + (512, 257, 'float32'), + (256, 256, 'float16'), + (511, 512, 'float32'), + (128, 255, 'float16'), + (256, 256, 'float32'), +] + +@pytest.mark.parametrize('M, N, stype', MATRIX_PARAMS) +@pytest.mark.parametrize('format', FORMATS) +def test_softmax(M, N, stype, format): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32=False + dtype=torch.float32 + else: + dtype=torch.float16 + a = torch.randn(M, N, dtype=dtype, device='cuda') + b_triton_jit = torch.empty_like(a) + b_triton_aot = torch.empty_like(a) + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + ) + + if format == 'online_softmax': + kernel_online_softmax[grid](a,b_triton_jit,M,N,*a.stride(),*b_triton_jit.stride(),64,64) + kernel = gen_softmax(True,M,N,stype) + kernel(a,b_triton_aot) + else: + kernel_softmax[grid](a,b_triton_jit,M,N,*a.stride(),*b_triton_jit.stride(),64,64) + kernel = gen_softmax(False,M,N,stype) + kernel(a,b_triton_aot) + + b_torch = torch.softmax(a, dim=-1) + torch.testing.assert_close(b_triton_jit, b_torch) + torch.testing.assert_close(b_triton_aot, b_torch) + From 8c1cb7d056d3d5c3feb7c7249a086e146b97387f Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 17 Jul 2025 11:31:26 +0800 Subject: [PATCH 21/26] [fix] softmax 4d test --- .../testing/cuda/test_softmax.py | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py index 610cb2ac7..952f62cad 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py @@ -8,12 +8,12 @@ from tritontemplate.compiler.ops.softmax import Softmax from tritontemplate.compiler.compiler import compile_kernel -def gen_softmax(is_online,M,N,stype): - A=Tensor(name='A',dtype=stype,shape=[M,N]) - B=Tensor(name='B',dtype=stype,shape=[M,N]) +def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim,stype): + A=Tensor(name='A',dtype=stype,shape=[batch,num_heads,seqlen,hidden_dim]) + B=Tensor(name='B',dtype=stype,shape=[batch,num_heads,seqlen,hidden_dim]) softmax_op=Softmax( inputs=[A], - dim=1, + dim=3, enable_online=is_online, outputs=[B], ) @@ -24,29 +24,35 @@ def gen_softmax(is_online,M,N,stype): 'softmax', 'online_softmax', ] -MATRIX_PARAMS = [ - (128, 31, 'float32'), - (2, 31, 'float16'), - (128, 31, 'float32'), - (31, 2, 'float16'), - (128, 128, 'float32'), - (257, 512, 'float16'), - (512, 257, 'float32'), - (256, 256, 'float16'), - (511, 512, 'float32'), - (128, 255, 'float16'), - (256, 256, 'float32'), -] +hidden_dim = [64, 128,] +num_heads = [8, 16,] +batch = [2, 4, 8] +seqlen = [63,66,127,129,255,257] +stype = ['float16', 'float32'] + +# Generate 10 random combinations +import random +test_cases = [] +for _ in range(10): + test_cases.append(( + random.choice(hidden_dim), + random.choice(num_heads), + random.choice(batch), + random.choice(seqlen), + random.choice(stype) + )) -@pytest.mark.parametrize('M, N, stype', MATRIX_PARAMS) +@pytest.mark.parametrize('hidden_dim, num_heads, batch, seqlen, stype', test_cases) @pytest.mark.parametrize('format', FORMATS) -def test_softmax(M, N, stype, format): +def test_softmax(batch,num_heads,seqlen,hidden_dim, stype, format): if stype == 'float32': torch.backends.cuda.matmul.allow_tf32=False dtype=torch.float32 else: dtype=torch.float16 - a = torch.randn(M, N, dtype=dtype, device='cuda') + a = torch.randn(batch,num_heads, seqlen, hidden_dim, dtype=dtype, device='cuda') + M=batch*seqlen*num_heads + N=hidden_dim b_triton_jit = torch.empty_like(a) b_triton_aot = torch.empty_like(a) grid = lambda META: ( @@ -54,15 +60,16 @@ def test_softmax(M, N, stype, format): ) if format == 'online_softmax': - kernel_online_softmax[grid](a,b_triton_jit,M,N,*a.stride(),*b_triton_jit.stride(),64,64) - kernel = gen_softmax(True,M,N,stype) + kernel_online_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,64,64) + kernel = gen_softmax(True,batch,num_heads,seqlen,hidden_dim, stype) kernel(a,b_triton_aot) else: - kernel_softmax[grid](a,b_triton_jit,M,N,*a.stride(),*b_triton_jit.stride(),64,64) - kernel = gen_softmax(False,M,N,stype) + kernel_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,64,64) + kernel = gen_softmax(False,batch,num_heads,seqlen,hidden_dim,stype) kernel(a,b_triton_aot) b_torch = torch.softmax(a, dim=-1) torch.testing.assert_close(b_triton_jit, b_torch) torch.testing.assert_close(b_triton_aot, b_torch) +test_softmax(12,8,144,128,'float32','online_softmax') \ No newline at end of file From 9dbb1318207b65ef629273b4e4d140e27f058fc6 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 17 Jul 2025 17:10:36 +0800 Subject: [PATCH 22/26] [feat] supported layernorm --- .../backend/cuda/layernorm/__init__.py | 1 + .../backend/cuda/layernorm/layernorm.py | 98 +++++++++++++++++++ .../compiler/ops/layernorm/__init__.py | 1 + .../compiler/ops/layernorm/layernorm.py | 83 ++++++++++++++++ .../compiler/ops/softmax/softmax.py | 3 +- .../tritontemplate/testing/cuda/test_bmm.py | 16 ++- .../tritontemplate/testing/cuda/test_gemm.py | 6 +- .../testing/cuda/test_layernorm.py | 84 ++++++++++++++++ .../testing/cuda/test_softmax.py | 63 +++++------- 9 files changed, 311 insertions(+), 44 deletions(-) create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py new file mode 100644 index 000000000..9bcd2adac --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py new file mode 100644 index 000000000..054680b5c --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py @@ -0,0 +1,98 @@ +import triton +import triton.language as tl + +def gen_grid_layernorm(M,BLOCK_SIZE_M): + grid = (triton.cdiv(M, BLOCK_SIZE_M),1,1) + return grid + +@triton.jit +def layernorm(x_ptr,y_ptr,M:tl.constexpr,N:tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr,stride_y0:tl.constexpr,stride_y1:tl.constexpr,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,eps:tl.constexpr=1e-5): + ''' + layernorm function for last dimension. + ''' + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + + pid = tl.program_id(axis=0) + block_start = pid*BLOCK_SIZE_M + offsets_M = tl.arange(0,BLOCK_SIZE_M) + block_start + mask_M = offsets_M < M + ex = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + varx = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :]* stride_x1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0).to(tl.float32) + ex += tl.sum(x, axis=1) + varx += tl.sum(x * x, axis=1) + + ex = ex / N + varx = varx / N + varx -= (ex * ex) + normized = tl.sqrt(varx + eps) + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :]* stride_x1 + y_offsets = offsets_M[:, None] * stride_y0 + offsets_N[None, :]* stride_y1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0).to(tl.float32) + y = (x - ex[:, None]) / normized[:, None] + tl.store(y_ptr + y_offsets, y, mask=mask) + + +@triton.jit +def layernorm_weight_bias( + x_ptr, + bias_ptr, + weight_ptr, + y_ptr, + M: tl.constexpr, + N: tl.constexpr, + stride_x0: tl.constexpr, + stride_x1: tl.constexpr, + stride_y0: tl.constexpr, + stride_y1: tl.constexpr, + stride_weight: tl.constexpr, + stride_bias: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + eps: tl.constexpr = 1e-5, +): + ''' + layernorm function with weight and bias for last dimension. + ''' + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_M + offsets_M = tl.arange(0, BLOCK_SIZE_M) + block_start + mask_M = offsets_M < M + ex = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + varx = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0) + ex += tl.sum(x, axis=1) + varx += tl.sum(x * x, axis=1) + + ex = ex / N + varx = varx / N + varx -= (ex * ex) + normized = tl.sqrt(varx + eps) + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1 + y_offsets = offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0) + weight = tl.load(weight_ptr + offsets_N*stride_weight, mask=mask_N, other=0.0).to(tl.float32) + bias = tl.load(bias_ptr + offsets_N*stride_bias, mask=mask_N, other=0.0).to(tl.float32) + y = (x - ex[:, None]) * weight[None, :] / normized[:, None] + bias[None, :] + tl.store(y_ptr + y_offsets, y, mask=mask) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py new file mode 100644 index 000000000..2e129c98f --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.layernorm.layernorm import Layernorm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py new file mode 100644 index 000000000..29f9bf483 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -0,0 +1,83 @@ +from typing import List,Optional +import importlib +from math import prod + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Layernorm(Operation): + def __init__( + self, + inputs: List[Tensor],# [x,bias(beta),weight(gamma)] + axis:int, + eps:float = 1e-5, + outputs: Optional[List[Tensor]] = None, + name: Optional[str] = None, + ) -> None: + super().__init__(inputs, outputs, name) + assert axis == len(inputs[0].shape)-1, f'only support last axis now' + self._attrs['axis'] = axis + self._attrs['eps'] = eps + self._attrs['inputs'] = inputs + self._attrs['outputs'] = outputs + + self._deduce_output_shape() + + def _deduce_output_shape(self): + M = prod(self._attrs['inputs'][0].shape[:-1]) + N = self._attrs['inputs'][0].shape[-1] + self._attrs['M'] = M + self._attrs['N'] = N + self._attrs['with_weight_bias']=len(self._attrs['inputs']) == 3 + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] + + def _gen_constants(self): + const_metadata={} + const_metadata['M']= self._attrs['M'] + const_metadata['N']= self._attrs['N'] + const_metadata['stride_x0'] = self._attrs['N'] + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = self._attrs['N'] + const_metadata['stride_y1'] = 1 + const_metadata['eps'] = self._attrs['eps'] + + const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + + if self._attrs['with_weight_bias']: + const_metadata['stride_weight'] = 1 + const_metadata['stride_bias'] = 1 + + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self, target_name, workdir,enable_tf32)->TritonExecutor: + triton_kernel_name= 'layernorm_weight_bias' if self._attrs['with_weight_bias'] else 'layernorm' + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_grid_layernorm') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) + + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index 4c269d6b0..4d662178e 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -18,11 +18,12 @@ class Softmax(Operation): def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None): super().__init__(inputs, outputs, name) + assert dim == len(inputs[0].shape)-1, f'only support last axis now' + self._attrs['dim'] = dim self._attrs['inputs'] = inputs self._attrs['outputs'] = outputs self._attrs['enable_online'] = enable_online - assert dim == len(inputs[0].shape)-1 self._deduce_output_shape() def _deduce_output_shape(self): diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py index 7b0e6f21b..e7c1e6b09 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py @@ -99,8 +99,12 @@ def test_bmm_bias(format, batch_size, M, N, K, stype): test_kernel=gen_bmm_bias(format,batch_size,M,N,K,stype) test_kernel(a,b,bias,c_triton_aot) bmm_bias_kernel[grid](a,b,bias,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*bias.stride(),*c_triton_jit.stride(),64,64,64,False) - torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=1e-2,rtol=1e-2) - torch.testing.assert_close(c_triton_aot,c_torch,atol=1e-2,rtol=1e-2) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + + torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=atol,rtol=rtol) + torch.testing.assert_close(c_triton_aot,c_torch,atol=atol,rtol=rtol) @pytest.mark.parametrize('format', FORMATS) @pytest.mark.parametrize( @@ -137,6 +141,10 @@ def test_bmm(format, batch_size, M, N, K, stype): bmm_kernel[grid](a,b,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*c_triton_jit.stride(),64,64,64,False) kernel=gen_bmm(format,batch_size,M,N,K,stype) kernel(a,b,c_triton_aot) - torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=1e-2,rtol=1e-2) - torch.testing.assert_close(c_triton_aot,c_torch,atol=1e-2,rtol=1e-2) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + + torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=atol,rtol=rtol) + torch.testing.assert_close(c_triton_aot,c_torch,atol=atol,rtol=rtol) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py index e156a8a6a..07aaebbed 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -181,6 +181,8 @@ def test_gemm_relu(format, M, N, K, stype): kernel = gen_gemm(format, M, N, K, stype) kernel(A, B, c_triton_aot) - assert torch.allclose(c_triton_aot, c_triton_jit, atol=1e-2, rtol=1e-2) - assert torch.allclose(pytorch_result, c_triton_jit, atol=1e-2, rtol=1e-2) + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + assert torch.allclose(c_triton_aot, c_triton_jit, atol=atol, rtol=rtol) + assert torch.allclose(pytorch_result, c_triton_jit, atol=atol, rtol=rtol) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py new file mode 100644 index 000000000..766725ab9 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py @@ -0,0 +1,84 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.layernorm import layernorm as kernel_layernorm +from tritontemplate.backend.cuda.layernorm import layernorm_weight_bias as kernel_layernorm_weight_bias +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.layernorm import Layernorm +from tritontemplate.compiler.compiler import compile_kernel + +def gen_layernorm(with_weight_bias,batch,seq_len,hidden_size,stype): + X=Tensor(name='X',shape=(batch,seq_len,hidden_size),dtype=stype) + Y=Tensor(name='Y',shape=(batch,seq_len,hidden_size),dtype=stype) + if with_weight_bias: + W=Tensor(name='W',shape=(hidden_size,),dtype=stype) + B=Tensor(name='B',shape=(hidden_size,),dtype=stype) + op=Layernorm( + inputs=[X,W,B], + outputs=[Y], + axis=2, + eps=1e-5) + else: + op=Layernorm( + inputs=[X], + outputs=[Y], + axis=2, + eps=1e-5) + return op + +MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 32, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), +] +FORMATS = ['layernorm','layernorm_weight_bias'] + +@pytest.mark.parametrize('batch,seq_len,hidden_size,stype',MATRIX_PARAMS) +@pytest.mark.parametrize('format',FORMATS) +def test_layernorm(batch,seq_len,hidden_size,stype,format): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + x = torch.randn(batch,seq_len,hidden_size,dtype=dtype,device='cuda') + y_triton_jit = torch.empty(batch,seq_len,hidden_size,dtype=dtype,device='cuda') + y_triton_aot = torch.empty(batch,seq_len,hidden_size,dtype=dtype,device='cuda') + + M= batch*seq_len + N = hidden_size + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + ) + + if format == 'layernorm': + y_torch=torch.nn.functional.layer_norm(x,(N,),eps=1e-5) + kernel_layernorm[grid](x,y_triton_jit,M,N,N,1,N,1,64,64,1e-5) + kernel = gen_layernorm(False,batch,seq_len,hidden_size,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x,y_triton_aot) + + else: + weight = torch.randn(hidden_size,dtype=dtype,device='cuda') + bias = torch.randn(hidden_size,dtype=dtype,device='cuda') + y_torch = torch.nn.functional.layer_norm(x,(N,),weight,bias,eps=1e-5) + kernel_layernorm_weight_bias[grid](x,bias,weight,y_triton_jit,M,N,N,1,N,1,1,1,64,64,1e-5) + kernel = gen_layernorm(True,batch,seq_len,hidden_size,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x,bias,weight,y_triton_aot) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + torch.testing.assert_close(y_triton_jit,y_torch,atol=atol,rtol=rtol) + torch.testing.assert_close(y_triton_aot,y_torch,atol=atol,rtol=rtol) + diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py index 952f62cad..9dd63cf44 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py @@ -8,9 +8,9 @@ from tritontemplate.compiler.ops.softmax import Softmax from tritontemplate.compiler.compiler import compile_kernel -def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim,stype): - A=Tensor(name='A',dtype=stype,shape=[batch,num_heads,seqlen,hidden_dim]) - B=Tensor(name='B',dtype=stype,shape=[batch,num_heads,seqlen,hidden_dim]) +def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim): + A=Tensor(name='A',dtype='float16',shape=[batch,num_heads,seqlen,hidden_dim]) + B=Tensor(name='B',dtype='float32',shape=[batch,num_heads,seqlen,hidden_dim]) softmax_op=Softmax( inputs=[A], dim=3, @@ -24,52 +24,41 @@ def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim,stype): 'softmax', 'online_softmax', ] -hidden_dim = [64, 128,] -num_heads = [8, 16,] -batch = [2, 4, 8] -seqlen = [63,66,127,129,255,257] -stype = ['float16', 'float32'] +MATRIX_PARAMS = [ + (128, 16, 8, 255), + (64, 8, 8, 255), + (64, 16, 2, 66), + (128, 16, 8, 257), + (128, 8, 4, 127), + (128, 8, 8, 63), + (64, 8, 4, 129), + (128, 8, 4, 255), + (64, 8, 4, 63), + (64, 8, 2, 255) + ] -# Generate 10 random combinations -import random -test_cases = [] -for _ in range(10): - test_cases.append(( - random.choice(hidden_dim), - random.choice(num_heads), - random.choice(batch), - random.choice(seqlen), - random.choice(stype) - )) - -@pytest.mark.parametrize('hidden_dim, num_heads, batch, seqlen, stype', test_cases) +@pytest.mark.parametrize('hidden_dim, num_heads, batch, seqlen', MATRIX_PARAMS) @pytest.mark.parametrize('format', FORMATS) -def test_softmax(batch,num_heads,seqlen,hidden_dim, stype, format): - if stype == 'float32': - torch.backends.cuda.matmul.allow_tf32=False - dtype=torch.float32 - else: - dtype=torch.float16 - a = torch.randn(batch,num_heads, seqlen, hidden_dim, dtype=dtype, device='cuda') +def test_softmax(batch,num_heads,seqlen,hidden_dim, format): + + a = torch.randn(batch,num_heads, seqlen, hidden_dim, dtype=torch.float16, device='cuda') M=batch*seqlen*num_heads N=hidden_dim - b_triton_jit = torch.empty_like(a) - b_triton_aot = torch.empty_like(a) + b_triton_jit = torch.empty(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') + b_triton_aot = torch.empty(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') grid = lambda META: ( triton.cdiv(M, META['BLOCK_SIZE_M']), ) if format == 'online_softmax': kernel_online_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,64,64) - kernel = gen_softmax(True,batch,num_heads,seqlen,hidden_dim, stype) + kernel = gen_softmax(True,batch,num_heads,seqlen,hidden_dim) kernel(a,b_triton_aot) else: kernel_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,64,64) - kernel = gen_softmax(False,batch,num_heads,seqlen,hidden_dim,stype) + kernel = gen_softmax(False,batch,num_heads,seqlen,hidden_dim) kernel(a,b_triton_aot) - b_torch = torch.softmax(a, dim=-1) - torch.testing.assert_close(b_triton_jit, b_torch) - torch.testing.assert_close(b_triton_aot, b_torch) - -test_softmax(12,8,144,128,'float32','online_softmax') \ No newline at end of file + b_torch = torch.softmax(a, dim=-1).to(torch.float32) + torch.testing.assert_close(b_triton_jit, b_torch,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(b_triton_aot, b_torch,atol=1e-2,rtol=1e-2) From bbaa6c9f8898725c3b024abd5bc4635337430272 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 17 Jul 2025 21:01:53 +0800 Subject: [PATCH 23/26] [feat] supported transpose --- .../backend/cuda/transpose/__init__.py | 2 + .../backend/cuda/transpose/transpose_0213.py | 46 ++++++++ .../backend/cuda/transpose/transpose_10.py | 36 ++++++ .../python/tritontemplate/compiler/base.py | 8 +- .../tritontemplate/compiler/ops/bmm/bmm.py | 2 - .../tritontemplate/compiler/ops/gemm/gemm.py | 2 - .../compiler/ops/layernorm/layernorm.py | 2 - .../compiler/ops/softmax/softmax.py | 2 - .../compiler/ops/transpose/__init__.py | 1 + .../compiler/ops/transpose/transpose.py | 108 ++++++++++++++++++ .../testing/cuda/test_transpose.py | 99 ++++++++++++++++ 11 files changed, 295 insertions(+), 13 deletions(-) create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py create mode 100644 external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py create mode 100644 external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py create mode 100644 external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py new file mode 100644 index 000000000..da1c07b90 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py @@ -0,0 +1,2 @@ +from tritontemplate.backend.cuda.transpose.transpose_10 import transpose_10,gen_grid_transpose_10 +from tritontemplate.backend.cuda.transpose.transpose_0213 import transpose_0213,gen_grid_transpose_0213 diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py new file mode 100644 index 000000000..87cfe56b5 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py @@ -0,0 +1,46 @@ +import triton +import triton.language as tl + +def gen_grid_transpose_0213(D0, D1, D2, D3, BLOCK_SIZE_D1, BLOCK_SIZE_D2): + """ + Generates the grid for a transpose kernel. + """ + return (D0 * D3, triton.cdiv(D2, BLOCK_SIZE_D2), triton.cdiv(D1, BLOCK_SIZE_D1)) + +#TODO: rewrite for contiguous D3 reading and storing +@triton.jit +def transpose_0213(x, y, + D0:tl.constexpr, D1:tl.constexpr, D2:tl.constexpr, D3:tl.constexpr, + stride_x0:tl.constexpr, stride_x1:tl.constexpr, stride_x2:tl.constexpr, stride_x3:tl.constexpr, + stride_y0:tl.constexpr, stride_y1:tl.constexpr, stride_y2:tl.constexpr, stride_y3:tl.constexpr, + BLOCK_SIZE_D1:tl.constexpr,BLOCK_SIZE_D2:tl.constexpr): + """ + Transpose a matrix in the 0213 layout. + """ + pid_d0d3 = tl.program_id(0) + pid_d2_chunk = tl.program_id(1) + pid_d1_chunk = tl.program_id(2) + + pid_d0 = pid_d0d3 // D3 + pid_d3 = pid_d0d3 % D3 + + offs_d1 = pid_d1_chunk * BLOCK_SIZE_D1 + tl.arange(0, BLOCK_SIZE_D1) + offs_d2 = pid_d2_chunk * BLOCK_SIZE_D2 + tl.arange(0, BLOCK_SIZE_D2) + + x_ptr = x + pid_d0 * stride_x0 + \ + offs_d1[:, None] * stride_x1 + \ + offs_d2[None, :] * stride_x2 + \ + pid_d3 * stride_x3 + + y_ptr = y + pid_d0 * stride_y0 + \ + offs_d2[:, None] * stride_y1 + \ + offs_d1[None, :] * stride_y2 + \ + pid_d3 * stride_y3 + + + mask_x = (offs_d1[:, None] < D1) & (offs_d2[None, :] < D2) + mask_y = (offs_d2[:, None] < D2) & (offs_d1[None, :] < D1) + + tile = tl.load(x_ptr, mask=mask_x, other=0.0) + transposed_tile = tl.trans(tile) + tl.store(y_ptr, transposed_tile, mask=mask_y) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py new file mode 100644 index 000000000..44187666b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py @@ -0,0 +1,36 @@ +import triton +import triton.language as tl + +def gen_grid_transpose_10(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + Generates the grid for a GEMM kernel. + """ + return ( + triton.cdiv(M, BLOCK_SIZE_M), + triton.cdiv(N, BLOCK_SIZE_N), + 1) + +@triton.jit +def transpose_10( + x_ptr, y_ptr, + M: tl.constexpr, N: tl.constexpr, + stride_x0:tl.constexpr,stride_x1:tl.constexpr, + stride_y0:tl.constexpr,stride_y1:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + block_start_m = pid_m * BLOCK_SIZE_M + block_start_n = pid_n * BLOCK_SIZE_N + offset_M = tl.arange(0, BLOCK_SIZE_M) + block_start_m + offset_N = tl.arange(0, BLOCK_SIZE_N) + block_start_n + + input_offsets = (offset_M[:, None] * stride_x0 + offset_N[None, :]) * stride_x1 + mask_M = offset_M < M + mask_N = offset_N < N + mask = mask_M[:, None] & mask_N[None, :] + + x = tl.load(x_ptr + input_offsets, mask=mask) + output_offsets = (offset_N[:, None] * stride_y0 + offset_M[None, :]) * stride_y1 + output_mask = mask_M[None, :] & mask_N[:, None] + x = tl.trans(x) + tl.store(y_ptr + output_offsets, x, mask=output_mask) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py index c3bc5bf70..8a69f30df 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/base.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -85,15 +85,13 @@ class Operation(BaseType): def __init__( self, inputs: List[BaseType], - outputs: Optional[List[BaseType]], + outputs: Optional[List[BaseType]] = None, name: Optional[str] = None, ) -> None: super().__init__() self._attrs['inputs'] = inputs - if name is not None: - self._attrs['name'] = name - if outputs is not None: - self._attrs['outputs'] = outputs + self._attrs['outputs'] = outputs + self._attrs['name'] = name @property def name(self) -> Optional[str]: diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index 13ac1b531..3563d6e49 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -27,8 +27,6 @@ def __init__( super().__init__(inputs, outputs,name) self.layout = layout self.is_bias = is_bias - self._attrs['inputs'] = inputs - self._attrs['outputs'] = outputs self._deduce_output_shape() def _deduce_output_shape(self): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 5a8444237..6b88d9c48 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -35,8 +35,6 @@ def __init__( self.layout = layout self.is_bias= is_bias self._attrs['activation'] = activation - self._attrs['inputs'] = inputs - self._attrs['outputs'] = outputs self._deduce_output_shape() diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index 29f9bf483..68aec0504 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -28,8 +28,6 @@ def __init__( assert axis == len(inputs[0].shape)-1, f'only support last axis now' self._attrs['axis'] = axis self._attrs['eps'] = eps - self._attrs['inputs'] = inputs - self._attrs['outputs'] = outputs self._deduce_output_shape() diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index 4d662178e..3fd19fc2d 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -20,8 +20,6 @@ def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outpu super().__init__(inputs, outputs, name) assert dim == len(inputs[0].shape)-1, f'only support last axis now' self._attrs['dim'] = dim - self._attrs['inputs'] = inputs - self._attrs['outputs'] = outputs self._attrs['enable_online'] = enable_online self._deduce_output_shape() diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py new file mode 100644 index 000000000..aef6dc273 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.transpose.transpose import Transpose \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py new file mode 100644 index 000000000..06a303668 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -0,0 +1,108 @@ +from typing import List,Optional +import importlib + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_supported_methods = ['10','0213'] + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Transpose(Operation): + def __init__(self, + inputs: List[Tensor], + method: str, + outputs: Optional[List[Tensor]] = None, + name: Optional[str] = None): + super().__init__(inputs, outputs, name) + assert method in _supported_methods, f"Unsupported method {method}" + self._attrs['method'] = method + + self._deduce_output_shape() + + def _deduce_output_shape(self): + input_shape = self._attrs['inputs'][0].shape + output_shape = [] + for i in self._attrs['method']: + output_shape.append(input_shape[int(i)]) + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(output_shape, self._attrs['inputs'][0].dtype)] + else: + assert self._attrs['outputs'][0].shape == output_shape, f"Transpose op output shape {self._attrs['outputs'][0].shape} does not match expected shape {output_shape}" + + def _gen_constants_10(self): + const_metadata={} + M,N=self._attrs['inputs'][0].shape + + const_metadata['M'] = M + const_metadata['N'] = N + const_metadata['stride_x0'] = N + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = M + const_metadata['stride_y1'] = 1 + + const_metadata['BLOCK_SIZE_M'] = self._block_size(M) + const_metadata['BLOCK_SIZE_N'] = self._block_size(N) + + return const_metadata + + def _gen_grid_10(self,target_name,const_metadata): + gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_10') + return gen_grid(const_metadata['M'],const_metadata['N'],const_metadata['BLOCK_SIZE_M'],const_metadata['BLOCK_SIZE_N']) + + def _gen_constants_0213(self): + const_metadata={} + D0,D1,D2,D3=self._attrs['inputs'][0].shape + const_metadata['D0'] = D0 + const_metadata['D1'] = D1 + const_metadata['D2'] = D2 + const_metadata['D3'] = D3 + + const_metadata['stride_x0'],const_metadata['stride_x1'],const_metadata['stride_x2'],const_metadata['stride_x3'] = shape2stride(self._attrs['inputs'][0].shape) + const_metadata['stride_y0'],const_metadata['stride_y1'],const_metadata['stride_y2'],const_metadata['stride_y3']= shape2stride(self._attrs['outputs'][0].shape) + + const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) + const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) + + return const_metadata + + def _gen_grid_0213(self,target_name,const_metadata): + gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_0213') + return gen_grid(const_metadata['D0'],const_metadata['D1'],const_metadata['D2'],const_metadata['D3'],const_metadata['BLOCK_SIZE_D1'],const_metadata['BLOCK_SIZE_D2']) + + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + + def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: + triton_kernel_name= 'transpose_' + self._attrs['method'] + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),triton_kernel_name) + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + exec_metadata=self._gen_exec_metadata() + + if self._attrs['method'] == '10': + constants=self._gen_constants_10() + exec_grid = self._gen_grid_10(target_name,constants) + elif self._attrs['method'] == '0213': + constants=self._gen_constants_0213() + exec_grid = self._gen_grid_0213(target_name,constants) + else: + raise ValueError(f"Unsupported method {self._attrs['method']}") + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py new file mode 100644 index 000000000..f967cb36b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py @@ -0,0 +1,99 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.transpose import transpose_10 as kernel_transpose_10 +from tritontemplate.backend.cuda.transpose import transpose_0213 as kernel_transpose_0213 +from tritontemplate.backend.cuda.transpose import gen_grid_transpose_10, gen_grid_transpose_0213 +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.transpose import Transpose +from tritontemplate.compiler.compiler import compile_kernel + +FORMATS = [ + 'transpose_10', + 'transpose_0213', +] + +MATRIX_PARAMS = [ + (256,128,'float16'), + (255,257,'float32'), + (127,129,'float16'), + (2304,768,'float16'), +] + +TENSOR4D_PARAMS = [ + (16, 64, 128, 32, 'float16'), + (4, 127, 255, 63, 'float32'), + (8, 256, 512, 64, 'float16'), + (8, 1024,8,96, 'float16'), +] + +def gen_transpose_10(M,N,stype): + X = Tensor([M, N], stype) + Y = Tensor([N, M], stype) + op = Transpose([X], '10', [Y]) + return op + +@pytest.mark.parametrize( + 'M, N, stype', + MATRIX_PARAMS +) +def test_transpose10(M, N, stype): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + x = torch.randn(M, N, dtype=dtype, device='cuda') + y_triton_jit = torch.empty(N, M, dtype=dtype, device='cuda') + y_triton_aot = torch.empty(N, M, dtype=dtype, device='cuda') + BLOCK_M=64 + BLOCK_N=64 + y=x.transpose(0,1).contiguous() + grid = gen_grid_transpose_10(M, N, BLOCK_M,BLOCK_N) + kernel_transpose_10[grid](x, y_triton_jit, M, N, *x.stride(),*y_triton_jit.stride(), BLOCK_M, BLOCK_N) + + kernel = gen_transpose_10(M,N,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x, y_triton_aot) + + torch.testing.assert_close(y_triton_jit, y) + torch.testing.assert_close(y_triton_aot, y) + +def gen_transpose_0213(D0,D1,D2,D3,stype): + X = Tensor([D0, D1, D2, D3], stype) + Y = Tensor([D0, D2, D1, D3], stype) + op = Transpose([X], '0213', [Y]) + return op + +@pytest.mark.parametrize( + 'D0, D1, D2, D3, stype', + TENSOR4D_PARAMS +) +def test_transpose0213(D0,D1,D2,D3, stype): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + x = torch.randn(D0,D1,D2,D3, dtype=dtype, device='cuda') + y_triton_jit = torch.empty(D0,D2,D1,D3, dtype=dtype, device='cuda') + y_triton_aot = torch.empty(D0,D2,D1,D3, dtype=dtype, device='cuda') + + BLOCK_D1=32 + BLOCK_D2=32 + + y = x.permute(0,2,1,3).contiguous() + grid = gen_grid_transpose_0213(D0, D1, D2, D3, BLOCK_D1, BLOCK_D2) + kernel_transpose_0213[grid](x, y_triton_jit, + D0, D1, D2, D3, + *x.stride(), *y_triton_jit.stride(), + BLOCK_D1, BLOCK_D2) + + kernel = gen_transpose_0213(D0,D1,D2,D3,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x, y_triton_aot) + + torch.testing.assert_close(y_triton_jit, y) + torch.testing.assert_close(y_triton_aot, y) + From 8f8bcf87c29bda56fb381c62933700d4c48133fc Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 18 Jul 2025 10:35:33 +0800 Subject: [PATCH 24/26] [bug] TritonTemplate align to byteir --- .../tritontemplate/compiler/ops/__init__.py | 6 +++++- .../compiler/ops/layernorm/layernorm.py | 6 +++--- .../compiler/ops/softmax/softmax.py | 4 +++- .../compiler/ops/transpose/transpose.py | 18 +++++++++--------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py index 3b1fb08a3..a421c9891 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py @@ -1 +1,5 @@ -from tritontemplate.compiler.ops.gemm import Gemm \ No newline at end of file +from tritontemplate.compiler.ops.gemm import Gemm +from tritontemplate.compiler.ops.bmm import Bmm +from tritontemplate.compiler.ops.transpose import Transpose +from tritontemplate.compiler.ops.softmax import Softmax +from tritontemplate.compiler.ops.layernorm import Layernorm diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index 68aec0504..4e2641884 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -19,14 +19,14 @@ class Layernorm(Operation): def __init__( self, inputs: List[Tensor],# [x,bias(beta),weight(gamma)] - axis:int, + axises:List[int], eps:float = 1e-5, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None, ) -> None: super().__init__(inputs, outputs, name) - assert axis == len(inputs[0].shape)-1, f'only support last axis now' - self._attrs['axis'] = axis + assert len(axises)==1 and axises[0] == len(inputs[0].shape)-1, f'Only last axis normalization is supported (axis={axis}, input shape={inputs[0].shape})' + self._attrs['axis'] = axises[0] self._attrs['eps'] = eps self._deduce_output_shape() diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index 3fd19fc2d..028947e61 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -32,7 +32,9 @@ def _deduce_output_shape(self): self._attrs['N']= N if self._attrs['outputs'] is None: - self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] + # Return float32 + self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype='float32')] + # self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] def _gen_constants(self): const_metadata={} diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py index 06a303668..1a37495c5 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -9,7 +9,7 @@ from tritontemplate.compiler.utils import get_warpsize from tritontemplate.backend.cuda.utils.utils import shape2stride -_supported_methods = ['10','0213'] +_supported_permutations = ['10','0213'] _exec_metadata = { 'num_warps': 4, @@ -19,19 +19,19 @@ class Transpose(Operation): def __init__(self, inputs: List[Tensor], - method: str, + permutation: str, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None): super().__init__(inputs, outputs, name) - assert method in _supported_methods, f"Unsupported method {method}" - self._attrs['method'] = method + assert permutation in _supported_permutations, f"Unsupported permutation {permutation}" + self._attrs['permutation'] = permutation self._deduce_output_shape() def _deduce_output_shape(self): input_shape = self._attrs['inputs'][0].shape output_shape = [] - for i in self._attrs['method']: + for i in self._attrs['permutation']: output_shape.append(input_shape[int(i)]) if self._attrs['outputs'] is None: self._attrs['outputs'] = [Tensor(output_shape, self._attrs['inputs'][0].dtype)] @@ -84,19 +84,19 @@ def _gen_exec_metadata(self): def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: - triton_kernel_name= 'transpose_' + self._attrs['method'] + triton_kernel_name= 'transpose_' + self._attrs['permutation'] triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),triton_kernel_name) signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) exec_metadata=self._gen_exec_metadata() - if self._attrs['method'] == '10': + if self._attrs['permutation'] == '10': constants=self._gen_constants_10() exec_grid = self._gen_grid_10(target_name,constants) - elif self._attrs['method'] == '0213': + elif self._attrs['permutation'] == '0213': constants=self._gen_constants_0213() exec_grid = self._gen_grid_0213(target_name,constants) else: - raise ValueError(f"Unsupported method {self._attrs['method']}") + raise ValueError(f"Unsupported permutation {self._attrs['permutation']}") num_warps=exec_metadata['num_warps'] num_stages=exec_metadata['num_stages'] From e6ed72a88eed7f9a918dc72ca1bd551a07bcdfb2 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 18 Jul 2025 10:36:00 +0800 Subject: [PATCH 25/26] [bug] TritonTemplate align to byteir --- .../cat/ir_translator/backend/tit_registry.py | 93 ++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py index da3746312..11e622c76 100644 --- a/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py +++ b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py @@ -36,4 +36,95 @@ def _dispatch_cat_gemm_rcr_bias_relu(op, inputs): @TRITONTemplateIRTranslator.register("cat.gemm_rcr_bias") def _dispatch_cat_gemm_rcr_bias(op, inputs): Y = tit_ops.Gemm(inputs=inputs, layout="rcr", is_bias=True) - return [Y] \ No newline at end of file + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr_relu") +def _dispatch_cat_gemm_rcr_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr", activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr") +def _dispatch_cat_gemm_rcr(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr_bias_relu") +def _dispatch_cat_gemm_rrr_bias_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr", is_bias=True, activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr_bias") +def _dispatch_cat_gemm_rrr_bias(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr_relu") +def _dispatch_cat_gemm_rrr_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr", activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr") +def _dispatch_cat_gemm_rrr(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rrr") +def _dispatch_cat_bmm_rrr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rrr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rrr_add") +def _dispatch_cat_bmm_rrr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rrr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rcr") +def _dispatch_cat_bmm_rcr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rcr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rcr_add") +def _dispatch_cat_bmm_rcr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rcr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_crr") +def _dispatch_cat_bmm_crr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="crr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_crr_add") +def _dispatch_cat_bmm_crr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="crr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_ccr") +def _dispatch_cat_bmm_ccr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="ccr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_ccr_add") +def _dispatch_cat_bmm_ccr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="ccr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.softmax") +def _dispatch_cat_softmax(op, inputs): + dim = mlir_attr_to_pyobj(op.attributes["dim"]) + Y = tit_ops.Softmax(inputs=inputs,dim=dim,enable_online=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.layernorm") +def _dispatch_cat_layernorm(op, inputs): + axises = mlir_attr_to_pyobj(op.attributes["axis"]) + eps = mlir_attr_to_pyobj(op.attributes["epsilon"]) + Y=tit_ops.Layernorm(inputs=inputs,axises=axises,eps=eps) + return [Y] + +@TRITONTemplateIRTranslator.register("mhlo.transpose") +def _dispatch_mhlo_transpose(op, inputs): + dims = mlir_attr_to_pyobj(op.attributes["permutation"]) + dims = dims.tolist() + dims_str = ''.join(map(str, dims)) + Y=tit_ops.Transpose(inputs=inputs,permutation=dims_str) + return [Y] From 1e1ac178632160f8f915d90a744fd845e7c09579 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 18 Jul 2025 13:58:51 +0800 Subject: [PATCH 26/26] [bug] fixed e2e nanogpt bug --- compiler/lib/Conversion/ToTIT/GenTITConfig.cpp | 2 +- .../dialects/cat/ir_translator/backend/tit_registry.py | 6 +++++- .../tritontemplate/backend/cuda/layernorm/layernorm.py | 2 +- .../tritontemplate/backend/cuda/softmax/softmax.py | 1 - .../tritontemplate/compiler/ops/softmax/softmax.py | 3 +-- .../python/tritontemplate/testing/cuda/test_bmm.py | 4 ++-- .../python/tritontemplate/testing/cuda/test_gemm.py | 4 ++-- .../tritontemplate/testing/cuda/test_layernorm.py | 4 ++-- .../python/tritontemplate/testing/cuda/test_softmax.py | 10 ++++++---- .../tritontemplate/testing/cuda/test_transpose.py | 4 ++-- 10 files changed, 22 insertions(+), 18 deletions(-) diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp index 93c793da1..7c72a0ede 100644 --- a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -68,7 +68,7 @@ static LogicalResult AttachTITConfigToAttr( if (kv.second.getAsInteger(0, val)) { return func.emitError("Invalid integer format for ") << kv.first(); } - if (val <= 0) { + if (val < 0) { return func.emitError("Value must be positive for ") << kv.first(); } titConfig[kv.first()] = opBuilder.getI32IntegerAttr(val); diff --git a/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py index 11e622c76..75dce58c2 100644 --- a/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py +++ b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py @@ -110,8 +110,12 @@ def _dispatch_cat_bmm_ccr_add(op, inputs): @TRITONTemplateIRTranslator.register("cat.softmax") def _dispatch_cat_softmax(op, inputs): + shaped_type = ir.ShapedType(op.result.type) + shape = shaped_type.shape + dtype = mlir_type_to_torch_str(shaped_type.element_type) + outputs=[Tensor(shape=shape,dtype=dtype,name='output_0')] dim = mlir_attr_to_pyobj(op.attributes["dim"]) - Y = tit_ops.Softmax(inputs=inputs,dim=dim,enable_online=True) + Y = tit_ops.Softmax(inputs=inputs,dim=dim,outputs=outputs,enable_online=True) return [Y] @TRITONTemplateIRTranslator.register("cat.layernorm") diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py index 054680b5c..9c7d8541b 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py @@ -46,8 +46,8 @@ def layernorm(x_ptr,y_ptr,M:tl.constexpr,N:tl.constexpr,stride_x0:tl.constexpr,s @triton.jit def layernorm_weight_bias( x_ptr, - bias_ptr, weight_ptr, + bias_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr, diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py index 8fea36cea..393cbe217 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py @@ -28,7 +28,6 @@ def softmax(x_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr,stride_x0:tl.constexp elem = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) max_ele = tl.maximum(max_ele,tl.max(elem, axis=1)) - y_sum = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) for i in tl.static_range(NUM_BLOCK_N): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index 028947e61..fb0373f82 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -12,7 +12,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 1, + 'num_stages': 3, } class Softmax(Operation): @@ -21,7 +21,6 @@ def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outpu assert dim == len(inputs[0].shape)-1, f'only support last axis now' self._attrs['dim'] = dim self._attrs['enable_online'] = enable_online - self._deduce_output_shape() def _deduce_output_shape(self): diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py index e7c1e6b09..83a45fc33 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py @@ -22,7 +22,7 @@ def gen_bmm_bias(format, batch_size, M, N, K, stype): C=Tensor(name='C',dtype=stype,shape=[batch_size,M,N]) bmm_op=Bmm( inputs=[A,B,Bias], - outputs=[C], + outputs=None, layout=format, is_bias=True ) @@ -41,7 +41,7 @@ def gen_bmm(format, batch_size, M, N, K, stype): C=Tensor(name='C',dtype=stype,shape=[batch_size,M,N]) bmm_op=Bmm( inputs=[A,B], - outputs=[C], + outputs=None, layout=format, is_bias=False ) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py index 07aaebbed..b1fdd93d9 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -21,7 +21,7 @@ def gen_gemm_bias(format, M, N, K, stype): C=Tensor(name='C',dtype=stype,shape=[M,N]) gemm_op=Gemm( inputs=[A,B,Bias], - outputs=[C], + outputs=None, layout=format, is_bias=True, activation='relu', @@ -41,7 +41,7 @@ def gen_gemm(format, M, N, K, stype): C=Tensor(name='C',dtype=stype,shape=[M,N]) gemm_op=Gemm( inputs=[A,B], - outputs=[C], + outputs=None, layout=format, is_bias=False, activation='relu', diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py index 766725ab9..aba410aae 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py @@ -16,13 +16,13 @@ def gen_layernorm(with_weight_bias,batch,seq_len,hidden_size,stype): B=Tensor(name='B',shape=(hidden_size,),dtype=stype) op=Layernorm( inputs=[X,W,B], - outputs=[Y], + outputs=None, axis=2, eps=1e-5) else: op=Layernorm( inputs=[X], - outputs=[Y], + outputs=None, axis=2, eps=1e-5) return op diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py index 9dd63cf44..ba8c7b188 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py @@ -9,7 +9,7 @@ from tritontemplate.compiler.compiler import compile_kernel def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim): - A=Tensor(name='A',dtype='float16',shape=[batch,num_heads,seqlen,hidden_dim]) + A=Tensor(name='A',dtype='float32',shape=[batch,num_heads,seqlen,hidden_dim]) B=Tensor(name='B',dtype='float32',shape=[batch,num_heads,seqlen,hidden_dim]) softmax_op=Softmax( inputs=[A], @@ -41,7 +41,7 @@ def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim): @pytest.mark.parametrize('format', FORMATS) def test_softmax(batch,num_heads,seqlen,hidden_dim, format): - a = torch.randn(batch,num_heads, seqlen, hidden_dim, dtype=torch.float16, device='cuda') + a = torch.randn(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') M=batch*seqlen*num_heads N=hidden_dim b_triton_jit = torch.empty(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') @@ -51,7 +51,7 @@ def test_softmax(batch,num_heads,seqlen,hidden_dim, format): ) if format == 'online_softmax': - kernel_online_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,64,64) + kernel_online_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,128,128) kernel = gen_softmax(True,batch,num_heads,seqlen,hidden_dim) kernel(a,b_triton_aot) else: @@ -60,5 +60,7 @@ def test_softmax(batch,num_heads,seqlen,hidden_dim, format): kernel(a,b_triton_aot) b_torch = torch.softmax(a, dim=-1).to(torch.float32) - torch.testing.assert_close(b_triton_jit, b_torch,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(b_triton_jit, b_triton_aot,atol=1e-2,rtol=1e-2) torch.testing.assert_close(b_triton_aot, b_torch,atol=1e-2,rtol=1e-2) + +test_softmax(16,256,256,128,'online_softmax') \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py index f967cb36b..159867512 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py @@ -31,7 +31,7 @@ def gen_transpose_10(M,N,stype): X = Tensor([M, N], stype) Y = Tensor([N, M], stype) - op = Transpose([X], '10', [Y]) + op = Transpose([X], '10', outputs=None) return op @pytest.mark.parametrize( @@ -63,7 +63,7 @@ def test_transpose10(M, N, stype): def gen_transpose_0213(D0,D1,D2,D3,stype): X = Tensor([D0, D1, D2, D3], stype) Y = Tensor([D0, D2, D1, D3], stype) - op = Transpose([X], '0213', [Y]) + op = Transpose([X], '0213', None) return op @pytest.mark.parametrize(