Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5ff62eb
[feat] add triton builder
6somehow Jun 11, 2025
2430968
[feat] add tit attr
6somehow Jun 11, 2025
8ae5ae3
[feat] tit demo
6somehow Jun 12, 2025
584341d
[fix] kernel_name bug
6somehow Jun 12, 2025
e343abe
[fix] compile print bug
6somehow Jun 12, 2025
b8facfa
[fix] pm print
6somehow Jun 12, 2025
b489b68
[module] add tritontemplate
6somehow Jun 16, 2025
396edb1
[feat] ShareMemSize info
6somehow Jun 17, 2025
f9c6370
[feat] runtime share memory supported
6somehow Jun 18, 2025
5145026
[feat] tit e2e
6somehow Jun 18, 2025
997dc4c
[fix] tritontemplate example
6somehow Jun 18, 2025
a7c01f1
[fix] tiriton template mlp complement
6somehow Jun 18, 2025
2b49e12
[feat] tf32 option
6somehow Jun 20, 2025
0e7ae32
[feat] Add size under 32 support
6somehow Jun 24, 2025
1658f0e
[format] after clang-format
6somehow Jun 24, 2025
36b62e2
Merge branch 'main' into triton-backend
6somehow Jun 24, 2025
add8f1d
[fix] clear files
6somehow Jun 27, 2025
41108a3
Merge branch 'triton-backend' of github.com:6somehow/byteir into trit…
6somehow Jun 27, 2025
4abc796
[feat] gemm rrr add
6somehow Jul 7, 2025
266f02a
[feat] bmm add
6somehow Jul 10, 2025
b69d010
[fix] reform gemm
6somehow Jul 10, 2025
95901f8
[feat] supported softmax
6somehow Jul 17, 2025
8c1cb7d
[fix] softmax 4d test
6somehow Jul 17, 2025
9dbb131
[feat] supported layernorm
6somehow Jul 17, 2025
bbaa6c9
[feat] supported transpose
6somehow Jul 17, 2025
8f8bcf8
[bug] TritonTemplate align to byteir
6somehow Jul 18, 2025
e6ed72a
[bug] TritonTemplate align to byteir
6somehow Jul 18, 2025
1e1ac17
[bug] fixed e2e nanogpt bug
6somehow Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/include/byteir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "byteir/Conversion/ToLLVM/ToLLVM.h"
#include "byteir/Conversion/ToLinalg/ToLinalg.h"
#include "byteir/Conversion/ToPTX/ToPTX.h"
#include "byteir/Conversion/ToTIT/ToTIT.h"

namespace mlir {

Expand Down
29 changes: 29 additions & 0 deletions compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,35 @@ def GenAITConfig : Pass<"gen-ait-config", "func::FuncOp"> {
];
}

//===----------------------------------------------------------------------===//
// ToTIT
//===----------------------------------------------------------------------===//

def GenTITConfig : Pass<"gen-tit-config", "func::FuncOp"> {
let summary = "Generate TIT configuration";
let constructor = "mlir::createGenTITConfigPass()";
let options = [
ListOption<"funcNames", "func-names", "std::string",
"names of all cat func for TIT backends.">,
ListOption<"titPtxPaths", "tit-ptx-paths", "std::string",
"paths to all TIT-generated .ptx files">,
ListOption<"smemsizeArgs", "smemsize-args", "std::string",
"smemsize args for TIT backends.">,
ListOption<"gridsizeXArgs", "gridsize-x-args", "std::string",
"gridsize x args for TIT backends.">,
ListOption<"gridsizeYArgs", "gridsize-y-args", "std::string",
"gridsize y args for TIT backends.">,
ListOption<"gridsizeZArgs", "gridsize-z-args", "std::string",
"gridsize z args for TIT backends.">,
ListOption<"blocksizeXArgs", "blocksize-x-args", "std::string",
"blocksize x args for TIT backends.">,
ListOption<"blocksizeYArgs", "blocksize-y-args", "std::string",
"blocksize y args for TIT backends.">,
ListOption<"blocksizeZArgs", "blocksize-z-args", "std::string",
"blocksize z args for TIT backends.">,
];
}

