diff --git a/compiler/include/byteir/Conversion/Passes.h b/compiler/include/byteir/Conversion/Passes.h index 1bf88fff0..2d29ef4ea 100644 --- a/compiler/include/byteir/Conversion/Passes.h +++ b/compiler/include/byteir/Conversion/Passes.h @@ -35,6 +35,7 @@ #include "byteir/Conversion/ToLLVM/ToLLVM.h" #include "byteir/Conversion/ToLinalg/ToLinalg.h" #include "byteir/Conversion/ToPTX/ToPTX.h" +#include "byteir/Conversion/ToTIT/ToTIT.h" namespace mlir { diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index 92269e369..760c4542f 100755 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -155,6 +155,35 @@ def GenAITConfig : Pass<"gen-ait-config", "func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// ToTIT +//===----------------------------------------------------------------------===// + +def GenTITConfig : Pass<"gen-tit-config", "func::FuncOp"> { + let summary = "Generate TIT configuration"; + let constructor = "mlir::createGenTITConfigPass()"; + let options = [ + ListOption<"funcNames", "func-names", "std::string", + "names of all cat func for TIT backends.">, + ListOption<"titPtxPaths", "tit-ptx-paths", "std::string", + "paths to all TIT-generated .ptx files">, + ListOption<"smemsizeArgs", "smemsize-args", "std::string", + "smemsize args for TIT backends.">, + ListOption<"gridsizeXArgs", "gridsize-x-args", "std::string", + "gridsize x args for TIT backends.">, + ListOption<"gridsizeYArgs", "gridsize-y-args", "std::string", + "gridsize y args for TIT backends.">, + ListOption<"gridsizeZArgs", "gridsize-z-args", "std::string", + "gridsize z args for TIT backends.">, + ListOption<"blocksizeXArgs", "blocksize-x-args", "std::string", + "blocksize x args for TIT backends.">, + ListOption<"blocksizeYArgs", "blocksize-y-args", "std::string", + "blocksize y args for TIT backends.">, + ListOption<"blocksizeZArgs", "blocksize-z-args", "std::string", + "blocksize z args for TIT backends.">, + ]; +} + //===----------------------------------------------------------------------===// // ToByre //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Conversion/ToTIT/ToTIT.h b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h new file mode 100644 index 000000000..308023a26 --- /dev/null +++ b/compiler/include/byteir/Conversion/ToTIT/ToTIT.h @@ -0,0 +1,43 @@ +//===- ToTIT.h ------------------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTEIR_CONVERSION_TOTIT_H +#define BYTEIR_CONVERSION_TOTIT_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func +class ModuleOp; + +std::unique_ptr> +createGenTITConfigPass(ArrayRef funcNames = {""}, + ArrayRef titPtxPaths = {""}, + ArrayRef smemsizeArgs = {""}, + ArrayRef gridsizeXArgs = {""}, + ArrayRef gridsizeYArgs = {""}, + ArrayRef gridsizeZArgs = {""}, + ArrayRef blocksizeXArgs = {""}, + ArrayRef blocksizeYArgs = {""}, + ArrayRef blocksizeZArgs = {""}); + +} // namespace mlir + +#endif // BYTEIR_CONVERSION_TOTIT_H diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index a47af1137..4d6f77fa1 100755 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(HloToTensor) add_subdirectory(MemrefToByre) add_subdirectory(ToAce) add_subdirectory(ToAIT) +add_subdirectory(ToTIT) add_subdirectory(ToByre) add_subdirectory(ToGPU) add_subdirectory(ToHlo) diff --git a/compiler/lib/Conversion/ToTIT/CMakeLists.txt b/compiler/lib/Conversion/ToTIT/CMakeLists.txt new file mode 100644 index 000000000..2309c1fc3 --- /dev/null +++ b/compiler/lib/Conversion/ToTIT/CMakeLists.txt @@ -0,0 +1,14 @@ +add_byteir_conversion_library(ByteIRToTIT + GenTITConfig.cpp + + ADDITIONAL_HEADER_DIRS + ${BYTEIR_SRC_INCLUDE_DIR}/byteir/Conversion/ToTIT + + DEPENDS + ByteIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRBufferizationTransforms + ByteIRUtils + ) diff --git a/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp new file mode 100644 index 000000000..7c72a0ede --- /dev/null +++ b/compiler/lib/Conversion/ToTIT/GenTITConfig.cpp @@ -0,0 +1,133 @@ +//===- GenTITConfig.cpp ---------------------------------------*--- C++ -*-===// +// +// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +#include "byteir/Conversion/ToTIT/ToTIT.h" +#include "byteir/Dialect/Byre/Common.h" +#include "byteir/Dialect/mhlo/Transforms/HloFuser.h" +#include "byteir/Utils/FuncUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/Support/Debug.h" + +#include "../PassDetail.h" + +using namespace mlir; + +namespace { + +static LogicalResult AttachTITConfigToAttr( + func::FuncOp func, const std::string &titPtxPath, + const std::string &smemsizeArg, const std::string &gridsizeXArg, + const std::string &gridsizeYArg, const std::string &gridsizeZArg, + const std::string &blocksizeXArg, const std::string &blocksizeYArg, + const std::string &blocksizeZArg) { + + std::string device_name; + std::string byreKernelName; + if (titPtxPath.find(".ptx") != std::string::npos) { + device_name = "cuda"; + byreKernelName = "PTXOp"; + } + + if (device_name.empty() || byreKernelName.empty()) { + return func.emitError("Invalid device type for TIT configuration"); + } + addGenericFuncAttrs(func, byreKernelName); + + mlir::OpBuilder opBuilder(func); + llvm::StringMap titConfig; + + // Attach the Byre Tensor Info + titConfig["call_convention"] = opBuilder.getStringAttr("bare_ptr"); + titConfig["device"] = opBuilder.getStringAttr(device_name); + titConfig["device_file_name"] = opBuilder.getStringAttr(titPtxPath); + + llvm::StringMap gpuLaunchArgs = { + {"SharedMemorySize", smemsizeArg}, {"BlockSize.x", blocksizeXArg}, + {"BlockSize.y", blocksizeYArg}, {"BlockSize.z", blocksizeZArg}, + {"GridSize.x", gridsizeXArg}, {"GridSize.y", gridsizeYArg}, + {"GridSize.z", gridsizeZArg}}; + + for (auto &kv : gpuLaunchArgs) { + int val; + if (kv.second.getAsInteger(0, val)) { + return func.emitError("Invalid integer format for ") << kv.first(); + } + if (val < 0) { + return func.emitError("Value must be positive for ") << kv.first(); + } + titConfig[kv.first()] = opBuilder.getI32IntegerAttr(val); + } + + for (auto &kv : titConfig) { + func->setAttr(byre::getByrePrefix() + kv.first().str(), kv.second); + } + + return success(); +} + +struct GenTITConfigPass : public GenTITConfigBase { + GenTITConfigPass(ArrayRef funcNames, + ArrayRef titPtxPaths, + ArrayRef smemsizeArgs, + ArrayRef gridsizeXArgs, + ArrayRef gridsizeYArgs, + ArrayRef gridsizeZArgs, + ArrayRef blocksizeXArgs, + ArrayRef blocksizeYArgs, + ArrayRef blocksizeZArgs) + : GenTITConfigBase() { + this->funcNames = funcNames; + this->titPtxPaths = titPtxPaths; + this->smemsizeArgs = smemsizeArgs; + this->gridsizeXArgs = gridsizeXArgs; + this->gridsizeYArgs = gridsizeYArgs; + this->gridsizeZArgs = gridsizeZArgs; + this->blocksizeXArgs = blocksizeXArgs; + this->blocksizeYArgs = blocksizeYArgs; + this->blocksizeZArgs = blocksizeZArgs; + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + if (!func->hasAttr(getByteIRCatFusionAttrName())) + return; + for (size_t i = 0; i < funcNames.size(); ++i) + if (func.getSymName() == funcNames[i]) { + if (failed(AttachTITConfigToAttr( + func, titPtxPaths[i], smemsizeArgs[i], gridsizeXArgs[i], + gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i], + blocksizeYArgs[i], blocksizeZArgs[i]))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> mlir::createGenTITConfigPass( + ArrayRef funcNames, ArrayRef titPtxPaths, + ArrayRef smemsizeArgs, ArrayRef gridsizeXArgs, + ArrayRef gridsizeYArgs, ArrayRef gridsizeZArgs, + ArrayRef blocksizeXArgs, ArrayRef blocksizeYArgs, + ArrayRef blocksizeZArgs) { + return std::make_unique( + funcNames, titPtxPaths, smemsizeArgs, gridsizeXArgs, gridsizeYArgs, + gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs); +} diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index f53c4c7f7..8852b2ad8 100755 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -287,6 +287,121 @@ def _compile_cuda_with_ait( raise ValueError("module asm has be changed after byre serialization") +@register_byteir_compiler_backend(target="cuda_with_triton", device="cuda") +def _compile_cuda_with_triton( + compile_options: CompileOptions, +) -> None: + from .dialects.cat import IRProcessor + + target = "cuda" + module = compile_options.module + entry_func = compile_options.entry_func + gpu_arch = compile_options.gpu_arch + verbose = compile_options.verbose + name = compile_options.name + enable_tf32 = compile_options.enable_tf32 + parallelism = compile_options.parallelism + disable_byteir_ait_cache = compile_options.disable_byteir_ait_cache + + output_file_dir = compile_options.output_dir + output_file_prefix = compile_options.output_file_prefix + output_type = compile_options.output_type + useBarePtrCallConv = True # all tensor must have static shapes if True + + context = module.context + + entry_func_str = "entry-func={}".format(entry_func) + target_str = "target={}".format(target) + + with context: + PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation) + _print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ... + + processor = IRProcessor(name, + "./workspace", + enable_tf32=enable_tf32, + compile_parallelism=parallelism, + disable_byteir_ait_cache=disable_byteir_ait_cache, + verbose=verbose) + processor.module = module + + processor.preprocess_pass() + processor.cat_opt_pass(anchor_only=False) + + with context: + pm = PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op outline-cat-op})") + pm.run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Hlo Fusion Opt (with Cat):") if verbose else ... + # not generate ait lib .so for cat functions + processor.triton_opt_pass(output_file_dir) + module = processor.module + + with context: + PassManager.parse("builtin.module(linalg-tensor-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Linalg Tensor Opt:") if verbose else ... + with context: + if enable_tf32: + PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types enable-tf32 {}}})".format(entry_func_str)).run(processor.module.operation) + else: + PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Byre Tensor Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(byteir-bufferize-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After ByteIR Bufferize Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(linalg-memref-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Linalg Memref Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(scf-opt)").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After SCF Opt:") if verbose else ... + with context: + if useBarePtrCallConv: + PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true device-file-name="+ output_file_prefix + ".ptx" + "})").run(module.operation) + else: + PassManager.parse("builtin.module(gpu-opt{device-file-name=" + output_file_prefix + ".ptx" + "})").run(module.operation) + _print_verbose(processor.module, "// IR Dump After GPU Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(inline)").run(processor.module.operation) + PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation) + PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(processor.module.operation) + PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(processor.module.operation) + PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Set Space Opt:") if verbose else ... + with context: + PassManager.parse("builtin.module(byre-opt{append-arg-types " + entry_func_str + "})").run(processor.module.operation) + _print_verbose(processor.module, "// IR Dump After Byre Opt:") if verbose else ... + + # create device module + module_str = processor.module.operation.get_asm(print_generic_op_form=True) + device_module = ir.Module.parse(module_str, context) + with context: + if useBarePtrCallConv: + PassManager.parse("builtin.module(nvvm-codegen{use-bare-ptr-memref-call-conv=true " + f" gpu-arch={gpu_arch}" + "})").run(device_module.operation) + else: + PassManager.parse("builtin.module(nvvm-codegen{" + f" gpu-arch= {gpu_arch}" + "})").run(device_module.operation) + _print_verbose(device_module, "// IR Dump After NVVM Codegen:") if verbose else ... + # write to output device ptx + byteir.translate_to_ptx(device_module, output_file_dir + "/" + output_file_prefix, gpu_arch) + + # create host module + with context: + PassManager.parse("builtin.module(byre-host)").run(processor.module.operation) + PassManager.parse("builtin.module(remove-module-tag{attr-name=gpu.container_module})").run(module.operation) + PassManager.parse("builtin.module(remove-module-tag{attr-name=torch.debug_module_name})").run(module.operation) + _print_verbose(processor.module, "// IR Dump After Byre Host:") if verbose else ... + + output_host_mlir_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIR.value) + output_host_mlirbc_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIRBC.value) + # write to output host mlir file + with open(output_host_mlir_path, "w") as f: + f.write(module.operation.get_asm()) + if output_type is OutputType.MLIRBC: + byteir.serialize_byre(module, compile_options.byre_serial_version, output_host_mlirbc_path) + deserialized_module = byteir.deserialize_byre(open(output_host_mlirbc_path, "rb").read(), context) + if (module.operation.get_asm() != deserialized_module.operation.get_asm()): + raise ValueError("module asm has be changed after byre serialization") + + @register_byteir_compiler_backend(target="cpu", device="cpu") def _compile_cpu( compile_options: CompileOptions, diff --git a/compiler/python/byteir/dialects/cat/ir_processor.py b/compiler/python/byteir/dialects/cat/ir_processor.py index 23fbb7133..8bd48c846 100644 --- a/compiler/python/byteir/dialects/cat/ir_processor.py +++ b/compiler/python/byteir/dialects/cat/ir_processor.py @@ -4,6 +4,7 @@ from byteir.utils import get_gpu_type from .ait_cache import AITCache +from .tit_cache import TITCache from pathlib import Path from shutil import copyfile, copymode @@ -55,7 +56,8 @@ def __init__(self, enable_tf32 = False, compile_parallelism = 1, disable_byteir_ait_cache = False, - verbose = False): + verbose = False, + device = "cuda"): self.job_name = job_name self.workdir = workdir self.module = None @@ -65,11 +67,10 @@ def __init__(self, self.pool = multiprocessing.Pool(compile_parallelism) else: self.pool = None - self.byteir_cache = AITCache() self.verbose = verbose - self.disable_byteir_ait_cache = disable_byteir_ait_cache - if not disable_byteir_ait_cache: - self.byteir_cache.load_or_create_cache() + #TODO: signature rename + self.disable_byteir_cache = disable_byteir_ait_cache + self.device = device def _get_builder(self, func, subgraph_name, backend="ait"): assert func != None @@ -77,6 +78,9 @@ def _get_builder(self, func, subgraph_name, backend="ait"): if backend == "ait": from byteir.dialects.cat.ir_translator.ait_builder import AITBuilder return AITBuilder(func, workdir=self.workdir, subgraph_name=subgraph_name, enable_tf32=self.enable_tf32) + elif backend == "triton": + from byteir.dialects.cat.ir_translator.tit_builder import TRITONTBuilder + return TRITONTBuilder(func, workdir=self.workdir, subgraph_name=subgraph_name, enable_tf32=self.enable_tf32, device=self.device) else: raise RuntimeError(f"Unsupported runtime backend {backend}") @@ -100,6 +104,9 @@ def cat_opt_pass(self, anchor_only=False): return self.module def ait_opt_pass(self, output_dir): + self.byteir_cache = AITCache() + if not self.disable_byteir_cache: + self.byteir_cache.load_or_create_cache() funcNameArg = [] aitLibPathArg = [] @@ -141,7 +148,7 @@ def ait_opt_pass(self, output_dir): print("compilation finished in {}s".format(t_ed-t_st)) # update byteir cache - if not self.disable_byteir_ait_cache: + if not self.disable_byteir_cache: for key, lib_path in libs_to_add_to_cache.items(): self.byteir_cache.add(gpu_type, key, lib_path, override=False) self.byteir_cache._save() @@ -153,6 +160,99 @@ def ait_opt_pass(self, output_dir): _print_verbose(self.module, "// IR Dump After Gen AIT Config:") if self.verbose else ... return self.module + + def triton_opt_pass(self, output_dir): + + def decouple_triton_args(triton_args): + func_name_args = [] + ptx_path_args = [] + gridsize_x_args = [] + gridsize_y_args = [] + gridsize_z_args = [] + blocksize_x_args = [] + blocksize_y_args = [] + blocksize_z_args = [] + smemsize_args = [] + for func_name,ptx_path,gridsize,blocksize,smem_size in triton_args: + func_name_args.append(func_name) + ptx_path_args.append(ptx_path) + gridsize_x_args.append(str(gridsize[0])) + gridsize_y_args.append(str(gridsize[1])) + gridsize_z_args.append(str(gridsize[2])) + blocksize_x_args.append(str(blocksize)) + blocksize_y_args.append(str(1)) + blocksize_z_args.append(str(1)) + smemsize_args.append(str(smem_size)) + return func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args,smemsize_args + + self.pool=None + + self.byteir_cache = TITCache() + if not self.disable_byteir_cache: + self.byteir_cache.load_or_create_cache() + triton_args = [] + + gpu_type = get_gpu_type() + if gpu_type == None: + raise RuntimeError("No gpu found in this machine! cannot perform triton-opt-pass") + work_items = [] # work items of FuncOp + + for func in self.module.body.operations: + if BYTEIR_CAT_ATTR not in func.attributes: + continue + output_ptx_path = os.path.join(output_dir, func.name.value + ".ptx") + hash_str = func_hash_str(func, gpu_type) + # TODO: gridsize order need to be checked + # gridsize form (x,y,z), blocksize form (x,y,z) + cached_argv = self.byteir_cache.find(gpu_type, hash_str) + if cached_argv: + cache_ptx,gridsize,blocksize,smemsize = cached_argv + print(f"func {func.name.value} cache hit") + copyfile(cache_ptx, output_ptx_path) + copymode(cache_ptx, output_ptx_path) + triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize,smemsize)) + else: + work_items.append(func) + + # compile and benchmark + print("compile triton module using {} processes".format(min(len(work_items), self.compile_parallelism))) + print("\n".join([str(func) for func in work_items])) + t_st = time.time() + + new_args = [] + for func in work_items: + output_ptx_path = os.path.join(output_dir, func.name.value + ".ptx") + if self.pool: + new_args.append(self.pool.apply_async(_parallel_tit_compile, + (self.workdir, func, output_ptx_path, self.enable_tf32))) + else: + new_args.append(_parallel_tit_compile(self.workdir, func, output_ptx_path, self.enable_tf32)) + + if self.pool: + self.pool.close() + self.pool.join() + + for func,output_ptx_path,gridsize,blocksize,smemsize in new_args: + triton_args.append((func.name.value,output_ptx_path, gridsize, blocksize,smemsize)) + self.byteir_cache.load_or_create_cache() + self.byteir_cache.add(gpu_type, func_hash_str(func, gpu_type), (output_ptx_path, gridsize, blocksize,smemsize), override=False) + self.byteir_cache._save() + self.byteir_cache.close_cache() + + t_ed = time.time() + print("compilation finished in {}s".format(t_ed-t_st)) + + func_name_args, ptx_path_args, gridsize_x_args, gridsize_y_args, gridsize_z_args, blocksize_x_args, blocksize_y_args, blocksize_z_args,smemsize_args = decouple_triton_args(triton_args) + ptx_path_args= [os.path.split(path)[-1] for path in ptx_path_args] + + with self.module.context: + pm_str="builtin.module(func.func(gen-tit-config{{func-names={} tit-ptx-paths={} smemsize-args={} gridsize-x-args={} gridsize-y-args={} gridsize-z-args={} blocksize-x-args={} blocksize-y-args={} blocksize-z-args={}}}))".format(",".join(func_name_args), ",".join(ptx_path_args),",".join(smemsize_args), ",".join(gridsize_x_args), ",".join(gridsize_y_args), ",".join(gridsize_z_args), ",".join(blocksize_x_args), ",".join(blocksize_y_args), ",".join(blocksize_z_args)) + pm = PassManager.parse(pm_str) + pm.run(self.module.operation) + _print_verbose(self.module, "// IR Dump After Gen TIT Config:") if self.verbose else ... + + + return self.module def execute(self, inputs, func_name=None, backend="ait"): if func_name is None: @@ -184,3 +284,14 @@ def _parallel_ait_compile(workdir: str, func: FuncOp, output_lib_path, enable_tf builder.benchmark() copyfile(builder.ait_module_path, output_lib_path) copymode(builder.ait_module_path, output_lib_path) + +def _parallel_tit_compile(workdir: str, func: FuncOp, output_ptx_path, enable_tf32): + + # os.environ["CUDA_VISIBLE_DEVICES"]=str(os.getpid() % available_cuda_device_num) + from byteir.dialects.cat.ir_translator.tit_builder import TITBuilder + builder = TITBuilder(func, workdir=workdir, subgraph_name=func.name.value, enable_tf32=enable_tf32) + builder.compile() + blockSize,gridsize,smemsize=builder.blocksize,builder.gridsize,builder.smemsize + copyfile(builder.tit_module_path, output_ptx_path) + copymode(builder.tit_module_path, output_ptx_path) + return func,output_ptx_path,gridsize,blockSize,smemsize \ No newline at end of file diff --git a/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py new file mode 100644 index 000000000..75dce58c2 --- /dev/null +++ b/compiler/python/byteir/dialects/cat/ir_translator/backend/tit_registry.py @@ -0,0 +1,134 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from tritontemplate.compiler.base import Tensor +import tritontemplate.compiler.ops as tit_ops + +from ..translator import IRTranslator +from byteir import ir +from byteir.utils import mlir_attr_to_pyobj, mlir_type_to_torch_str + +class TRITONTemplateIRTranslator(IRTranslator): + pass + +@TRITONTemplateIRTranslator.register("mhlo.constant") +def _dispatch_mhlo_constant(op, inputs): + shaped_type = ir.ShapedType(op.result.type) + shape = shaped_type.shape + output = Tensor(shape, dtype=mlir_type_to_torch_str(shaped_type.element_type)) + return [output] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr_bias_relu") +def _dispatch_cat_gemm_rcr_bias_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr", is_bias=True, activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr_bias") +def _dispatch_cat_gemm_rcr_bias(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr_relu") +def _dispatch_cat_gemm_rcr_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr", activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rcr") +def _dispatch_cat_gemm_rcr(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rcr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr_bias_relu") +def _dispatch_cat_gemm_rrr_bias_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr", is_bias=True, activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr_bias") +def _dispatch_cat_gemm_rrr_bias(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr_relu") +def _dispatch_cat_gemm_rrr_relu(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr", activation="relu") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.gemm_rrr") +def _dispatch_cat_gemm_rrr(op, inputs): + Y = tit_ops.Gemm(inputs=inputs, layout="rrr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rrr") +def _dispatch_cat_bmm_rrr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rrr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rrr_add") +def _dispatch_cat_bmm_rrr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rrr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rcr") +def _dispatch_cat_bmm_rcr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rcr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_rcr_add") +def _dispatch_cat_bmm_rcr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="rcr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_crr") +def _dispatch_cat_bmm_crr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="crr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_crr_add") +def _dispatch_cat_bmm_crr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="crr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_ccr") +def _dispatch_cat_bmm_ccr(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="ccr") + return [Y] + +@TRITONTemplateIRTranslator.register("cat.bmm_ccr_add") +def _dispatch_cat_bmm_ccr_add(op, inputs): + Y = tit_ops.Bmm(inputs=inputs, layout="ccr", is_bias=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.softmax") +def _dispatch_cat_softmax(op, inputs): + shaped_type = ir.ShapedType(op.result.type) + shape = shaped_type.shape + dtype = mlir_type_to_torch_str(shaped_type.element_type) + outputs=[Tensor(shape=shape,dtype=dtype,name='output_0')] + dim = mlir_attr_to_pyobj(op.attributes["dim"]) + Y = tit_ops.Softmax(inputs=inputs,dim=dim,outputs=outputs,enable_online=True) + return [Y] + +@TRITONTemplateIRTranslator.register("cat.layernorm") +def _dispatch_cat_layernorm(op, inputs): + axises = mlir_attr_to_pyobj(op.attributes["axis"]) + eps = mlir_attr_to_pyobj(op.attributes["epsilon"]) + Y=tit_ops.Layernorm(inputs=inputs,axises=axises,eps=eps) + return [Y] + +@TRITONTemplateIRTranslator.register("mhlo.transpose") +def _dispatch_mhlo_transpose(op, inputs): + dims = mlir_attr_to_pyobj(op.attributes["permutation"]) + dims = dims.tolist() + dims_str = ''.join(map(str, dims)) + Y=tit_ops.Transpose(inputs=inputs,permutation=dims_str) + return [Y] diff --git a/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py new file mode 100644 index 000000000..e382c2a40 --- /dev/null +++ b/compiler/python/byteir/dialects/cat/ir_translator/tit_builder.py @@ -0,0 +1,154 @@ +# Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import random +import torch +import os +import numpy as np +from typing import Set + +from tritontemplate.compiler.base import Tensor, IntImm +from tritontemplate.compiler import compile_kernel + +from .backend.tit_registry import * + +from byteir.utils import mlir_type_to_torch_str, torch_dtype_from_str + +class TITBuilder: + # stores a graph + # use stores mlir.Value to ait.Tensor map + _value2tensor = None + _op2parent_block = None + _im_vals = None + + def __init__(self, func, workdir="./workspace", subgraph_name="model", enable_tf32=False,device="cuda"): + self.func = func + self._value2tensor = {} + self._op2parent_block = {} + self._im_vals = set() + self.inputs : list[Tensor] = [] + self.outputs : list[Tensor] = [] + + self.tit_module_path = None + self.tit_model = None + + self.subgraph_name = subgraph_name + self.ptx_name = subgraph_name + ".ptx" + self.workdir = workdir + self.enable_tf32 = enable_tf32 + self.test_name = "./" + subgraph_name + self.constants = {} + self.constant_idx = 0 + + self.device = device + + + # init arguments + for idx,i in enumerate(self.func.arguments): + shaped_type = ir.ShapedType(i.type) + shape = shaped_type.shape + dtype = mlir_type_to_torch_str(shaped_type.element_type) + self._value2tensor[i] = Tensor(shape=shape, dtype=dtype, name=f"input_tensor_{idx}") + self.inputs.append(self._value2tensor[i]) + + # Note: variable `lib_path` is constructed according to the code in compile_model() + os.makedirs(os.path.join(self.workdir, self.test_name), exist_ok=True) + ptx_path = os.path.join(self.workdir, self.test_name, self.ptx_name) + print("TIT module path {} for {}".format(ptx_path, self.ptx_name)) + self.tit_module_path = ptx_path + + def compile(self): + self._visit_block(self.func.entry_block) + assert self.tit_module_path is not None + assert self.tit_kernel is not None + + def _visit_op(self, op): + # analyze here + # call bt APIs to create tensor & op + inputs = list(self._lookup_tensor(i) for i in op.operands) + if op.operation.name == "func.return": + self._gen_tit_kernel(inputs) + return + if hasattr(op, "operands"): + outputs = TRITONTemplateIRTranslator.translate(op, inputs) + if op.operation.name == "mhlo.constant": + # TODO: need to support FP16 + shaped_type = ir.ShapedType(op.result.type) + outputs[0]._attrs["name"] = f"const_tensor_{self.constant_idx}" + np_array = mlir_attr_to_pyobj(op.attributes["value"]) + data = torch.from_numpy(np_array).contiguous().cuda() + data = data.to(torch_dtype_from_str(mlir_type_to_torch_str(shaped_type.element_type))) + self.constants[outputs[0]._attrs["name"]] = data + self.constant_idx += 1 + if op.operation.name != "mhlo.constant": + for value in op.results: + self._im_vals.add(value) + for output, value in zip(outputs, op.results): + self._value2tensor[value] = output + + for region in op.operation.regions: + for block in region.blocks: + self._visit_block(block) + + def _visit_block(self, block): + for i in block.operations: + self._op2parent_block[i] = block + self._visit_op(i) + + def _lookup_tensor(self, val): + # return a bt.graph.Tensor + assert val in self._value2tensor + return self._value2tensor[val] + + def _gen_tit_kernel(self, results): + idx = 0 + for out in results: + out._attrs["name"] = f"output_tensor_{idx}" + idx += 1 + assert len(results) == 1, "only support single cat op" + result=results[0] + self.tit_kernel = compile_kernel( + op=result, + device=self.device, + workdir=self.workdir, + enable_tf32=self.enable_tf32, + ) + # kernel rename + with open(self.tit_module_path, "w") as f: + f.write(self.tit_kernel.kernel_ptx(self.subgraph_name)) + self.gridsize = self.tit_kernel.gridsize + self.blocksize = self.tit_kernel.blocksize + self.smemsize = self.tit_kernel.smemsize + + def _gen_runtime_tensor(self, tensor: Tensor): + shape = tensor.shape() + rt_shape = [] + for s in shape: + if isinstance(s, IntImm): + rt_shape.append(s.value()) + dtype = torch_dtype_from_str(tensor.dtype()) + if len(rt_shape) == 0: + return torch.tensor(1).to(torch_dtype_from_str(tensor.dtype())).cuda() + if dtype == torch.bool: + return torch.randint(high=1, size=rt_shape, dtype=dtype, device="cuda") + elif dtype in [torch.int8, torch.int, torch.int16, torch.int32, torch.int64]: + return torch.randint(high=100, size=rt_shape, dtype=dtype, device="cuda") + else: + return torch.randn(*rt_shape, device="cuda").to(torch_dtype_from_str(tensor.dtype())) + + def execute(self, np_inputs, num_trials=1, benchmark=False): + raise NotImplementedError("TITBuilder.execute() is not implemented yet") + + def benchmark(self, num_trials=5): + raise NotImplementedError("TITBuilder.benchmark() is not implemented yet") diff --git a/compiler/python/byteir/dialects/cat/tit_cache.py b/compiler/python/byteir/dialects/cat/tit_cache.py new file mode 100644 index 000000000..8c28c8813 --- /dev/null +++ b/compiler/python/byteir/dialects/cat/tit_cache.py @@ -0,0 +1,108 @@ +import os +import re +import json +import fcntl +import tempfile + +from shutil import copyfile, copymode + +HOME_DIR = os.getenv("HOME") +if HOME_DIR is None: + HOME_DIR = tempfile.gettempdir() +DEFAULT_CACHE_DIR = os.path.join(HOME_DIR, ".byteir", "tit_cache") +CACHE_FILE_NAME = "tit_global_cache.json" +IDX_KEY = "byteir_tit_cache_auto_increment_idx" + +class TITCache: + def __init__(self, cache_dir = DEFAULT_CACHE_DIR) -> None: + self.idx = 0 # unique id of saved compiled .so + self.cache_dir = cache_dir + self.cache = { IDX_KEY : self.idx } # key: tit op hash str, value: relative path of compiled .so + self.fp = None + + def _open(self): + if not os.path.exists(os.path.join(self.cache_dir, CACHE_FILE_NAME)): + self.fp = open(os.path.join(self.cache_dir, CACHE_FILE_NAME), "w") + fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX) + self.cache = { IDX_KEY : self.idx } + json.dump(self.cache, self.fp, indent=2) + fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN) + self.fp.close() + self.fp = open(os.path.join(self.cache_dir, CACHE_FILE_NAME), "r+") + # print("try to acquire file lock...") + # acquire lock + fcntl.flock(self.fp.fileno(), fcntl.LOCK_EX) + # print("file lock acquired.") + + def _close(self): + # release lock + fcntl.flock(self.fp.fileno(), fcntl.LOCK_UN) + self.fp.close() + + def _load(self): + try: + self.cache = json.load(self.fp) + except json.decoder.JSONDecodeError: + self.cache = { IDX_KEY : self.idx } + + def _save(self): + self.fp.seek(0) + json.dump(self.cache, self.fp, indent=2) + + def sync_cache(self): + # sync cache with the files in cache dir + # remove invalid key-value pairs + max_idx = 0 + for gpu_type in self.cache: + if gpu_type == IDX_KEY: + continue + keys_to_check = list(self.cache[gpu_type].keys()) + for key in keys_to_check: + value = self.cache[gpu_type][key] + if not os.path.exists(os.path.join(self.cache_dir, value[0])): + self.cache[gpu_type].pop(key) + continue + if self.get_ptx_idx(value[0]) > max_idx: + max_idx = self.get_ptx_idx(value[0]) + self.idx = max_idx + 1 + self.cache[IDX_KEY] = self.idx + + def load_or_create_cache(self): + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir, exist_ok=True) + self._open() + self._load() + self.sync_cache() + + def add(self, gpu_type, key, argv, override = False): + if gpu_type not in self.cache: + self.cache[gpu_type] = {} + if override or key not in self.cache[gpu_type]: + value = "{:0>16}.ptx".format(self.idx) + self.idx += 1 + self.cache[IDX_KEY] = self.idx + self.cache[gpu_type][key] = (value, argv[1], argv[2],argv[3]) + ptx_path = argv[0] + copyfile(ptx_path, os.path.join(self.cache_dir, value)) + copymode(ptx_path, os.path.join(self.cache_dir, value)) + + def find(self, gpu_type, key): + if gpu_type not in self.cache: + return None + if key in self.cache[gpu_type]: + lib_path, gridsize, blocksize,smemsize = self.cache[gpu_type][key] + return os.path.join(self.cache_dir, lib_path), gridsize, blocksize,smemsize + else: + return None + + def get_ptx_idx(self, ptx_name): + try: + idx=re.search(r'(\d+)\.ptx', ptx_name).group(1) + return int(idx) + except: + print("invalid ptx name: {}".format(ptx_name)) + raise RuntimeError("invalid ptx name") + + def close_cache(self): + if self.fp != None: + self._close() \ No newline at end of file diff --git a/compiler/python/byteir/tools/compiler.py b/compiler/python/byteir/tools/compiler.py index 1c01e06bd..03ee2e7c3 100644 --- a/compiler/python/byteir/tools/compiler.py +++ b/compiler/python/byteir/tools/compiler.py @@ -11,7 +11,7 @@ parser.add_argument("--target", type=str, default="cuda", - choices=["cuda", "cuda_with_ait", "cpu"], + choices=["cuda", "cuda_with_ait", "cuda_with_triton", "cpu"], help="target device name") parser.add_argument("--gpu_arch", type=str, diff --git a/compiler/python/byteir/utils.py b/compiler/python/byteir/utils.py index 8d5dc1e4e..715779891 100644 --- a/compiler/python/byteir/utils.py +++ b/compiler/python/byteir/utils.py @@ -164,7 +164,7 @@ def detect_gpu_arch_with_nvidia_smi(): "sm_70": ["V100"], "sm_75": ["T4", "Quadro T2000"], "sm_80": ["PG509", "A100", "A10", "RTX 30", "A30", "RTX 40", "A16"], - "sm_90": ["H100"], + "sm_90": ["H100","H800"], } for sm, names in sm_names.items(): if any(name in stdout for name in names): diff --git a/external/TritonTemplate/.gitignore b/external/TritonTemplate/.gitignore new file mode 100644 index 000000000..8897298b9 --- /dev/null +++ b/external/TritonTemplate/.gitignore @@ -0,0 +1,146 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# tmp +tmp/ + +tags + +# macOS dir files +.DS_Store + +# PyCharm files +.idea + +# vscode +.vscode + +# vim temp files +*.swp diff --git a/external/TritonTemplate/README.md b/external/TritonTemplate/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/setup.py b/external/TritonTemplate/python/setup.py new file mode 100644 index 000000000..8b369e707 --- /dev/null +++ b/external/TritonTemplate/python/setup.py @@ -0,0 +1,59 @@ +import os +import shutil + +from setuptools import find_packages, setup + +CURRENT_DIR = os.path.dirname(__file__) +libinfo_py = os.path.join(CURRENT_DIR, "tritontemplate", "_libinfo.py") +libinfo = {} +with open(libinfo_py, "r") as f: + exec(f.read(), libinfo) +__version__ = libinfo["__version__"] + +def gen_file_list(srcs, f_cond): + file_list = [] + for src in srcs: + for root, _, files in os.walk(src): + value = [] + for file in files: + if f_cond(file): + path = os.path.join(root, file) + value.append(path.replace("tritontemplate/", "")) + file_list.extend(value) + return file_list + +def gen_backend_common_file_list(): + srcs = ["tritontemplate/backend"] + f_cond = lambda x: True if x.endswith(".py") else False + return gen_file_list(srcs, f_cond) + +def gen_utils_file_list(): + srcs = ["tritontemplate/utils"] + f_cond = lambda x: True if x.endswith(".py") else False + return gen_file_list(srcs, f_cond) + +def gen_compiler_file_list(): + srcs = ["tritontemplate/compiler"] + f_cond = lambda x: True if x.endswith(".py") else False + return gen_file_list(srcs, f_cond) + +setup_kwargs = {} +include_libs = True +wheel_include_libs = True + +setup( + name="tritontemplate", + version=__version__, + description="TritonTemplate: Make Flex Triton Templates for AI", + zip_safe=True, + install_requires=["torch>=2.1.0","triton"], + packages=find_packages(), + package_data={ + "tritontemplate": [] + + gen_utils_file_list() + + gen_backend_common_file_list() + + gen_compiler_file_list() + }, + python_requires=">=3.7, <4", + **setup_kwargs +) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/__init__.py b/external/TritonTemplate/python/tritontemplate/__init__.py new file mode 100644 index 000000000..a1bf64a27 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/__init__.py @@ -0,0 +1,5 @@ +import sys +from tritontemplate import backend,compiler,testing,utils +from tritontemplate._libinfo import __version__ + +__all__ = ["backend", "compiler", "testing", "utils"] diff --git a/external/TritonTemplate/python/tritontemplate/_libinfo.py b/external/TritonTemplate/python/tritontemplate/_libinfo.py new file mode 100644 index 000000000..45c421e85 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/_libinfo.py @@ -0,0 +1 @@ +__version__ = "dev0" \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py new file mode 100644 index 000000000..5f8c6587a --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py new file mode 100644 index 000000000..38a1813da --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -0,0 +1,168 @@ +import triton +import triton.language as tl + +def gen_grid_bmm(batch_size,M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + Generates the grid for a Batch GEMM kernel. + """ + return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) + +@triton.jit +def bmm_bias( + # Pointers to matrices + a_ptr, b_ptr, bias_ptr, c_ptr, + # Matrix dimensions + BATCH_SIZE: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + is_transpose_a: tl.constexpr,stride_a0: tl.constexpr, stride_a1: tl.constexpr, stride_a2: tl.constexpr, + is_transpose_b: tl.constexpr,stride_b0: tl.constexpr, stride_b1: tl.constexpr, stride_b2: tl.constexpr, + stride_bias0: tl.constexpr,stride_bias1: tl.constexpr,stride_bias2: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, stride_c2: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + enable_tf32: tl.constexpr, + ): + """ + Kernel for BMM + Bias. + if !is_transpose_a and !is_transpose_b : + A (B, M, K) @ B (B, K, N) + Bias (B, M, N) -> C (B, M, N) + """ + batch_id=tl.program_id(0) + pid=tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M) + offs_bn = pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N) + offs_k = tl.arange(0,BLOCK_SIZE_K) + + if is_transpose_a: + # A(B,K,M) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:] + stride_ak=stride_a1 + else: + # A(B,M,K) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_am[:,None]+stride_a2*offs_k[None,:] + stride_ak=stride_a2 + + if is_transpose_b: + # B(B,N,K) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:] + stride_bk=stride_b2 + else: + # B(B,K,N) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_k[:,None]+stride_b2*offs_bn[None,:] + stride_bk=stride_b1 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k C (B, M, N) + """ + batch_id=tl.program_id(0) + pid=tl.program_id(1) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M) + offs_bn = pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N) + offs_k = tl.arange(0,BLOCK_SIZE_K) + + if is_transpose_a: + # A(B,K,M) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:] + stride_ak=stride_a1 + else: + # A(B,M,K) + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_am[:,None]+stride_a2*offs_k[None,:] + stride_ak=stride_a2 + + if is_transpose_b: + # B(B,N,K) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:] + stride_bk=stride_b2 + else: + # B(B,K,N) + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_k[:,None]+stride_b2*offs_bn[None,:] + stride_bk=stride_b1 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k C(M,N) + """ + # _TF32_ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" + # DTYPE: tl.constexpr = k.dtype.element_ty + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + if is_transpose_a: + # A(K,M) + a_ptrs = a_ptr + offs_k[:, None] * stride_a0 + offs_am[None, :] * stride_a1 + stride_ak = stride_a0 + else: + # A(M,K) + a_ptrs = a_ptr + offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1 + stride_ak = stride_a1 + + if is_transpose_b: + # B(N,K) + b_ptrs = b_ptr + offs_bn[:, None] * stride_b0 + offs_k[None, :] * stride_b1 + stride_bk = stride_b1 + else: + # B(K,N) + b_ptrs = b_ptr + offs_k[:, None] * stride_b0 + offs_bn[None, :] * stride_b1 + stride_bk = stride_b0 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_am[None, :] < M) + else: + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + if is_transpose_b: + b_mask = (offs_bn[:, None] < N) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + else: + b_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_bn[None, :] < N) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + if is_transpose_a: + a = tl.trans(a) + if is_transpose_b: + b = tl.trans(b) + accumulator += tl.dot(a, b, allow_tf32=enable_tf32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if bias_ptr is not None: + offs_bias_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + bias_ptrs = bias_ptr + offs_bias_n * stride_bias0 + bias_vals = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0) + accumulator = accumulator + bias_vals[None, :] + + if ACTIVATION == 'relu': + accumulator = relu(accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def gemm( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides for matrices + is_transpose_a: tl.constexpr, stride_a0: tl.constexpr, stride_a1: tl.constexpr, + is_transpose_b: tl.constexpr, stride_b0: tl.constexpr, stride_b1: tl.constexpr, + stride_c0: tl.constexpr, stride_c1: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + ACTIVATION: tl.constexpr, # 'relu' or None + enable_tf32: tl.constexpr, +): + """ + Kernel for GEMM + ReLU. + A @ B -> C + This kernel can handle transpositions for A and B. + """ + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + if is_transpose_a: + # A(K,M) + a_ptrs = a_ptr + offs_k[:, None] * stride_a0 + offs_am[None, :] * stride_a1 + stride_ak = stride_a0 + else: + # A(M,K) + a_ptrs = a_ptr + offs_am[:, None] * stride_a0 + offs_k[None, :] * stride_a1 + stride_ak = stride_a1 + + if is_transpose_b: + # B(N,K) + b_ptrs = b_ptr + offs_bn[:, None] * stride_b0 + offs_k[None, :] * stride_b1 + stride_bk = stride_b1 + else: + # B(K,N) + b_ptrs = b_ptr + offs_k[:, None] * stride_b0 + offs_bn[None, :] * stride_b1 + stride_bk = stride_b0 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if is_transpose_a: + a_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_am[None, :] < M) + else: + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + a = tl.load(a_ptrs, mask=a_mask, other=0.0) + + if is_transpose_b: + b_mask = (offs_bn[:, None] < N) & (offs_k[None, :] + BLOCK_SIZE_K * k < K) + else: + b_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_bn[None, :] < N) + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + + if is_transpose_a: + a = tl.trans(a) + if is_transpose_b: + b = tl.trans(b) + accumulator += tl.dot(a, b, allow_tf32=enable_tf32) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if ACTIVATION == 'relu': + accumulator = relu(accumulator) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + stride_c1 * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py new file mode 100644 index 000000000..9bcd2adac --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py new file mode 100644 index 000000000..9c7d8541b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py @@ -0,0 +1,98 @@ +import triton +import triton.language as tl + +def gen_grid_layernorm(M,BLOCK_SIZE_M): + grid = (triton.cdiv(M, BLOCK_SIZE_M),1,1) + return grid + +@triton.jit +def layernorm(x_ptr,y_ptr,M:tl.constexpr,N:tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr,stride_y0:tl.constexpr,stride_y1:tl.constexpr,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,eps:tl.constexpr=1e-5): + ''' + layernorm function for last dimension. + ''' + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + + pid = tl.program_id(axis=0) + block_start = pid*BLOCK_SIZE_M + offsets_M = tl.arange(0,BLOCK_SIZE_M) + block_start + mask_M = offsets_M < M + ex = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + varx = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :]* stride_x1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0).to(tl.float32) + ex += tl.sum(x, axis=1) + varx += tl.sum(x * x, axis=1) + + ex = ex / N + varx = varx / N + varx -= (ex * ex) + normized = tl.sqrt(varx + eps) + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :]* stride_x1 + y_offsets = offsets_M[:, None] * stride_y0 + offsets_N[None, :]* stride_y1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0).to(tl.float32) + y = (x - ex[:, None]) / normized[:, None] + tl.store(y_ptr + y_offsets, y, mask=mask) + + +@triton.jit +def layernorm_weight_bias( + x_ptr, + weight_ptr, + bias_ptr, + y_ptr, + M: tl.constexpr, + N: tl.constexpr, + stride_x0: tl.constexpr, + stride_x1: tl.constexpr, + stride_y0: tl.constexpr, + stride_y1: tl.constexpr, + stride_weight: tl.constexpr, + stride_bias: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + eps: tl.constexpr = 1e-5, +): + ''' + layernorm function with weight and bias for last dimension. + ''' + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_M + offsets_M = tl.arange(0, BLOCK_SIZE_M) + block_start + mask_M = offsets_M < M + ex = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + varx = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0) + ex += tl.sum(x, axis=1) + varx += tl.sum(x * x, axis=1) + + ex = ex / N + varx = varx / N + varx -= (ex * ex) + normized = tl.sqrt(varx + eps) + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + x_offsets = offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1 + y_offsets = offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1 + mask = mask_M[:, None] & mask_N[None, :] + x = tl.load(x_ptr + x_offsets, mask=mask, other=0.0) + weight = tl.load(weight_ptr + offsets_N*stride_weight, mask=mask_N, other=0.0).to(tl.float32) + bias = tl.load(bias_ptr + offsets_N*stride_bias, mask=mask_N, other=0.0).to(tl.float32) + y = (x - ex[:, None]) * weight[None, :] / normized[:, None] + bias[None, :] + tl.store(y_ptr + y_offsets, y, mask=mask) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py new file mode 100644 index 000000000..2fc667265 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.softmax.softmax import softmax,online_softmax,gen_grid_softmax \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py new file mode 100644 index 000000000..393cbe217 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py @@ -0,0 +1,90 @@ +import triton +import triton.language as tl + +def gen_grid_softmax(M, BLOCK_SIZE_M): + """ + Generates the grid for a softmax kernel. + """ + return ( + triton.cdiv(M, BLOCK_SIZE_M),1,1) + +@triton.jit +def softmax(x_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr, stride_y0:tl.constexpr, stride_y1:tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + ''' + softmax function for last dimension. + ''' + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_M + offsets_M = tl.arange(0, BLOCK_SIZE_M) + block_start + + mask_M = offsets_M < M + + max_ele = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + elem = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) + max_ele = tl.maximum(max_ele,tl.max(elem, axis=1)) + + y_sum = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + elem = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) + y = tl.exp(elem - max_ele[:, None]) + y_sum += tl.sum(y, axis=1) + tl.store(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, y,mask=mask_M[:, None] & mask_N[None, :]) + + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i*BLOCK_SIZE_N + mask_N = offsets_N < N + y = tl.load(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, mask=mask_M[:, None] & mask_N[None, :], other=0) + y = y / y_sum[:,None] + tl.store(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, y, mask=mask_M[:, None] & mask_N[None, :]) + + +@triton.jit +def online_softmax(x_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr, stride_y0:tl.constexpr, stride_y1:tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + ''' + online softmax function for last dimension. + ''' + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE_M + offsets_M = tl.arange(0, BLOCK_SIZE_M) + block_start + mask_M = offsets_M < M + + m_i = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - float('inf') + s_i = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + NUM_BLOCK_N:tl.constexpr = (N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i * BLOCK_SIZE_N + mask_N = offsets_N < N + + x = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, + mask=mask_M[:, None] & mask_N[None, :], other=-float('inf')) + + m_block = tl.max(x, axis=1) + m_new = tl.maximum(m_i, m_block) + + s_i *= tl.exp(m_i - m_new) + s_i += tl.sum(tl.exp(x - m_new[:, None]), axis=1) + + m_i = m_new + + for i in tl.static_range(NUM_BLOCK_N): + offsets_N = tl.arange(0, BLOCK_SIZE_N) + i * BLOCK_SIZE_N + mask_N = offsets_N < N + + x = tl.load(x_ptr + offsets_M[:, None] * stride_x0 + offsets_N[None, :] * stride_x1, + mask=mask_M[:, None] & mask_N[None, :], other=0.) + + y = tl.exp(x - m_i[:, None]) / s_i[:, None] + + tl.store(y_ptr + offsets_M[:, None] * stride_y0 + offsets_N[None, :] * stride_y1, + y, mask=mask_M[:, None] & mask_N[None, :]) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py new file mode 100644 index 000000000..da1c07b90 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py @@ -0,0 +1,2 @@ +from tritontemplate.backend.cuda.transpose.transpose_10 import transpose_10,gen_grid_transpose_10 +from tritontemplate.backend.cuda.transpose.transpose_0213 import transpose_0213,gen_grid_transpose_0213 diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py new file mode 100644 index 000000000..87cfe56b5 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py @@ -0,0 +1,46 @@ +import triton +import triton.language as tl + +def gen_grid_transpose_0213(D0, D1, D2, D3, BLOCK_SIZE_D1, BLOCK_SIZE_D2): + """ + Generates the grid for a transpose kernel. + """ + return (D0 * D3, triton.cdiv(D2, BLOCK_SIZE_D2), triton.cdiv(D1, BLOCK_SIZE_D1)) + +#TODO: rewrite for contiguous D3 reading and storing +@triton.jit +def transpose_0213(x, y, + D0:tl.constexpr, D1:tl.constexpr, D2:tl.constexpr, D3:tl.constexpr, + stride_x0:tl.constexpr, stride_x1:tl.constexpr, stride_x2:tl.constexpr, stride_x3:tl.constexpr, + stride_y0:tl.constexpr, stride_y1:tl.constexpr, stride_y2:tl.constexpr, stride_y3:tl.constexpr, + BLOCK_SIZE_D1:tl.constexpr,BLOCK_SIZE_D2:tl.constexpr): + """ + Transpose a matrix in the 0213 layout. + """ + pid_d0d3 = tl.program_id(0) + pid_d2_chunk = tl.program_id(1) + pid_d1_chunk = tl.program_id(2) + + pid_d0 = pid_d0d3 // D3 + pid_d3 = pid_d0d3 % D3 + + offs_d1 = pid_d1_chunk * BLOCK_SIZE_D1 + tl.arange(0, BLOCK_SIZE_D1) + offs_d2 = pid_d2_chunk * BLOCK_SIZE_D2 + tl.arange(0, BLOCK_SIZE_D2) + + x_ptr = x + pid_d0 * stride_x0 + \ + offs_d1[:, None] * stride_x1 + \ + offs_d2[None, :] * stride_x2 + \ + pid_d3 * stride_x3 + + y_ptr = y + pid_d0 * stride_y0 + \ + offs_d2[:, None] * stride_y1 + \ + offs_d1[None, :] * stride_y2 + \ + pid_d3 * stride_y3 + + + mask_x = (offs_d1[:, None] < D1) & (offs_d2[None, :] < D2) + mask_y = (offs_d2[:, None] < D2) & (offs_d1[None, :] < D1) + + tile = tl.load(x_ptr, mask=mask_x, other=0.0) + transposed_tile = tl.trans(tile) + tl.store(y_ptr, transposed_tile, mask=mask_y) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py new file mode 100644 index 000000000..44187666b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py @@ -0,0 +1,36 @@ +import triton +import triton.language as tl + +def gen_grid_transpose_10(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): + """ + Generates the grid for a GEMM kernel. + """ + return ( + triton.cdiv(M, BLOCK_SIZE_M), + triton.cdiv(N, BLOCK_SIZE_N), + 1) + +@triton.jit +def transpose_10( + x_ptr, y_ptr, + M: tl.constexpr, N: tl.constexpr, + stride_x0:tl.constexpr,stride_x1:tl.constexpr, + stride_y0:tl.constexpr,stride_y1:tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + block_start_m = pid_m * BLOCK_SIZE_M + block_start_n = pid_n * BLOCK_SIZE_N + offset_M = tl.arange(0, BLOCK_SIZE_M) + block_start_m + offset_N = tl.arange(0, BLOCK_SIZE_N) + block_start_n + + input_offsets = (offset_M[:, None] * stride_x0 + offset_N[None, :]) * stride_x1 + mask_M = offset_M < M + mask_N = offset_N < N + mask = mask_M[:, None] & mask_N[None, :] + + x = tl.load(x_ptr + input_offsets, mask=mask) + output_offsets = (offset_N[:, None] * stride_y0 + offset_M[None, :]) * stride_y1 + output_mask = mask_M[None, :] & mask_N[:, None] + x = tl.trans(x) + tl.store(y_ptr + output_offsets, x, mask=output_mask) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py new file mode 100644 index 000000000..44ecea8ff --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/activation.py @@ -0,0 +1,9 @@ +import triton +import triton.language as tl + +__all__ = ['relu'] + +@triton.jit +def relu(x): + return tl.maximum(x, 0) + diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/utils.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/utils.py new file mode 100644 index 000000000..cfe5466c4 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/utils.py @@ -0,0 +1,8 @@ +from typing import Union, List, Tuple + +def shape2stride(s: Union[List[int], Tuple[int, ...]]) -> Tuple[int, ...]: + slen=len(s) + stride=[1]*slen + for i in range(1,slen): + stride[slen-i-1]=stride[slen-i]*s[slen-i] + return tuple(stride) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/__init__.py new file mode 100644 index 000000000..def690a0b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/__init__.py @@ -0,0 +1,4 @@ +from tritontemplate.compiler import base,dtype,op_registry,ops +from tritontemplate.compiler.compiler import compile_kernel + +__all__ = ["base", "compile_kernel","dtype","op_registry","ops",] \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py new file mode 100644 index 000000000..8a69f30df --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -0,0 +1,143 @@ +from abc import ABC,abstractmethod +from pprint import pformat +from typing import Any, Dict, Iterable, List, Optional, Set, Union + +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature + +class BaseType(ABC): + def __init__(self) -> None: + super().__init__() + self._attrs: Dict[str, Any] = {"name": None, "nop": False} + + def __str__(self) -> str: + return pformat(self._attrs, indent=2, depth=2) + + def __repr__(self) -> str: + return self.__str__() + + +class IntImm(int, BaseType): + def __new__(cls, val: int, divisibility: Optional[int] = None, name: Optional[str] = None): + instance = super().__new__(cls, val) + return instance + + def __init__( + self, + val: int, + divisibility: Optional[int] = None, + name: Optional[str] = None, + ) -> None: + BaseType.__init__(self) + self._attrs['val'] = int(self) + if name is not None: + self._attrs['name'] = name + self._attrs['divisibility'] = divisibility + + @property + def name(self) -> Optional[str]: + return self._attrs.get('name') + + @property + def divisibility(self) -> Optional[int]: + return self._attrs.get('divisibility') + + # set divisibility: + # @divisibility.setter + # def divisibility(self, value: Optional[int]): + # self._attrs['divisibility'] = value + + @property + def val(self) -> int: + return int(self) + + +class Tensor(BaseType): + """ + """ + def __init__( + self, + shape: List[IntImm], + dtype: str = "float16", + name: Optional[str] = None, + ) -> None: + super().__init__() + if name is not None: + self._attrs['name'] = name + self._attrs['dtype'] = dtype + self._attrs['shape'] = shape + + @property + def name(self) -> Optional[str]: + return self._attrs.get('name') + + @property + def dtype(self) -> str: + return self._attrs['dtype'] + + @property + def shape(self) -> List[IntImm]: + return self._attrs['shape'] + + +class Operation(BaseType): + """ + """ + def __init__( + self, + inputs: List[BaseType], + outputs: Optional[List[BaseType]] = None, + name: Optional[str] = None, + ) -> None: + super().__init__() + self._attrs['inputs'] = inputs + self._attrs['outputs'] = outputs + self._attrs['name'] = name + + @property + def name(self) -> Optional[str]: + return self._attrs.get('name') + + @property + def inputs(self) -> List[Tensor]: + return self._attrs['inputs'] + + @property + def outputs(self) -> Optional[List[Tensor]]: + return self._attrs['outputs'] + + @abstractmethod + def compile(self,target_name,workdir): + raise NotImplementedError + + + def _gen_tensor_signature_divisiability(self,tensors_names:List[str]): + signature_metadata={} + divisiability={1:[],16:[]} + tensor_obj=[] + for tensor_name in tensors_names: + tensor_obj+=self._attrs[tensor_name] + for i,input in enumerate(tensor_obj): + if isinstance(input,Tensor): + try: + sptype='*'+dtype_str_to_triton_signature(input.dtype) + except KeyError: + raise KeyError(f'dtype {input.dtype} not supported') + signature_metadata[i]=sptype + # the ptr from torch is 16-byte aligned + if divisiability.get(16,None) is None: + divisiability[16]=[i] + else: + divisiability[16].append(i) + else: + raise NotImplementedError(f'input {input} not supported') + + return signature_metadata,divisiability + + @staticmethod + def _block_size(x): + if x<=32: + return 32 + elif x<=64: + return 64 + else: + return 128 diff --git a/external/TritonTemplate/python/tritontemplate/compiler/compiler.py b/external/TritonTemplate/python/tritontemplate/compiler/compiler.py new file mode 100644 index 000000000..5a4726283 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/compiler.py @@ -0,0 +1,21 @@ +from typing import List, Optional, Union +import logging +import importlib + +from tritontemplate import compiler,backend +from tritontemplate.compiler.kernel import TritonExecutor + +_LOGGER = logging.getLogger(__name__) + +def compile_kernel( + op: compiler.base.Operation, + device: str='cuda', + workdir: str='./workshop', + enable_tf32: bool=False +)->TritonExecutor: + try: + _ = importlib.import_module(f'tritontemplate.backend.{device}') + except ModuleNotFoundError: + raise ModuleNotFoundError(f'Target {device} not found') + return op.compile(device, workdir,enable_tf32) + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/dtype.py b/external/TritonTemplate/python/tritontemplate/compiler/dtype.py new file mode 100644 index 000000000..5afdff586 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/dtype.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified by somehow6 on 2025-06-11 to tritontemplate. +""" +dtype definitions and utility functions of Tritontemplate +""" + + +_DTYPE2BYTE = { + "bool": 1, + "float16": 2, + "float32": 4, + "float": 4, + "int": 4, + "int32": 4, + "int64": 8, + "bfloat16": 2, +} + +_DTYPETRITONSIGNATURE = { + "float16": "fp16", + "float32": "fp32", + "float": "fp32", + "int": "i32", + "int32": "i32", + "int64": "i64", + "bfloat16": "bf16" +} + + +# Maps dtype strings to AITemplateDtype enum in model_interface.h. +# Must be kept in sync! +# We can consider defining an AITemplateDtype enum to use on the Python +# side at some point, but stick to strings for now to keep things consistent +# with other Python APIs. +_DTYPE_TO_ENUM = { + "float16": 1, + "float32": 2, + "float": 2, + "int": 3, + "int32": 3, + "int64": 4, + "bool": 5, + "bfloat16": 6, +} + + +def get_dtype_size(dtype: str) -> int: + """Returns size (in bytes) of the given dtype str. + + Parameters + ---------- + dtype: str + A data type string. + + Returns + ---------- + int + Size (in bytes) of this dtype. + """ + + if dtype not in _DTYPE2BYTE: + raise KeyError(f"Unknown dtype: {dtype}. Expected one of {_DTYPE2BYTE.keys()}") + return _DTYPE2BYTE[dtype] + + +def normalize_dtype(dtype: str) -> str: + """Returns a normalized dtype str. + + Parameters + ---------- + dtype: str + A data type string. + + Returns + ---------- + str + normalized dtype str. + """ + if dtype == "int": + return "int32" + if dtype == "float": + return "float32" + return dtype + + +def dtype_str_to_enum(dtype: str) -> int: + """Returns the AITemplateDtype enum value (defined in model_interface.h) of + the given dtype str. + + Parameters + ---------- + dtype: str + A data type string. + + Returns + ---------- + int + the AITemplateDtype enum value. + """ + if dtype not in _DTYPE_TO_ENUM: + raise ValueError( + f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPE_TO_ENUM.keys())}" + ) + return _DTYPE_TO_ENUM[dtype] + +def dtype_str_to_triton_signature(dtype: str) -> str: + """Returns the AITemplateDtype enum value (defined in model_interface.h) of + the given dtype str. + Parameters + ---------- + dtype: str + A data type string. + Returns + ---------- + int + the AITemplateDtype enum value. + """ + if dtype not in _DTYPETRITONSIGNATURE: + raise ValueError( + f"Got unsupported input dtype {dtype}! Supported dtypes are: {list(_DTYPETRITONSIGNATURE.keys())}" + ) + return _DTYPETRITONSIGNATURE[dtype] + + +def is_same_dtype(dtype1: str, dtype2: str) -> bool: + """Returns True if dtype1 and dtype2 are the same dtype and False otherwise. + + Parameters + ---------- + dtype1: str + A data type string. + dtype2: str + A data type string. + + Returns + ---------- + bool + whether dtype1 and dtype2 are the same dtype + """ + return normalize_dtype(dtype1) == normalize_dtype(dtype2) + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py new file mode 100644 index 000000000..012c54c19 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -0,0 +1,34 @@ +from typing import Sequence +import triton + +from tritontemplate.compiler.utils import get_device_max_shared_memory,get_cuda_device_name + +class TritonExecutor: + def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32,constants:dict=None): + self.call_constants = constants + self.triton_kernel = triton_kernel + self.gridsize = grid_size + self.blocksize = triton_kernel.num_warps * warp_size + self.warpsize = warp_size + self.name = triton_kernel.metadata['name'] + self.smemsize = triton_kernel.shared + + self.device_name = get_cuda_device_name() + try: + self.device_name = get_cuda_device_name() + assert self.smemsize <= get_device_max_shared_memory(self.device_name), \ + f'kernel {self.name} smem size {self.smemsize} exceeds device {self.device_name} max smem size {get_device_max_shared_memory(self.device_name)}' + except KeyError as e: + # Log the error and continue with default values + import logging + logging.warning(f"Unsupported device detected: {str(e)}. Continuing with default configuration.") + self.device_name = "unknown" + + def __call__(self, *args, **kwds): + return self.triton_kernel[self.gridsize](*args, **kwds) + + def kernel_ptx(self,func_name:str): + ptx = self.triton_kernel.asm['ptx'] + ptx = ptx.replace(self.name, func_name) + return ptx + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/op_registry.py b/external/TritonTemplate/python/tritontemplate/compiler/op_registry.py new file mode 100644 index 000000000..7b870c869 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/op_registry.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Registry for basic operators and math functions. +""" +from typing import Callable, Dict + +# OP_REGISTRY defines a mapping from a FuncEnum name to a function to create this elementwise operator. +# This object is initialized in elementwise.py, and referenced in base.py and math.py. +OP_REGISTRY: Dict[str, Callable] = {} diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py new file mode 100644 index 000000000..a421c9891 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/__init__.py @@ -0,0 +1,5 @@ +from tritontemplate.compiler.ops.gemm import Gemm +from tritontemplate.compiler.ops.bmm import Bmm +from tritontemplate.compiler.ops.transpose import Transpose +from tritontemplate.compiler.ops.softmax import Softmax +from tritontemplate.compiler.ops.layernorm import Layernorm diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py new file mode 100644 index 000000000..f849c3561 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.bmm.bmm import Bmm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py new file mode 100644 index 000000000..3563d6e49 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -0,0 +1,103 @@ +from typing import List,Optional +import importlib + +import triton +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_supported_layouts = ['rcr','rrr','crr','ccr'] + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Bmm(Operation): + def __init__( + self, + inputs:List[Tensor], + layout:str,is_bias:bool=False, + outputs:Optional[List[Tensor]]=None, + name: Optional[str]=None): + assert layout in _supported_layouts, f"Unsupported layout {layout}" + + super().__init__(inputs, outputs,name) + self.layout = layout + self.is_bias = is_bias + self._deduce_output_shape() + + def _deduce_output_shape(self): + BATCH_SIZE = self._attrs['inputs'][0].shape[0] + is_transpose_a=self.layout[0]=='c' + is_transpose_b=self.layout[1]=='c' + M=self._attrs['inputs'][0].shape[2] if is_transpose_a else self._attrs['inputs'][0].shape[1] + K=self._attrs['inputs'][0].shape[1] if is_transpose_a else self._attrs['inputs'][0].shape[2] + N=self._attrs['inputs'][1].shape[1] if is_transpose_b else self._attrs['inputs'][1].shape[2] + + self._attrs['BATCH_SIZE']=BATCH_SIZE + self._attrs['M']=M + self._attrs['N']=N + self._attrs['K']=K + self._attrs['is_transpose_a']=is_transpose_a + self._attrs['is_transpose_b']=is_transpose_b + + res_shape=[BATCH_SIZE,M,N] if self.layout[2]=='r' else [BATCH_SIZE,N,M] + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=res_shape,dtype=self._attrs['inputs'][0].dtype)] + else: + assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" + + def _gen_constants(self,enable_tf32): + const_metadata={} + any_float32=False + for input in self._attrs['inputs']: + if input.dtype == 'float32': + any_float32=True + break + const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) + const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) + + const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False + input=self._attrs['inputs'] + output=self._attrs['outputs'] + const_metadata['BATCH_SIZE']=self._attrs['BATCH_SIZE'] + const_metadata['M']=self._attrs['M'] + const_metadata['N']=self._attrs['N'] + const_metadata['K']=self._attrs['K'] + + const_metadata['is_transpose_a']=self._attrs['is_transpose_a'] + const_metadata['is_transpose_b']=self._attrs['is_transpose_b'] + const_metadata['stride_a0'],const_metadata['stride_a1'],const_metadata['stride_a2']=shape2stride(input[0].shape) + const_metadata['stride_b0'],const_metadata['stride_b1'],const_metadata['stride_b2']=shape2stride(input[1].shape) + if self.is_bias: + const_metadata['stride_bias0'],const_metadata['stride_bias1'],const_metadata['stride_bias2']=shape2stride(input[2].shape) + + const_metadata['stride_c0'],const_metadata['stride_c1'],const_metadata['stride_c2']=shape2stride(output[0].shape) + + + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self, target_name, workdir,enable_tf32: bool = False,)->TritonExecutor: + triton_kernel_name=f'bmm'+ ('' if not self.is_bias else '_bias') + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_grid_bmm') + + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants(enable_tf32) + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['BATCH_SIZE'],constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py new file mode 100644 index 000000000..3eec361e5 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.gemm.gemm import Gemm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py new file mode 100644 index 000000000..6b88d9c48 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -0,0 +1,116 @@ +from typing import List,Optional +import importlib + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_supported_layouts = ['rcr','rrr','ccr','crr'] +_supported_activations = ['relu',None] + + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Gemm(Operation): + def __init__( + self, + inputs: List[Tensor], + layout: str, + is_bias: bool = False, + outputs: Optional[List[Tensor]] = None, + activation: Optional[str] = None, + name: Optional[str] = None, + ) -> None: + assert layout in _supported_layouts, f'layout {layout} not supported' + assert activation in _supported_activations, f'activation {activation} not supported' + + super().__init__(inputs, outputs, name) + self.layout = layout + self.is_bias= is_bias + self._attrs['activation'] = activation + self._deduce_output_shape() + + + def _deduce_output_shape(self): + + is_transpose_a=self.layout[0]=='c' + is_transpose_b=self.layout[1]=='c' + M=self._attrs['inputs'][0].shape[1] if is_transpose_a else self._attrs['inputs'][0].shape[0] + K=self._attrs['inputs'][0].shape[0] if is_transpose_a else self._attrs['inputs'][0].shape[1] + N=self._attrs['inputs'][1].shape[0] if is_transpose_b else self._attrs['inputs'][1].shape[1] + + + self._attrs['M'] = M + self._attrs['K'] = K + self._attrs['N'] = N + self._attrs['is_transpose_a'] = is_transpose_a + self._attrs['is_transpose_b'] = is_transpose_b + + res_shape=[M,N] if self.layout[2]=='r' else [N,M] + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=res_shape,dtype=self._attrs['inputs'][0].dtype)] + else: + assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" + + def _gen_constants(self,enable_tf32): + const_metadata={} + const_metadata['ACTIVATION'] = self._attrs['activation'] + + any_float32=False + for input in self._attrs['inputs']: + if input.dtype == 'float32': + any_float32=True + break + + const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False + + const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) + const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) + + input=self._attrs['inputs'] + output=self._attrs['outputs'] + const_metadata['M']=self._attrs['M'] + const_metadata['N']=self._attrs['N'] + const_metadata['K']=self._attrs['K'] + + const_metadata['is_transpose_a']=self._attrs['is_transpose_a'] + const_metadata['is_transpose_b']=self._attrs['is_transpose_b'] + const_metadata['stride_a0'],const_metadata['stride_a1']=shape2stride(input[0].shape) + const_metadata['stride_b0'],const_metadata['stride_b1']=shape2stride(input[1].shape) + + if self.is_bias: + const_metadata['stride_bias0']=1 + + const_metadata['stride_c0'],const_metadata['stride_c1']=shape2stride(output[0].shape) + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + #TODO:enable_tf32 https://github.com/triton-lang/triton/issues/4574 + def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor: + triton_kernel_name=f'gemm'+ ('' if not self.is_bias else '_bias') + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm') + + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants(enable_tf32) + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) + + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py new file mode 100644 index 000000000..2e129c98f --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.layernorm.layernorm import Layernorm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py new file mode 100644 index 000000000..4e2641884 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -0,0 +1,81 @@ +from typing import List,Optional +import importlib +from math import prod + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Layernorm(Operation): + def __init__( + self, + inputs: List[Tensor],# [x,bias(beta),weight(gamma)] + axises:List[int], + eps:float = 1e-5, + outputs: Optional[List[Tensor]] = None, + name: Optional[str] = None, + ) -> None: + super().__init__(inputs, outputs, name) + assert len(axises)==1 and axises[0] == len(inputs[0].shape)-1, f'Only last axis normalization is supported (axis={axis}, input shape={inputs[0].shape})' + self._attrs['axis'] = axises[0] + self._attrs['eps'] = eps + + self._deduce_output_shape() + + def _deduce_output_shape(self): + M = prod(self._attrs['inputs'][0].shape[:-1]) + N = self._attrs['inputs'][0].shape[-1] + self._attrs['M'] = M + self._attrs['N'] = N + self._attrs['with_weight_bias']=len(self._attrs['inputs']) == 3 + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] + + def _gen_constants(self): + const_metadata={} + const_metadata['M']= self._attrs['M'] + const_metadata['N']= self._attrs['N'] + const_metadata['stride_x0'] = self._attrs['N'] + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = self._attrs['N'] + const_metadata['stride_y1'] = 1 + const_metadata['eps'] = self._attrs['eps'] + + const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + + if self._attrs['with_weight_bias']: + const_metadata['stride_weight'] = 1 + const_metadata['stride_bias'] = 1 + + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self, target_name, workdir,enable_tf32)->TritonExecutor: + triton_kernel_name= 'layernorm_weight_bias' if self._attrs['with_weight_bias'] else 'layernorm' + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_grid_layernorm') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) + + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py new file mode 100644 index 000000000..480cc116b --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.softmax.softmax import Softmax \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py new file mode 100644 index 000000000..fb0373f82 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -0,0 +1,69 @@ +from typing import List,Optional +import importlib +from math import prod + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 3, +} + +class Softmax(Operation): + def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None): + super().__init__(inputs, outputs, name) + assert dim == len(inputs[0].shape)-1, f'only support last axis now' + self._attrs['dim'] = dim + self._attrs['enable_online'] = enable_online + self._deduce_output_shape() + + def _deduce_output_shape(self): + M = prod(self._attrs['inputs'][0].shape[:-1]) + N = self._attrs['inputs'][0].shape[-1] + + self._attrs['M']= M + self._attrs['N']= N + + if self._attrs['outputs'] is None: + # Return float32 + self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype='float32')] + # self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] + + def _gen_constants(self): + const_metadata={} + const_metadata['M']= self._attrs['M'] + const_metadata['N']= self._attrs['N'] + const_metadata['stride_x0'] = self._attrs['N'] + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = self._attrs['N'] + const_metadata['stride_y1'] = 1 + + const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) + const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + + return const_metadata + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: + triton_kernel_name= 'online_softmax' if self._attrs['enable_online'] else 'softmax' + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),triton_kernel_name) + gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_grid_softmax') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py new file mode 100644 index 000000000..aef6dc273 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/__init__.py @@ -0,0 +1 @@ +from tritontemplate.compiler.ops.transpose.transpose import Transpose \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py new file mode 100644 index 000000000..1a37495c5 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -0,0 +1,108 @@ +from typing import List,Optional +import importlib + +import triton + +from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.kernel import TritonExecutor +from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.backend.cuda.utils.utils import shape2stride + +_supported_permutations = ['10','0213'] + +_exec_metadata = { + 'num_warps': 4, + 'num_stages': 1, +} + +class Transpose(Operation): + def __init__(self, + inputs: List[Tensor], + permutation: str, + outputs: Optional[List[Tensor]] = None, + name: Optional[str] = None): + super().__init__(inputs, outputs, name) + assert permutation in _supported_permutations, f"Unsupported permutation {permutation}" + self._attrs['permutation'] = permutation + + self._deduce_output_shape() + + def _deduce_output_shape(self): + input_shape = self._attrs['inputs'][0].shape + output_shape = [] + for i in self._attrs['permutation']: + output_shape.append(input_shape[int(i)]) + if self._attrs['outputs'] is None: + self._attrs['outputs'] = [Tensor(output_shape, self._attrs['inputs'][0].dtype)] + else: + assert self._attrs['outputs'][0].shape == output_shape, f"Transpose op output shape {self._attrs['outputs'][0].shape} does not match expected shape {output_shape}" + + def _gen_constants_10(self): + const_metadata={} + M,N=self._attrs['inputs'][0].shape + + const_metadata['M'] = M + const_metadata['N'] = N + const_metadata['stride_x0'] = N + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = M + const_metadata['stride_y1'] = 1 + + const_metadata['BLOCK_SIZE_M'] = self._block_size(M) + const_metadata['BLOCK_SIZE_N'] = self._block_size(N) + + return const_metadata + + def _gen_grid_10(self,target_name,const_metadata): + gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_10') + return gen_grid(const_metadata['M'],const_metadata['N'],const_metadata['BLOCK_SIZE_M'],const_metadata['BLOCK_SIZE_N']) + + def _gen_constants_0213(self): + const_metadata={} + D0,D1,D2,D3=self._attrs['inputs'][0].shape + const_metadata['D0'] = D0 + const_metadata['D1'] = D1 + const_metadata['D2'] = D2 + const_metadata['D3'] = D3 + + const_metadata['stride_x0'],const_metadata['stride_x1'],const_metadata['stride_x2'],const_metadata['stride_x3'] = shape2stride(self._attrs['inputs'][0].shape) + const_metadata['stride_y0'],const_metadata['stride_y1'],const_metadata['stride_y2'],const_metadata['stride_y3']= shape2stride(self._attrs['outputs'][0].shape) + + const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) + const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) + + return const_metadata + + def _gen_grid_0213(self,target_name,const_metadata): + gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_0213') + return gen_grid(const_metadata['D0'],const_metadata['D1'],const_metadata['D2'],const_metadata['D3'],const_metadata['BLOCK_SIZE_D1'],const_metadata['BLOCK_SIZE_D2']) + + + def _gen_exec_metadata(self): + return _exec_metadata.copy() + + + def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: + triton_kernel_name= 'transpose_' + self._attrs['permutation'] + triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),triton_kernel_name) + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + exec_metadata=self._gen_exec_metadata() + + if self._attrs['permutation'] == '10': + constants=self._gen_constants_10() + exec_grid = self._gen_grid_10(target_name,constants) + elif self._attrs['permutation'] == '0213': + constants=self._gen_constants_0213() + exec_grid = self._gen_grid_0213(target_name,constants) + else: + raise ValueError(f"Unsupported permutation {self._attrs['permutation']}") + + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + + triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) + + + return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/utils.py b/external/TritonTemplate/python/tritontemplate/compiler/utils.py new file mode 100644 index 000000000..28cb397eb --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/compiler/utils.py @@ -0,0 +1,33 @@ +import subprocess + +_TARGET2WARPSIZE={ + 'cuda':32, +} + +_DEVICE_MAX_SHARED_MEMORY={ + "NVIDIA H800": 227 * 1024, + "NVIDIA H100": 227 * 1024, + "NVIDIA A100": 164 * 1024, + "NVIDIA A800": 164 * 1024, + "NVIDIA V100": 96 * 1024, + "NVIDIA T4": 64 * 1024, +} + +def get_cuda_device_name(idx=0): + cmd = "nvidia-smi --query-gpu=name --format=csv,noheader" + result = subprocess.check_output(cmd, shell=True) + gpu_names = result.decode().strip().split("\n") + return gpu_names[idx] + +def get_warpsize(target_name): + try: + return _TARGET2WARPSIZE[target_name] + except KeyError: + raise KeyError(f'target {target_name} not supported') + +def get_device_max_shared_memory(target_name): + try: + return _DEVICE_MAX_SHARED_MEMORY[target_name] + except KeyError: + raise KeyError(f'target {target_name} not supported, please add max smem size info') + \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/__init__.py b/external/TritonTemplate/python/tritontemplate/testing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/testing/aot_demo.py b/external/TritonTemplate/python/tritontemplate/testing/aot_demo.py new file mode 100644 index 000000000..ef4a5f40a --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/aot_demo.py @@ -0,0 +1,208 @@ +import glob +import os +import subprocess +import sys +import tempfile + +import numpy as np + +import triton +from triton.common import cuda_include_dir, libcuda_dirs + +kernel_utils_src = """ +import triton + +@triton.jit +def mul(x, y): + return x * y +""" + +kernel_src = """ +import triton +import triton.language as tl +import kernel_utils + +@triton.jit +def kernel(C, A, B, + stride_cm, stride_cn, + stride_am, stride_ak, + stride_bk, stride_bn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + ms = tl.arange(0, BLOCK_M) + ns = tl.arange(0, BLOCK_N) + ks = tl.arange(0, BLOCK_K) + a = tl.load(A + ms[:, None] * stride_am + ks[None, :] * stride_ak) + b = tl.load(B + ks[:, None] * stride_bk + ns[None, :] * stride_bn) + c = tl.dot(a, b) + c = kernel_utils.mul(c, c) + tl.store(C + ms[:, None] * stride_cm + ns[None, :] * stride_cn, c) +""" + +test_src = """ +#include +#include +#include +#include +#include "kernel.h" + +static void write_buffer_to_csv(char *filename, int32_t *buffer, int size) { + FILE *file = fopen(filename, "w"); + if (file == NULL) { + printf(\"Could not open file %s\\n\", filename); + return; + } + for (int i = 0; i < size; i++) { + fprintf(file, "%d", buffer[i]); + if (i < size - 1) { + fprintf(file, ","); + } + } + fclose(file); +} + +static void read_csv_to_buffer(char *filename, int16_t *buffer, int size) { + FILE *file = fopen(filename, "r"); + if (file == NULL) { + printf(\"Could not open file %s\\n\", filename); + return; + } + int index = 0; + while (fscanf(file, "%hd,", &buffer[index]) != EOF && index < size) { + index++; + } + fclose(file); +} + +int main(int argc, char **argv) { + int M = 16, N = 16, K = 16; + int BM = 16, BN = 16, BK = 16; + + // initialize CUDA handles + CUdevice dev; + CUcontext ctx; + CUstream stream; + CUdeviceptr A, B, C; + CUresult err = 0; + cuInit(0); + cuDeviceGet(&dev, 0); + cuCtxCreate(&ctx, 0, dev); + cuMemAlloc(&A, M * K * 2); + cuMemAlloc(&B, K * N * 2); + cuMemAlloc(&C, M * N * 4); + cuStreamCreate(&stream, 0); + load_matmul_fp16xfp16_16x16x16(); + + // initialize input data + int16_t hA[M*K]; + int16_t hB[K*N]; + memset(hA, 0, M*K*2); + memset(hB, 0, K*N*2); + read_csv_to_buffer(argv[1], hA, M*K); + read_csv_to_buffer(argv[2], hB, K*N); + cuMemcpyHtoD(A, hA, M*K*2); + cuMemcpyHtoD(B, hB, K*N*2); + + // launch kernel + int gX = 1, gY = 1, gZ = 1; + cuStreamSynchronize(stream); + matmul_fp16xfp16_16x16x16(stream, M/BM, N/BN, 1, C, A, B, N, K, N); + cuStreamSynchronize(stream); + + // read data + int32_t hC[M*N]; + memset(hC, 0, M*N*4); + cuMemcpyDtoH(hC, C, M*N*4); + write_buffer_to_csv(argv[3], hC, M*N); + + + // free cuda handles + unload_matmul_fp16xfp16_16x16x16(); + cuMemFree(A); + cuMemFree(B); + cuMemFree(C); + cuCtxDestroy(ctx); +} +""" + + +def test_compile_link_matmul(): + np.random.seed(3) + + with tempfile.TemporaryDirectory() as tmp_dir: + kernel_path = os.path.join(tmp_dir, "kernel.py") + with open(kernel_path, "w") as file: + file.write(kernel_src) + + kernel_utils_path = os.path.join(tmp_dir, "kernel_utils.py") + with open(kernel_utils_path, "w") as file: + file.write(kernel_utils_src) + + compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") + linker_path = os.path.join(triton.tools.__path__[0], "link.py") + + dtype = "fp16" + M, N, K = 16, 16, 16 + BM, BN, BK = 16, 16, 16 + + # compile all desired configs + hints = [":16", ""] + for ha in hints: + for hb in hints: + sig = f'*fp32:16, *{dtype}:16, *{dtype}:16, i32{ha}, 1, i32{hb}, 1, i32:16, 1, {BM}, {BN}, {BK}' + name = f"matmul_{dtype}x{dtype}_{BM}x{BN}x{BK}" + subprocess.run([sys.executable, compiler_path, "-n", "kernel", "--signature", sig, "--out-name", name, "-o", name, "-w", "1", kernel_path], check=True, cwd=tmp_dir) + + # link all desired configs + h_files = glob.glob(os.path.join(tmp_dir, "*.h")) + subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=tmp_dir) + + # compile test case + with open(os.path.join(tmp_dir, "test.c"), "w") as file: + file.write(test_src) + c_files = glob.glob(os.path.join(tmp_dir, "*.c")) + subprocess.run(["gcc"] + c_files + ["-I", cuda_include_dir(), + "-L", libcuda_dirs()[0], + "-l", "cuda", + "-o", "test"], check=True, cwd=tmp_dir) + + # initialize test data + a = np.random.randn(M * K).astype(np.float16).reshape((M, K)) + b = np.random.randn(M * K).astype(np.float16).reshape((K, N)) + a_path = os.path.join(tmp_dir, "a.csv") + b_path = os.path.join(tmp_dir, "b.csv") + c_path = os.path.join(tmp_dir, "c.csv") + for x, path in [(a, a_path), (b, b_path)]: + x.view(np.int16).ravel().tofile(path, sep=",") + + # run test case + subprocess.run(["./test", a_path, b_path, c_path], check=True, cwd=tmp_dir) + + # read data and compare against reference + c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) + c_tri = c.reshape((M, N)).view(np.float32) + c_ref = np.matmul(a.astype(np.float32), b.astype(np.float32)) + np.testing.assert_allclose(c_tri, c_ref * c_ref, atol=1e-4, rtol=0.) + + +def test_ttgir_to_ptx(): + src = """ +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { + tt.return + } +} +""" + with tempfile.TemporaryDirectory() as tmp_dir: + kernel_path = os.path.join(tmp_dir, "empty_kernel.ttgir") + with open(kernel_path, "w") as fp: + fp.write(src) + k = triton.compile(kernel_path, cc=80) + ptx = k.asm["ptx"] + assert ".target sm_80" in ptx + assert ".address_size 64" in ptx + +if __name__ == "__main__": + test_compile_link_matmul() + test_ttgir_to_ptx() \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py new file mode 100644 index 000000000..83a45fc33 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py @@ -0,0 +1,150 @@ +import torch +import pytest + +import triton +from tritontemplate.compiler.base import IntImm,Tensor +from tritontemplate.compiler.ops.bmm import Bmm +from tritontemplate.compiler.compiler import compile_kernel + +from tritontemplate.backend.cuda.bmm import bmm_bias as bmm_bias_kernel +from tritontemplate.backend.cuda.bmm import bmm as bmm_kernel + +def gen_bmm_bias(format, batch_size, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[batch_size,M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[batch_size,K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[batch_size,K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[batch_size,N,K]) + Bias=Tensor(name='Bias',dtype=stype,shape=[batch_size,M,N]) + C=Tensor(name='C',dtype=stype,shape=[batch_size,M,N]) + bmm_op=Bmm( + inputs=[A,B,Bias], + outputs=None, + layout=format, + is_bias=True + ) + kernel = compile_kernel(bmm_op,device='cuda') + return kernel + +def gen_bmm(format, batch_size, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[batch_size,M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[batch_size,K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[batch_size,K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[batch_size,N,K]) + C=Tensor(name='C',dtype=stype,shape=[batch_size,M,N]) + bmm_op=Bmm( + inputs=[A,B], + outputs=None, + layout=format, + is_bias=False + ) + kernel = compile_kernel(bmm_op,device='cuda') + return kernel + + +FORMATS = ['rcr','rrr','ccr','crr'] +MATRIX_PARAMS = [ + (2, 2, 128, 31, 'float32'), + (2, 128, 2, 31, 'float16'), + (2, 128, 128, 31, 'float32'), + (2, 31, 128, 2, 'float16'), + (2, 129, 128, 128, 'float32'), + (2, 128, 257, 512, 'float16'), + (2, 128, 512, 257, 'float32'), + (2, 127, 256, 256, 'float16'), + (2, 128, 511, 512, 'float32'), + (2, 256, 128, 255, 'float16'), + (2, 1, 256, 256, 'float32'), +] + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'batch_size, M, N, K, stype', + MATRIX_PARAMS +) +def test_bmm_bias(format, batch_size, M, N, K, stype): + torch.manual_seed(0) + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32=False + dtype=torch.float32 + else: + dtype=torch.float16 + + a = torch.randn(batch_size, M, K, dtype=dtype, device='cuda') + b = torch.randn(batch_size, K, N, dtype=dtype, device='cuda') + bias = torch.randn(batch_size,M, N, dtype=dtype, device='cuda') + c_triton_jit = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') + c_triton_aot = torch.empty(batch_size, M, N, dtype=dtype, device='cuda') + + c_torch= torch.bmm(a, b)+bias + + grid = lambda META: ( + META['BATCH_SIZE'],triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + is_trans_a=False + is_trans_b=False + if format[0]=='c': + a=a.transpose(1,2).contiguous() + is_trans_a=True + if format[1]=='c': + b=b.transpose(1,2).contiguous() + is_trans_b=True + test_kernel=gen_bmm_bias(format,batch_size,M,N,K,stype) + test_kernel(a,b,bias,c_triton_aot) + bmm_bias_kernel[grid](a,b,bias,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*bias.stride(),*c_triton_jit.stride(),64,64,64,False) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + + torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=atol,rtol=rtol) + torch.testing.assert_close(c_triton_aot,c_torch,atol=atol,rtol=rtol) + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'batch_size, M, N, K, stype', + MATRIX_PARAMS +) +def test_bmm(format, batch_size, M, N, K, stype): + torch.manual_seed(0) + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32=False + dtype=torch.float32 + else: + dtype=torch.float16 + + a = torch.randn(batch_size, M, K, dtype=dtype, device='cuda') + b = torch.randn(batch_size, K, N, dtype=dtype, device='cuda') + c_triton_jit = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') + c_triton_aot = torch.randn(batch_size, M, N, dtype=dtype, device='cuda') + + c_torch= torch.bmm(a, b) + grid = lambda META: ( + META['BATCH_SIZE'],triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + is_trans_a=False + is_trans_b=False + if format[0]=='c': + a=a.transpose(1,2).contiguous() + is_trans_a=True + if format[1]=='c': + b=b.transpose(1,2).contiguous() + is_trans_b=True + + + bmm_kernel[grid](a,b,c_triton_jit,batch_size,M,N,K,is_trans_a,*a.stride(),is_trans_b,*b.stride(),*c_triton_jit.stride(),64,64,64,False) + kernel=gen_bmm(format,batch_size,M,N,K,stype) + kernel(a,b,c_triton_aot) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + + torch.testing.assert_close(c_triton_aot,c_triton_jit,atol=atol,rtol=rtol) + torch.testing.assert_close(c_triton_aot,c_torch,atol=atol,rtol=rtol) + diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py new file mode 100644 index 000000000..b1fdd93d9 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -0,0 +1,188 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.gemm import gemm_bias as gemm_bias_kernel +from tritontemplate.backend.cuda.gemm import gemm as gemm_kernel +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.gemm import Gemm +from tritontemplate.compiler.compiler import compile_kernel + +def gen_gemm_bias(format, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[N,K]) + Bias=Tensor(name='Bias',dtype=stype,shape=[M,N]) + C=Tensor(name='C',dtype=stype,shape=[M,N]) + gemm_op=Gemm( + inputs=[A,B,Bias], + outputs=None, + layout=format, + is_bias=True, + activation='relu', + ) + kernel = compile_kernel(gemm_op,device='cuda') + return kernel + +def gen_gemm(format, M, N, K, stype): + if format[0]=='r': + A=Tensor(name='A',dtype=stype,shape=[M,K]) + else: + A=Tensor(name='A',dtype=stype,shape=[K,M]) + if format[1]=='r': + B=Tensor(name='B',dtype=stype,shape=[K,N]) + else: + B=Tensor(name='B',dtype=stype,shape=[N,K]) + C=Tensor(name='C',dtype=stype,shape=[M,N]) + gemm_op=Gemm( + inputs=[A,B], + outputs=None, + layout=format, + is_bias=False, + activation='relu', + ) + kernel = compile_kernel(gemm_op,device='cuda') + return kernel + + + +FORMATS = ['rcr','rrr','ccr','crr'] +MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 2, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), +] + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'M, N, K, stype', + MATRIX_PARAMS +) +def test_gemm_bias_relu(format, M, N, K, stype): + + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + + + A = torch.randn((M, K), dtype=dtype, device='cuda') + B = torch.randn((K, N), dtype=dtype, device='cuda') + Bias = torch.randn((N,), dtype=dtype, device='cuda') + c_triton_jit = torch.empty((M, N), device=A.device, dtype=A.dtype) + c_triton_aot = torch.empty((M, N), device=A.device, dtype=A.dtype) + + pytorch_result = torch.nn.functional.relu(A @ B + Bias) + + is_trans_a=False + is_trans_b=False + + if format[0] == 'c': + is_trans_a = True + A = A.transpose(1,0).contiguous() + + if format[1] == 'c': + is_trans_b = True + B = B.transpose(1,0).contiguous() + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + gemm_bias_kernel[grid]( + A, B, Bias, c_triton_jit, + M, N, K, + is_trans_a, *A.stride(), + is_trans_b, *B.stride(), + *c_triton_jit.stride(), + *Bias.stride(), + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + ACTIVATION='relu', enable_tf32=False + ) + kernel = gen_gemm_bias(format, M, N, K, stype) + kernel(A, B, Bias, c_triton_aot) + + assert torch.allclose(c_triton_aot, c_triton_jit, atol=1e-2, rtol=1e-2) + assert torch.allclose(pytorch_result, c_triton_jit, atol=1e-2, rtol=1e-2) + +FORMATS = ['rcr','rrr','ccr','crr'] +MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 2, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), +] + +@pytest.mark.parametrize('format', FORMATS) +@pytest.mark.parametrize( + 'M, N, K, stype', + MATRIX_PARAMS +) +def test_gemm_relu(format, M, N, K, stype): + + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + + A = torch.randn((M, K), dtype=dtype, device='cuda') + B = torch.randn((K, N), dtype=dtype, device='cuda') + c_triton_jit = torch.empty((M, N), device=A.device, dtype=A.dtype) + c_triton_aot = torch.empty((M, N), device=A.device, dtype=A.dtype) + + pytorch_result = torch.nn.functional.relu(A @ B) + + is_trans_a=False + is_trans_b=False + + if format[0] == 'c': + is_trans_a = True + A = A.transpose(1,0).contiguous() + + if format[1] == 'c': + is_trans_b = True + B = B.transpose(1,0).contiguous() + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + ) + + gemm_kernel[grid]( + A, B, c_triton_jit, + M, N, K, + is_trans_a, *A.stride(), + is_trans_b, *B.stride(), + *c_triton_jit.stride(), + BLOCK_SIZE_M=64, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + ACTIVATION='relu', enable_tf32=False + ) + kernel = gen_gemm(format, M, N, K, stype) + kernel(A, B, c_triton_aot) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + assert torch.allclose(c_triton_aot, c_triton_jit, atol=atol, rtol=rtol) + assert torch.allclose(pytorch_result, c_triton_jit, atol=atol, rtol=rtol) + diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py new file mode 100644 index 000000000..aba410aae --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py @@ -0,0 +1,84 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.layernorm import layernorm as kernel_layernorm +from tritontemplate.backend.cuda.layernorm import layernorm_weight_bias as kernel_layernorm_weight_bias +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.layernorm import Layernorm +from tritontemplate.compiler.compiler import compile_kernel + +def gen_layernorm(with_weight_bias,batch,seq_len,hidden_size,stype): + X=Tensor(name='X',shape=(batch,seq_len,hidden_size),dtype=stype) + Y=Tensor(name='Y',shape=(batch,seq_len,hidden_size),dtype=stype) + if with_weight_bias: + W=Tensor(name='W',shape=(hidden_size,),dtype=stype) + B=Tensor(name='B',shape=(hidden_size,),dtype=stype) + op=Layernorm( + inputs=[X,W,B], + outputs=None, + axis=2, + eps=1e-5) + else: + op=Layernorm( + inputs=[X], + outputs=None, + axis=2, + eps=1e-5) + return op + +MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 32, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), +] +FORMATS = ['layernorm','layernorm_weight_bias'] + +@pytest.mark.parametrize('batch,seq_len,hidden_size,stype',MATRIX_PARAMS) +@pytest.mark.parametrize('format',FORMATS) +def test_layernorm(batch,seq_len,hidden_size,stype,format): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + x = torch.randn(batch,seq_len,hidden_size,dtype=dtype,device='cuda') + y_triton_jit = torch.empty(batch,seq_len,hidden_size,dtype=dtype,device='cuda') + y_triton_aot = torch.empty(batch,seq_len,hidden_size,dtype=dtype,device='cuda') + + M= batch*seq_len + N = hidden_size + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + ) + + if format == 'layernorm': + y_torch=torch.nn.functional.layer_norm(x,(N,),eps=1e-5) + kernel_layernorm[grid](x,y_triton_jit,M,N,N,1,N,1,64,64,1e-5) + kernel = gen_layernorm(False,batch,seq_len,hidden_size,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x,y_triton_aot) + + else: + weight = torch.randn(hidden_size,dtype=dtype,device='cuda') + bias = torch.randn(hidden_size,dtype=dtype,device='cuda') + y_torch = torch.nn.functional.layer_norm(x,(N,),weight,bias,eps=1e-5) + kernel_layernorm_weight_bias[grid](x,bias,weight,y_triton_jit,M,N,N,1,N,1,1,1,64,64,1e-5) + kernel = gen_layernorm(True,batch,seq_len,hidden_size,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x,bias,weight,y_triton_aot) + + atol = 1e-2 if dtype == torch.float16 else 1e-4 + rtol = 1e-2 if dtype == torch.float16 else 1e-4 + torch.testing.assert_close(y_triton_jit,y_torch,atol=atol,rtol=rtol) + torch.testing.assert_close(y_triton_aot,y_torch,atol=atol,rtol=rtol) + diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py new file mode 100644 index 000000000..ba8c7b188 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py @@ -0,0 +1,66 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.softmax import softmax as kernel_softmax +from tritontemplate.backend.cuda.softmax import online_softmax as kernel_online_softmax +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.softmax import Softmax +from tritontemplate.compiler.compiler import compile_kernel + +def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim): + A=Tensor(name='A',dtype='float32',shape=[batch,num_heads,seqlen,hidden_dim]) + B=Tensor(name='B',dtype='float32',shape=[batch,num_heads,seqlen,hidden_dim]) + softmax_op=Softmax( + inputs=[A], + dim=3, + enable_online=is_online, + outputs=[B], + ) + kernel = compile_kernel(softmax_op,device='cuda') + return kernel + +FORMATS = [ + 'softmax', + 'online_softmax', +] +MATRIX_PARAMS = [ + (128, 16, 8, 255), + (64, 8, 8, 255), + (64, 16, 2, 66), + (128, 16, 8, 257), + (128, 8, 4, 127), + (128, 8, 8, 63), + (64, 8, 4, 129), + (128, 8, 4, 255), + (64, 8, 4, 63), + (64, 8, 2, 255) + ] + +@pytest.mark.parametrize('hidden_dim, num_heads, batch, seqlen', MATRIX_PARAMS) +@pytest.mark.parametrize('format', FORMATS) +def test_softmax(batch,num_heads,seqlen,hidden_dim, format): + + a = torch.randn(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') + M=batch*seqlen*num_heads + N=hidden_dim + b_triton_jit = torch.empty(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') + b_triton_aot = torch.empty(batch,num_heads, seqlen, hidden_dim, dtype=torch.float32, device='cuda') + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + ) + + if format == 'online_softmax': + kernel_online_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,128,128) + kernel = gen_softmax(True,batch,num_heads,seqlen,hidden_dim) + kernel(a,b_triton_aot) + else: + kernel_softmax[grid](a,b_triton_jit,M,N,N,1,N,1,64,64) + kernel = gen_softmax(False,batch,num_heads,seqlen,hidden_dim) + kernel(a,b_triton_aot) + + b_torch = torch.softmax(a, dim=-1).to(torch.float32) + torch.testing.assert_close(b_triton_jit, b_triton_aot,atol=1e-2,rtol=1e-2) + torch.testing.assert_close(b_triton_aot, b_torch,atol=1e-2,rtol=1e-2) + +test_softmax(16,256,256,128,'online_softmax') \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py new file mode 100644 index 000000000..159867512 --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py @@ -0,0 +1,99 @@ +import torch +import pytest +import triton + +from tritontemplate.backend.cuda.transpose import transpose_10 as kernel_transpose_10 +from tritontemplate.backend.cuda.transpose import transpose_0213 as kernel_transpose_0213 +from tritontemplate.backend.cuda.transpose import gen_grid_transpose_10, gen_grid_transpose_0213 +from tritontemplate.compiler.base import IntImm, Tensor +from tritontemplate.compiler.ops.transpose import Transpose +from tritontemplate.compiler.compiler import compile_kernel + +FORMATS = [ + 'transpose_10', + 'transpose_0213', +] + +MATRIX_PARAMS = [ + (256,128,'float16'), + (255,257,'float32'), + (127,129,'float16'), + (2304,768,'float16'), +] + +TENSOR4D_PARAMS = [ + (16, 64, 128, 32, 'float16'), + (4, 127, 255, 63, 'float32'), + (8, 256, 512, 64, 'float16'), + (8, 1024,8,96, 'float16'), +] + +def gen_transpose_10(M,N,stype): + X = Tensor([M, N], stype) + Y = Tensor([N, M], stype) + op = Transpose([X], '10', outputs=None) + return op + +@pytest.mark.parametrize( + 'M, N, stype', + MATRIX_PARAMS +) +def test_transpose10(M, N, stype): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + x = torch.randn(M, N, dtype=dtype, device='cuda') + y_triton_jit = torch.empty(N, M, dtype=dtype, device='cuda') + y_triton_aot = torch.empty(N, M, dtype=dtype, device='cuda') + BLOCK_M=64 + BLOCK_N=64 + y=x.transpose(0,1).contiguous() + grid = gen_grid_transpose_10(M, N, BLOCK_M,BLOCK_N) + kernel_transpose_10[grid](x, y_triton_jit, M, N, *x.stride(),*y_triton_jit.stride(), BLOCK_M, BLOCK_N) + + kernel = gen_transpose_10(M,N,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x, y_triton_aot) + + torch.testing.assert_close(y_triton_jit, y) + torch.testing.assert_close(y_triton_aot, y) + +def gen_transpose_0213(D0,D1,D2,D3,stype): + X = Tensor([D0, D1, D2, D3], stype) + Y = Tensor([D0, D2, D1, D3], stype) + op = Transpose([X], '0213', None) + return op + +@pytest.mark.parametrize( + 'D0, D1, D2, D3, stype', + TENSOR4D_PARAMS +) +def test_transpose0213(D0,D1,D2,D3, stype): + if stype == 'float32': + torch.backends.cuda.matmul.allow_tf32 = False + dtype = torch.float32 + else: + dtype = torch.float16 + x = torch.randn(D0,D1,D2,D3, dtype=dtype, device='cuda') + y_triton_jit = torch.empty(D0,D2,D1,D3, dtype=dtype, device='cuda') + y_triton_aot = torch.empty(D0,D2,D1,D3, dtype=dtype, device='cuda') + + BLOCK_D1=32 + BLOCK_D2=32 + + y = x.permute(0,2,1,3).contiguous() + grid = gen_grid_transpose_0213(D0, D1, D2, D3, BLOCK_D1, BLOCK_D2) + kernel_transpose_0213[grid](x, y_triton_jit, + D0, D1, D2, D3, + *x.stride(), *y_triton_jit.stride(), + BLOCK_D1, BLOCK_D2) + + kernel = gen_transpose_0213(D0,D1,D2,D3,stype) + kernel_aot = compile_kernel(kernel) + kernel_aot(x, y_triton_aot) + + torch.testing.assert_close(y_triton_jit, y) + torch.testing.assert_close(y_triton_aot, y) + diff --git a/external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py b/external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py new file mode 100644 index 000000000..a652b73dc --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/ptx_gen_demo.py @@ -0,0 +1,77 @@ +import triton +import triton.language as tl +import torch + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, # Pointers to matrices + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, # Matrix dimensions + stride_am: tl.constexpr, stride_ak: tl.constexpr, # A matrix strides + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # B matrix strides + stride_cm: tl.constexpr, stride_cn: tl.constexpr, # C matrix strides + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # Tile sizes +): + # Get program IDs + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + # Create offsets for the block + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rk = tl.arange(0, BLOCK_SIZE_K) + + # Create masks to avoid out-of-bounds accesses + a_mask = (rm[:, None] < M) & (rk[None, :] < K) + b_mask = (rk[:, None] < K) & (rn[None, :] < N) + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + + # Initialize accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over K dimension + for k in range(0, K, BLOCK_SIZE_K): + # Load tiles from A and B + a = tl.load(a_ptr + rm[:, None] * stride_am + rk[None, :] * stride_ak, mask=a_mask) + b = tl.load(b_ptr + rk[:, None] * stride_bk + rn[None, :] * stride_bn, mask=b_mask) + + # Compute matrix multiplication + acc += tl.dot(a, b) + + # Store result + tl.store(c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn, acc, mask=c_mask) + +def compile_matmul_kernel_to_ptx(filename="matmul_kernel.ptx"): + # Define the signature of the kernel + signature = { + 0: '*fp32', # a_ptr + 1: '*fp32', # b_ptr + 2: '*fp32', # c_ptr + } + + # Define compile-time constants + constants = { + 'M': 1024, 'N': 1024, 'K': 1024, # Example matrix dimensions + 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, # Tile sizes + 'stride_am': 1024, 'stride_ak': 1, # Assuming row-major layout + 'stride_bk': 1024, 'stride_bn': 1, # Assuming row-major layout + 'stride_cm': 1024, 'stride_cn': 1, # Assuming row-major layout + } + # AOT compile the kernel + compiled_kernel = triton.compile( + matmul_kernel, + signature=signature, + constants=constants, + num_warps=16, + ) + # Get the PTX assembly code + ptx_code = compiled_kernel.asm['ptx'] + print(f"Number of warps: {compiled_kernel.num_warps}") + + + # Save the PTX code to a file + with open(filename, "w") as f: + f.write(ptx_code) + +if __name__ == "__main__": + # AOT compile to PTX + compile_matmul_kernel_to_ptx("matmul_kernel.ptx") \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/utils/__init__.py b/external/TritonTemplate/python/tritontemplate/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py b/external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py new file mode 100644 index 000000000..2042a44ec --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/utils/tensor_utils.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Util functions to handle tensor shapes. +""" + + +def wrap_dim(idx, rank): + """ + Wrap tensor index, idx, if it's negative. + """ + assert isinstance(idx, int), "idx must be int, but got {}".format(type(idx)) + if idx < 0: + idx = idx + rank + assert idx < rank, "idx {} out of range; rank {}".format(idx, rank) + return idx diff --git a/external/TritonTemplate/python/tritontemplate/utils/torch_utils.py b/external/TritonTemplate/python/tritontemplate/utils/torch_utils.py new file mode 100644 index 000000000..133e120eb --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/utils/torch_utils.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Functions for working with torch Tensors. +AITemplate doesn't depend on PyTorch, but it exposes +many APIs that work with torch Tensors anyways. + +The functions in this file may assume that +`import torch` will work. +""" + +import struct + +import torch + +from tritontemplate.compiler.dtype import dtype_str_to_enum, get_dtype_size, normalize_dtype + + +def types_mapping(): + from torch import bfloat16, bool, float16, float32, int32, int64 + + yield (float16, "float16") + yield (bfloat16, "bfloat16") + yield (float32, "float32") + yield (int32, "int32") + yield (int64, "int64") + yield (bool, "bool") + + +def torch_dtype_to_string(dtype): + for (torch_dtype, ait_dtype) in types_mapping(): + if dtype == torch_dtype: + return ait_dtype + raise ValueError( + f"Got unsupported input dtype {dtype}! " + f"Supported dtypes are: {list(types_mapping())}" + ) + + +def string_to_torch_dtype(string_dtype): + if string_dtype is None: + # Many torch functions take optional dtypes, so + # handling None is useful here. + return None + + for (torch_dtype, ait_dtype) in types_mapping(): + if string_dtype == ait_dtype: + return torch_dtype + raise ValueError( + f"Got unsupported ait dtype {string_dtype}! " + f"Supported dtypes are: {list(types_mapping())}" + ) + + +def write_tensor_binary(tensor: "torch.Tensor", file_handle) -> None: + tensor = tensor.detach().cpu().contiguous() + endianness = "@" # system endianness + dtype_str = normalize_dtype(torch_dtype_to_string(tensor.dtype)) + dtype_int = dtype_str_to_enum(dtype_str) + sizeof_dtype = get_dtype_size(dtype_str) + num_dims = len(tensor.shape) + file_handle.write(struct.pack(endianness + "I", dtype_int)) # unsigned int + file_handle.write(struct.pack(endianness + "I", sizeof_dtype)) # unsigned int + file_handle.write(struct.pack(endianness + "I", num_dims)) # unsigned int + total_size = sizeof_dtype + for dim in tensor.shape: + file_handle.write(struct.pack(endianness + "N", dim)) # size_t + total_size *= dim + file_handle.write(struct.pack(endianness + "N", total_size)) # size_t + bytedata = tensor.numpy().tobytes() + # just as a safety check + if len(bytedata) != total_size: + raise RuntimeError("Tensor has wrong number of bytes!") + file_handle.write(bytedata) diff --git a/frontends/torch-frontend/examples/inference/tit_mlp.py b/frontends/torch-frontend/examples/inference/tit_mlp.py new file mode 100644 index 000000000..79467e543 --- /dev/null +++ b/frontends/torch-frontend/examples/inference/tit_mlp.py @@ -0,0 +1,56 @@ +import os + +import torch +from torch import nn +import torch_frontend +import byteir + +from brt_backend import BRTBackend + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(256, 512,dtype=torch.float32) + self.linear2 = nn.Linear(512, 256,dtype=torch.float32) + self.linear3 = nn.Linear(256, 128,dtype=torch.float32) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + x = torch.nn.functional.relu(x) + x = self.linear3(x) + return x + +workspace = "./workspace" +os.makedirs(workspace, exist_ok=True) +with torch.no_grad(): + torch.backends.cuda.matmul.allow_tf32=False + model = MLP().cuda().eval() + inputs = [torch.randn(128, 256, dtype=torch.float32).cuda()] + traced_model = torch.jit.trace(model, inputs) + + stablehlo_file = workspace + "/model.stablehlo.mlir" + byre_file = workspace + "/model.byre.mlir" + module = torch_frontend.compile(traced_model, inputs, "stablehlo") + with open(stablehlo_file, "w") as f: + f.write(module.operation.get_asm()) + + byteir.compile(stablehlo_file, byre_file, entry_func="forward", target="cuda_with_triton") + + backend = BRTBackend("cuda", byre_file) + byteir_outputs = backend.execute(inputs) + + torch_outputs = model(*inputs) + torch_jit_outputs = traced_model(*inputs) + if len(byteir_outputs) == 1: + byteir_outputs = byteir_outputs[0] + + torch.testing.assert_close(torch_outputs, torch_jit_outputs, rtol=1e-3, atol=1e-3) + try: + torch.testing.assert_close(torch_outputs, byteir_outputs, rtol=1e-3, atol=1e-3) + except AssertionError as e: + diff=torch.abs(torch_outputs-byteir_outputs) + print("diff:",diff) + raise e + print("byteir tit backend success") \ No newline at end of file diff --git a/runtime/lib/backends/cuda/device/cuda_work_queue.cc b/runtime/lib/backends/cuda/device/cuda_work_queue.cc index 8eecda95e..35ec971f1 100644 --- a/runtime/lib/backends/cuda/device/cuda_work_queue.cc +++ b/runtime/lib/backends/cuda/device/cuda_work_queue.cc @@ -16,7 +16,6 @@ //===----------------------------------------------------------------------===// #include "brt/backends/cuda/device/cuda_work_queue.h" - #include "brt/backends/cuda/device/common/cuda_call.h" #include "brt/core/common/common.h" #include "brt/core/ir/ir.h" @@ -73,11 +72,25 @@ inline common::Status ComputeDrv(const void *func, void **args, dim3 *grid = static_cast(args[0]); dim3 *block = static_cast(args[1]); size_t *shared_size = static_cast(args[2]); + CUfunction hFunc = reinterpret_cast(const_cast(func)); + + // extend the shared memory + int shared_optin; + int device_id = -1; + BRT_CUDA_CHECK(cudaGetDevice(&device_id)); + BRT_CUDA_CHECK(cudaDeviceGetAttribute( + &shared_optin, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + + if (shared_optin > 49152 && (*shared_size) > 49152) { + BRT_CU_CHECK(cuFuncSetCacheConfig(hFunc, CU_FUNC_CACHE_PREFER_SHARED)); + BRT_CU_CHECK(cuFuncSetAttribute( + hFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin)); + } + void **kernel_args = args + 3; - return BRT_CU_CALL( - cuLaunchKernel(reinterpret_cast(const_cast(func)), - (*grid).x, (*grid).y, (*grid).z, (*block).x, (*block).y, - (*block).z, *shared_size, stream, kernel_args, 0)); + return BRT_CU_CALL(cuLaunchKernel(hFunc, (*grid).x, (*grid).y, (*grid).z, + (*block).x, (*block).y, (*block).z, + *shared_size, stream, kernel_args, 0)); } inline common::Status ComputeHost(const void *func, void **args, diff --git a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc index dfd6349f6..dad8689bf 100644 --- a/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc +++ b/runtime/lib/backends/cuda/providers/default/codegen/ptx.cc @@ -37,6 +37,7 @@ using namespace mlir; #define FILE_NAME_ATTR "device_file_name" #define KERNEL_NAME_ATTR "kernel_name" +#define SHARED_SIZE_ATTR "SharedMemorySize" #define GRID_SIZE_X_ATTR "GridSize.x" #define GRID_SIZE_Y_ATTR "GridSize.y" #define GRID_SIZE_Z_ATTR "GridSize.z" @@ -138,6 +139,7 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) BRT_THROW_EX(std::runtime_error, "no BlockSize.x attr"); } + size_t shared_size = 0; int gx = static_cast(info.GetOperation() ->getAttrOfType(GRID_SIZE_X_ATTR) .getInt()), @@ -167,6 +169,12 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) ->getAttrOfType(BLOCK_SIZE_Z_ATTR) .getInt()); } + if (info.GetOperation()->hasAttrOfType(SHARED_SIZE_ATTR)) { + shared_size = + static_cast(info.GetOperation() + ->getAttrOfType(SHARED_SIZE_ATTR) + .getInt()); + } std::vector ranks; if (info.GetOperation()->hasAttrOfType(ARG_RANKS_ATTR)) { @@ -181,7 +189,7 @@ PTXOpKernel::PTXOpKernel(const OpKernelInfo &info) auto num_arg = GetOpArgNum(info_); impl_->grid = dim3(gx, gy, gz); impl_->block = dim3(bx, by, bz); - impl_->shared_size = 0; + impl_->shared_size = shared_size; impl_->arg_reserve_size = 3; // initial 3 for grid/block/shared_size // store tensor meta