Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/Conversion/BlockedGpuToTriton/BlockedGpuToTriton.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class ConvertGpuFuncToTritonFunc : public OpConversionPattern<mlir::gpu::GPUFunc
rewriter.create<func::ReturnOp>(gpuFunc.getLoc());
newFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
rewriter.getUnitAttr());

newFunc.setArgAttrsAttr(gpuFunc.getArgAttrsAttr());

return success();
}

Expand Down
34 changes: 34 additions & 0 deletions lib/Conversion/GpuToBlockedGpu/GpuToBlockedGpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,40 @@ class ConvertGpuToBlockedGpu: public CometGpuToBlockedGpuBase<ConvertGpuToBlocke
mlir::gpu::GPUFuncOp funcOp = getOperation();
mlir::OpBuilder builder(funcOp);

for(auto memrefArg : funcOp.getArguments())
{
if(mlir::isa<mlir::MemRefType>(memrefArg.getType()))
{
std::vector<Operation*> toExamine;
toExamine.insert(toExamine.end(), memrefArg.getUsers().begin(), memrefArg.getUsers().end());
for(size_t i = 0; i < toExamine.size(); i++)
{
auto user = toExamine[i];
if(mlir::isa<mlir::memref::LoadOp>(user))
{
funcOp.setArgAttr(memrefArg.getArgNumber(), "gpu.read", builder.getUnitAttr());
}
else if(mlir::isa<mlir::memref::StoreOp>(user))
{
funcOp.setArgAttr(memrefArg.getArgNumber(), "gpu.write", builder.getUnitAttr());
}

if(user->getNumResults() > 0)
{
for(auto res: user->getResults())
{
if(isa<MemRefType>(res.getType()))
{
// If the result is a memref, we need to check its users as well
// to see if it is used in a store or load operation
toExamine.insert(toExamine.end(), res.getUsers().begin(), res.getUsers().end());
}
}
}
}
}
}

for(auto arg: funcOp.getArguments())
{
if(mlir::isa<mlir::MemRefType>(arg.getType()))
Expand Down
122 changes: 112 additions & 10 deletions lib/Conversion/PrepareGpuHost/PrepareGpuHost.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
#include "comet/Conversion/PrepareGpuHost/PrepareGpuHostPass.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -26,24 +31,28 @@ class PrepareGpuHost

std::map<std::string, Value> funcs;
std::map<std::string, std::string> gpu_to_triton_kernel;
std::map<std::string, func::FuncOp> gpu_name_to_funcOp;
std::map<std::string, triton::FuncOp> triton_name_to_triton_func_op;
auto gpuModules = modOp.getOps<gpu::GPUModuleOp>();
for(auto gpuModuleOp: gpuModules)
{
auto funcOps = gpuModuleOp.getOps<mlir::func::FuncOp>();
for(func::FuncOp funcOp: llvm::make_early_inc_range(funcOps))
{
std::map<size_t, std::vector<Attribute>> argsToSet;
if(!funcOp->hasAttr(gpu::GPUDialect::getKernelFuncAttrName()))
{
continue;
}
builder.setInsertionPoint(funcOp);
SmallVector<Type, 4> newTypes;
for(auto argType: funcOp.getArgumentTypes())
for(auto arg: funcOp.getArguments())
{
auto argType = arg.getType();
newTypes.push_back(argType);
if(MemRefType rankedType = dyn_cast<mlir::MemRefType>(argType))
{
argsToSet[newTypes.size() - 1] = {funcOp.getArgAttr(arg.getArgNumber(), "gpu.read"), funcOp.getArgAttr(arg.getArgNumber(), "gpu.write")};
if(rankedType.hasRank())
{
newTypes.push_back(builder.getIndexType());
Expand Down Expand Up @@ -71,6 +80,18 @@ class PrepareGpuHost
builder.create<func::ReturnOp>(funcOp.getLoc());
newFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
for(auto [argNumber, attrs]: argsToSet)
{
if(attrs[0])
{
newFunc.setArgAttr(argNumber, "gpu.read", attrs[0]);
}
if(attrs[1])
{
newFunc.setArgAttr(argNumber, "gpu.write", attrs[1]);
}
}
gpu_name_to_funcOp[funcOp.getName().str()] = newFunc;
funcOp->erase();
}
}
Expand Down Expand Up @@ -185,14 +206,24 @@ class PrepareGpuHost
if (mlir::gpu::LaunchFuncOp launchOp =
dyn_cast<mlir::gpu::LaunchFuncOp>(use.getOwner())) {
builder.setInsertionPoint(launchOp);
builder.create<mlir::gpu::MemcpyOp>(launchOp->getLoc(), TypeRange(),
ValueRange(),
gpuAlloc.getMemref(), alloc);
int offset = launchOp->getNumOperands() - launchOp.getNumKernelOperands();
int operNum = use.getOperandNumber() - offset;
if(gpu_name_to_funcOp[launchOp.getKernelName().str()].getArgAttr(operNum, "gpu.read"))
{
auto gpuMemCpy = builder.create<mlir::gpu::MemcpyOp>(launchOp->getLoc(), TypeRange(),
ValueRange(),
gpuAlloc.getMemref(), alloc);
gpuMemCpy->setAttr("gpu.read", builder.getUnitAttr());
}
use.set(gpuAlloc.getMemref());
builder.setInsertionPointAfter(launchOp);
builder.create<mlir::gpu::MemcpyOp>(launchOp->getLoc(), TypeRange(),
ValueRange(), alloc,
gpuAlloc.getMemref());
if(gpu_name_to_funcOp[launchOp.getKernelName().str()].getArgAttr(operNum, "gpu.write"))
{
auto gpuMemCpy = builder.create<mlir::gpu::MemcpyOp>(launchOp->getLoc(), TypeRange(),
ValueRange(), alloc,
gpuAlloc.getMemref());
gpuMemCpy->setAttr("gpu.write", builder.getUnitAttr());
}
}
}
}
Expand All @@ -214,11 +245,82 @@ class PrepareGpuHost
gpuAllocs.push_back(gpuAllocOp);
});