//===----------------------------------------------------------------------===//
// ToByre
//===----------------------------------------------------------------------===//
Expand Down
43 changes: 43 additions & 0 deletions compiler/include/byteir/Conversion/ToTIT/ToTIT.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===- ToTIT.h ------------------------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef BYTEIR_CONVERSION_TOTIT_H
#define BYTEIR_CONVERSION_TOTIT_H

#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func
class ModuleOp;

std::unique_ptr<OperationPass<func::FuncOp>>
createGenTITConfigPass(ArrayRef<std::string> funcNames = {""},
ArrayRef<std::string> titPtxPaths = {""},
ArrayRef<std::string> smemsizeArgs = {""},
ArrayRef<std::string> gridsizeXArgs = {""},
ArrayRef<std::string> gridsizeYArgs = {""},
ArrayRef<std::string> gridsizeZArgs = {""},
ArrayRef<std::string> blocksizeXArgs = {""},
ArrayRef<std::string> blocksizeYArgs = {""},
ArrayRef<std::string> blocksizeZArgs = {""});

} // namespace mlir

#endif // BYTEIR_CONVERSION_TOTIT_H
1 change: 1 addition & 0 deletions compiler/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_subdirectory(HloToTensor)
add_subdirectory(MemrefToByre)
add_subdirectory(ToAce)
add_subdirectory(ToAIT)
add_subdirectory(ToTIT)
add_subdirectory(ToByre)
add_subdirectory(ToGPU)
add_subdirectory(ToHlo)
Expand Down
14 changes: 14 additions & 0 deletions compiler/lib/Conversion/ToTIT/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
add_byteir_conversion_library(ByteIRToTIT
GenTITConfig.cpp

ADDITIONAL_HEADER_DIRS
${BYTEIR_SRC_INCLUDE_DIR}/byteir/Conversion/ToTIT

DEPENDS
ByteIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRBufferizationTransforms
ByteIRUtils
)
133 changes: 133 additions & 0 deletions compiler/lib/Conversion/ToTIT/GenTITConfig.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- GenTITConfig.cpp ---------------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include "byteir/Conversion/ToTIT/ToTIT.h"
#include "byteir/Dialect/Byre/Common.h"
#include "byteir/Dialect/mhlo/Transforms/HloFuser.h"
#include "byteir/Utils/FuncUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/Debug.h"

#include "../PassDetail.h"

using namespace mlir;