std::vector<mlir::gpu::MemcpyOp> gpuCopies;
modOp->walk([&gpuCopies](mlir::gpu::MemcpyOp gpuCopy) {
gpuCopies.push_back(gpuCopy);

std::map<void*, std::vector<Operation*>> memEffects;
modOp->walk([&uniqueGpuAllocs, &memEffects](Operation* memEffect) {
for(auto op: memEffect->getOperands())
{
if(uniqueGpuAllocs.find(op) != uniqueGpuAllocs.end())
{
memEffects[op.getAsOpaquePointer()].push_back(memEffect);
}
}
});

for(auto& [memref, effects]: memEffects)
{
bool copyIn = true;
std::vector<Operation*> copyDelete;
for(size_t i = 0; i < effects.size(); i++)
{
if(mlir::gpu::MemcpyOp gpuCopy = mlir::dyn_cast<mlir::gpu::MemcpyOp>(effects[i]))
{
if(!copyIn & gpuCopy->hasAttr("gpu.read"))
{
gpuCopy->erase();
}
else if(gpuCopy->hasAttr("gpu.read"))
{
copyIn = false;
}
else if(gpuCopy->hasAttr("gpu.write"))
{
copyIn = false;
copyDelete.push_back(gpuCopy);
}
}
else if(mlir::memref::StoreOp store = mlir::dyn_cast<mlir::memref::StoreOp>(effects[i]))
{
copyIn = true;
if(!copyDelete.empty())
{
copyDelete.pop_back();
}
}
else if(mlir::memref::CopyOp copy = mlir::dyn_cast<mlir::memref::CopyOp>(effects[i]))
{
if(copy.getTarget().getAsOpaquePointer() == memref)
{
copyIn = true;
}
}
else if(mlir::memref::LoadOp load = mlir::dyn_cast<mlir::memref::LoadOp>(effects[i]))
{
if(!copyDelete.empty())
{
copyDelete.pop_back();
}
}
else if(isa<memref::ExtractStridedMetadataOp, memref::DimOp, memref::RankOp>(effects[i]))
{
continue;
}
else // Unknown operation, be conservative
{
copyIn = true;
if(!copyDelete.empty())
{
copyDelete.pop_back();
}
}
}

for(auto toDelete: copyDelete)
{
toDelete->erase();
}
}

}
};

Expand Down
Loading