namespace {

static LogicalResult AttachTITConfigToAttr(
func::FuncOp func, const std::string &titPtxPath,
const std::string &smemsizeArg, const std::string &gridsizeXArg,
const std::string &gridsizeYArg, const std::string &gridsizeZArg,
const std::string &blocksizeXArg, const std::string &blocksizeYArg,
const std::string &blocksizeZArg) {

std::string device_name;
std::string byreKernelName;
if (titPtxPath.find(".ptx") != std::string::npos) {
device_name = "cuda";
byreKernelName = "PTXOp";
}

if (device_name.empty() || byreKernelName.empty()) {
return func.emitError("Invalid device type for TIT configuration");
}
addGenericFuncAttrs(func, byreKernelName);

mlir::OpBuilder opBuilder(func);
llvm::StringMap<mlir::Attribute> titConfig;

// Attach the Byre Tensor Info
titConfig["call_convention"] = opBuilder.getStringAttr("bare_ptr");
titConfig["device"] = opBuilder.getStringAttr(device_name);
titConfig["device_file_name"] = opBuilder.getStringAttr(titPtxPath);

llvm::StringMap<llvm::StringRef> gpuLaunchArgs = {
{"SharedMemorySize", smemsizeArg}, {"BlockSize.x", blocksizeXArg},
{"BlockSize.y", blocksizeYArg}, {"BlockSize.z", blocksizeZArg},
{"GridSize.x", gridsizeXArg}, {"GridSize.y", gridsizeYArg},
{"GridSize.z", gridsizeZArg}};

for (auto &kv : gpuLaunchArgs) {
int val;
if (kv.second.getAsInteger(0, val)) {
return func.emitError("Invalid integer format for ") << kv.first();
}
if (val < 0) {
return func.emitError("Value must be positive for ") << kv.first();
}
titConfig[kv.first()] = opBuilder.getI32IntegerAttr(val);
}

for (auto &kv : titConfig) {
func->setAttr(byre::getByrePrefix() + kv.first().str(), kv.second);
}

return success();
}

struct GenTITConfigPass : public GenTITConfigBase<GenTITConfigPass> {
GenTITConfigPass(ArrayRef<std::string> funcNames,
ArrayRef<std::string> titPtxPaths,
ArrayRef<std::string> smemsizeArgs,
ArrayRef<std::string> gridsizeXArgs,
ArrayRef<std::string> gridsizeYArgs,
ArrayRef<std::string> gridsizeZArgs,
ArrayRef<std::string> blocksizeXArgs,
ArrayRef<std::string> blocksizeYArgs,
ArrayRef<std::string> blocksizeZArgs)
: GenTITConfigBase() {
this->funcNames = funcNames;
this->titPtxPaths = titPtxPaths;
this->smemsizeArgs = smemsizeArgs;
this->gridsizeXArgs = gridsizeXArgs;
this->gridsizeYArgs = gridsizeYArgs;
this->gridsizeZArgs = gridsizeZArgs;
this->blocksizeXArgs = blocksizeXArgs;
this->blocksizeYArgs = blocksizeYArgs;
this->blocksizeZArgs = blocksizeZArgs;
}

void runOnOperation() override {
func::FuncOp func = getOperation();
if (!func->hasAttr(getByteIRCatFusionAttrName()))
return;
for (size_t i = 0; i < funcNames.size(); ++i)
if (func.getSymName() == funcNames[i]) {
if (failed(AttachTITConfigToAttr(
func, titPtxPaths[i], smemsizeArgs[i], gridsizeXArgs[i],
gridsizeYArgs[i], gridsizeZArgs[i], blocksizeXArgs[i],
blocksizeYArgs[i], blocksizeZArgs[i]))) {
return signalPassFailure();
}
}
}
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> mlir::createGenTITConfigPass(
ArrayRef<std::string> funcNames, ArrayRef<std::string> titPtxPaths,
ArrayRef<std::string> smemsizeArgs, ArrayRef<std::string> gridsizeXArgs,
ArrayRef<std::string> gridsizeYArgs, ArrayRef<std::string> gridsizeZArgs,
ArrayRef<std::string> blocksizeXArgs, ArrayRef<std::string> blocksizeYArgs,
ArrayRef<std::string> blocksizeZArgs) {
return std::make_unique<GenTITConfigPass>(
funcNames, titPtxPaths, smemsizeArgs, gridsizeXArgs, gridsizeYArgs,
gridsizeZArgs, blocksizeXArgs, blocksizeYArgs, blocksizeZArgs);
}
115 changes: 115 additions & 0 deletions compiler/python/byteir/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,121 @@ def _compile_cuda_with_ait(
raise ValueError("module asm has be changed after byre serialization")


@register_byteir_compiler_backend(target="cuda_with_triton", device="cuda")
def _compile_cuda_with_triton(
compile_options: CompileOptions,
) -> None:
from .dialects.cat import IRProcessor

target = "cuda"
module = compile_options.module
entry_func = compile_options.entry_func
gpu_arch = compile_options.gpu_arch
verbose = compile_options.verbose
name = compile_options.name
enable_tf32 = compile_options.enable_tf32
parallelism = compile_options.parallelism
disable_byteir_ait_cache = compile_options.disable_byteir_ait_cache

output_file_dir = compile_options.output_dir
output_file_prefix = compile_options.output_file_prefix
output_type = compile_options.output_type
useBarePtrCallConv = True # all tensor must have static shapes if True

context = module.context

entry_func_str = "entry-func={}".format(entry_func)
target_str = "target={}".format(target)

with context:
PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ...

processor = IRProcessor(name,
"./workspace",
enable_tf32=enable_tf32,
compile_parallelism=parallelism,
disable_byteir_ait_cache=disable_byteir_ait_cache,
verbose=verbose)
processor.module = module

processor.preprocess_pass()
processor.cat_opt_pass(anchor_only=False)

with context:
pm = PassManager().parse("builtin.module(hlo-fusion-opt{outline-single-elemwise-op outline-cat-op})")
pm.run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Hlo Fusion Opt (with Cat):") if verbose else ...
# not generate ait lib .so for cat functions
processor.triton_opt_pass(output_file_dir)
module = processor.module

with context:
PassManager.parse("builtin.module(linalg-tensor-opt)").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Linalg Tensor Opt:") if verbose else ...
with context:
if enable_tf32:
PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types enable-tf32 {}}})".format(entry_func_str)).run(processor.module.operation)
else:
PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Byre Tensor Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(byteir-bufferize-opt)").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After ByteIR Bufferize Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(linalg-memref-opt)").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Linalg Memref Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(scf-opt)").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After SCF Opt:") if verbose else ...
with context:
if useBarePtrCallConv:
PassManager.parse("builtin.module(gpu-opt{use-bare-ptr-memref-call-conv=true device-file-name="+ output_file_prefix + ".ptx" + "})").run(module.operation)
else:
PassManager.parse("builtin.module(gpu-opt{device-file-name=" + output_file_prefix + ".ptx" + "})").run(module.operation)
_print_verbose(processor.module, "// IR Dump After GPU Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(inline)").run(processor.module.operation)
PassManager.parse("builtin.module(func.func(lccl-to-byre))").run(module.operation)
PassManager.parse("builtin.module(func.func(gpu-launch-func-to-byre))").run(processor.module.operation)
PassManager.parse("builtin.module(func.func(set-op-space{" + entry_func_str + " space={}".format(target) + "}))").run(processor.module.operation)
PassManager.parse("builtin.module(set-arg-space{" + entry_func_str + " all-space={}".format(target) + "})").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Set Space Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(byre-opt{append-arg-types " + entry_func_str + "})").run(processor.module.operation)
_print_verbose(processor.module, "// IR Dump After Byre Opt:") if verbose else ...

# create device module
module_str = processor.module.operation.get_asm(print_generic_op_form=True)
device_module = ir.Module.parse(module_str, context)
with context:
if useBarePtrCallConv:
PassManager.parse("builtin.module(nvvm-codegen{use-bare-ptr-memref-call-conv=true " + f" gpu-arch={gpu_arch}" + "})").run(device_module.operation)
else:
PassManager.parse("builtin.module(nvvm-codegen{" + f" gpu-arch= {gpu_arch}" + "})").run(device_module.operation)
_print_verbose(device_module, "// IR Dump After NVVM Codegen:") if verbose else ...
# write to output device ptx
byteir.translate_to_ptx(device_module, output_file_dir + "/" + output_file_prefix, gpu_arch)

# create host module
with context:
PassManager.parse("builtin.module(byre-host)").run(processor.module.operation)
PassManager.parse("builtin.module(remove-module-tag{attr-name=gpu.container_module})").run(module.operation)
PassManager.parse("builtin.module(remove-module-tag{attr-name=torch.debug_module_name})").run(module.operation)
_print_verbose(processor.module, "// IR Dump After Byre Host:") if verbose else ...

output_host_mlir_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIR.value)
output_host_mlirbc_path = os.path.join(output_file_dir, output_file_prefix + "." + OutputType.MLIRBC.value)
# write to output host mlir file
with open(output_host_mlir_path, "w") as f:
f.write(module.operation.get_asm())
if output_type is OutputType.MLIRBC:
byteir.serialize_byre(module, compile_options.byre_serial_version, output_host_mlirbc_path)
deserialized_module = byteir.deserialize_byre(open(output_host_mlirbc_path, "rb").read(), context)
if (module.operation.get_asm() != deserialized_module.operation.get_asm()):
raise ValueError("module asm has be changed after byre serialization")


@register_byteir_compiler_backend(target="cpu", device="cpu")
def _compile_cpu(
compile_options: CompileOptions,
Expand Down
Loading