diff --git a/CMakeLists.txt b/CMakeLists.txt index 48290084..d02e2e3a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,6 +107,8 @@ add_custom_target(comet-headers) set_target_properties(comet-headers PROPERTIES FOLDER "Misc") add_custom_target(comet-doc) +set(CMAKE_INCLUDE_CURRENT_DIR ON) + # Add MLIR, LLVM and BLIS headers to the include path include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) @@ -182,3 +184,9 @@ if (STANDALONE_INSTALL) message(STATUS "Setting an $ORIGIN-based RPATH on all executables") set_rpath_all_targets(${CMAKE_CURRENT_SOURCE_DIR}) endif() + +option(DEBUG_MODE "Create a installation with debug information" off) +if (DEBUG_MODE) + message(STATUS "Building comet in debug mode") + add_compile_options(-DCOMET_DEBUG_MODE) +endif() diff --git a/frontends/comet_dsl/CMakeLists.txt b/frontends/comet_dsl/CMakeLists.txt index 17a5af17..07888ed8 100644 --- a/frontends/comet_dsl/CMakeLists.txt +++ b/frontends/comet_dsl/CMakeLists.txt @@ -23,7 +23,7 @@ set(LIBS COMETUtils COMETTensorAlgebraDialect COMETIndexTreeDialect - COMETIndexTreeToSCF + # COMETIndexTreeToSCF ) target_link_libraries(comet-opt diff --git a/frontends/comet_dsl/comet.cpp b/frontends/comet_dsl/comet.cpp index 946db8c5..006d3113 100644 --- a/frontends/comet_dsl/comet.cpp +++ b/frontends/comet_dsl/comet.cpp @@ -41,6 +41,9 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" + #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" @@ -245,7 +248,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, pm.addPass(mlir::comet::createFuncOpLoweringPass()); mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::comet::createRemoveLabeledTensorOpsPass()); + // optPM.addPass(mlir::comet::createRemoveLabeledTensorOpsPass()); /// Check to see if we are dumping to TA dialect. if (emitTA) @@ -287,17 +290,15 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, /// Generate the index tree IR optPM.addPass(mlir::comet::createLowerTensorAlgebraToIndexTreePass()); - if (OptKernelFusion) - { - /// Apply partial fusion on index tree dialect for some compound expressions. - optPM.addPass(mlir::comet::createIndexTreeKernelFusionPass()); - } + // Create new pass manager to optimize the index tree dialect + // mlir::OpPassManager &itOptPM = optPM.nest(); + optPM.addPass(mlir::comet::createIndexTreeDomainInferencePass()); - if (OptWorkspace) - { - /// Optimized workspace transformations, reduce iteration space for nonzero elements - optPM.addPass(mlir::comet::createIndexTreeWorkspaceTransformationsPass()); - } + // if (OptKernelFusion) + // { + // /// Apply partial fusion on index tree dialect for some compound expressions. + // optPM.addPass(mlir::comet::createIndexTreeKernelFusionPass()); + // } /// Dump index tree dialect. if (emitIT) @@ -319,8 +320,9 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, /// sparse input tensor declaration lowering, also generate sparse_output_tensor declaration if needed /// input and output sparse tensor declaration lowering are distant and need different information optPM.addPass(mlir::comet::createSparseTensorDeclLoweringPass()); - // optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass()); optPM.addPass(mlir::comet::createDenseTensorDeclLoweringPass()); + optPM.addPass(mlir::comet::createSparseTempOutputTensorDeclLoweringPass()); + optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass()); optPM.addPass(mlir::comet::createTensorFillLoweringPass()); /// ============================================================================= @@ -332,34 +334,24 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, optPM.addPass(mlir::comet::createLoweringTTGTPass(IsSelectBestPermTTGT, selectedPermNum, IsPrintFlops)); } - /// ============================================================================= - /// Operation based optimizations - /// ============================================================================= - if (OptMatmulTiling) - { - optPM.addPass(mlir::comet::createLinAlgMatmulTilingPass()); - } + // /// ============================================================================= + // /// Operation based optimizations + // /// ============================================================================= + // if (OptMatmulTiling) + // { + // optPM.addPass(mlir::comet::createLinAlgMatmulTilingPass()); + // } - if (OptCallToMatMulMicroKernel) - { - optPM.addPass(mlir::comet::createLinAlgMatmulMicroKernelPass()); - } + // if (OptCallToMatMulMicroKernel) + // { + // optPM.addPass(mlir::comet::createLinAlgMatmulMicroKernelPass()); + // } /// ============================================================================= /// Lowering all the operations to loops /// ============================================================================= if (IsLoweringtoSCF || emitLoops || emitLLVM) - { - /// Workspace transformations will create new dense tensor declarations, so we need to call createDenseTensorDeclLoweringPass - optPM.addPass(mlir::comet::createDenseTensorDeclLoweringPass()); /// lowers dense input/output tensor declaration - optPM.addPass(mlir::comet::createSparseTempOutputTensorDeclLoweringPass()); /// Temporary sparse output tensor declarations introduced by compound expressions - /// should be lowered before sparse output tensor declarations - optPM.addPass(mlir::comet::createSparseOutputTensorDeclLoweringPass()); /// lowering for sparse output tensor declarations - //(sparse_output_tensor_decl and temp_sparse_output_tensor_decl) - /// The partial Fusion pass might add new tensor.fill operations - optPM.addPass(mlir::comet::createTensorFillLoweringPass()); - optPM.addPass(mlir::comet::createPCToLoopsLoweringPass()); - + { /// ============================================================================= /// Lowering of other operations such as transpose, sum, etc. to SCF dialect /// ============================================================================= @@ -367,20 +359,32 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, /// If it is a transpose of sparse tensor, it lowers the code to make a runtime call to specific sorting algorithm optPM.addPass(mlir::comet::createLowerTensorAlgebraToSCFPass()); - /// Finally lowering index tree to SCF dialect - optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass()); - optPM.addPass(mlir::createTensorBufferizePass()); - pm.addPass(mlir::func::createFuncBufferizePass()); /// Needed for func + /// Concretize the domains of all the index variables + optPM.addPass(mlir::comet::createIndexTreeDomainConcretizationPass()); - if (OptDenseTransposeOp) /// Optimize Dense Transpose operation - { - /// If it is a dense transpose ops, the rewrites rules replaces ta.transpose with linalg.transpose, then - /// Create a pass to optimize LinAlg Copy Op - follow in HPTT paper - /// HPTT: A High-Performance Tensor Transposition C++ Library - /// https://arxiv.org/abs/1704.04374 - optPM.addPass(mlir::comet::createOptDenseTransposePass()); + if (OptWorkspace) { + /// Optimized workspace transformations, reduce iteration space for nonzero elements + optPM.addPass(mlir::comet::createIndexTreeWorkspaceTransformationsPass()); } + optPM.addPass(mlir::comet::createIndexTreeSymbolicComputePass()); + + /// Finally lowering index tree to SCF dialect + optPM.addPass(mlir::comet::createLowerIndexTreeToSCFPass()); + optPM.addPass(mlir::comet::createConvertSymbolicDomainsPass()); + optPM.addPass(mlir::comet::createSparseTensorConversionPass()); + optPM.addPass(mlir::comet::createIndexTreeInliningPass()); + optPM.addPass(mlir::createCanonicalizerPass()); + + // if (OptDenseTransposeOp) /// Optimize Dense Transpose operation + // { + // /// If it is a dense transpose ops, the rewrites rules replaces ta.transpose with linalg.transpose, then + // /// Create a pass to optimize LinAlg Copy Op - follow in HPTT paper + // /// HPTT: A High-Performance Tensor Transposition C++ Library + // /// https://arxiv.org/abs/1704.04374 + // optPM.addPass(mlir::comet::createOptDenseTransposePass()); + // } + /// Dump index tree dialect. if (emitLoops) { @@ -388,19 +392,25 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, return 4; return 0; } - /// ============================================================================= } + /// ============================================================================= - /// ============================================================================= - /// Late lowering passes - /// ============================================================================= + // /// ============================================================================= + // /// Late lowering passes + // /// ============================================================================= + // pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass()); + mlir::bufferization::OneShotBufferizationOptions opts; + opts.allowUnknownOps = true; + pm.addPass(mlir::bufferization::createOneShotBufferizePass(opts)); - optPM.addPass(mlir::comet::createSTCRemoveDeadOpsPass()); - optPM.addPass(mlir::comet::createLateLoweringPass()); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); + mlir::OpPassManager &late_lowering_pm = pm.nest(); + late_lowering_pm.addPass(mlir::comet::createSTCRemoveDeadOpsPass()); + late_lowering_pm.addPass(mlir::comet::createLateLoweringPass()); + + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); - /// ============================================================================= + // /// ============================================================================= if (isLoweringToLLVM || emitLLVM) { @@ -481,6 +491,7 @@ int main(int argc, char **argv) context.loadDialect(); context.loadDialect(); context.loadDialect(); + context.loadDialect(); mlir::OwningOpRef module; diff --git a/frontends/comet_dsl/include/Lexer.h b/frontends/comet_dsl/include/Lexer.h index 0dbb8366..658036e8 100644 --- a/frontends/comet_dsl/include/Lexer.h +++ b/frontends/comet_dsl/include/Lexer.h @@ -35,7 +35,7 @@ #include // *********** For debug purpose *********// -//#define COMET_DEBUG_MODE +// #define COMET_DEBUG_MODE #include "comet/Utils/debug.h" #undef COMET_DEBUG_MODE // *********** For debug purpose *********// diff --git a/frontends/comet_dsl/mlir/MLIRGen.cpp b/frontends/comet_dsl/mlir/MLIRGen.cpp index 174e5a19..e989b340 100644 --- a/frontends/comet_dsl/mlir/MLIRGen.cpp +++ b/frontends/comet_dsl/mlir/MLIRGen.cpp @@ -71,7 +71,7 @@ using llvm::Twine; using StringSet = std::set; // *********** For debug purpose *********// -//#define COMET_DEBUG_MODE +// #define COMET_DEBUG_MODE #include "comet/Utils/debug.h" #undef COMET_DEBUG_MODE // *********** For debug purpose *********// @@ -591,23 +591,41 @@ namespace comet_debug() << "\n"; auto lhs_tensor = lhs.getDefiningOp()->getOpResult(0).getType(); - assert(lhs_tensor.isa()); comet_pdump(lhs.getDefiningOp()); + auto lhs_labeledtensor = lhs.getDefiningOp()->getOpResult(0); comet_vdump(lhs_labeledtensor); // ta.labeled_tensor - auto lhs_el_type = lhs_tensor.cast().getElementType(); + mlir::Type lhs_el_type; + if(auto tensor_type = llvm::dyn_cast(lhs_tensor)){ + lhs_el_type = tensor_type.getElementType(); + } + else if(auto tensor_type = llvm::dyn_cast(lhs_tensor)){ + lhs_el_type = tensor_type.getElementType(); + } + else { + assert(false && "Expected a tensor input"); + } auto rhs_tensor = rhs.getDefiningOp()->getOpResult(0).getType(); comet_pdump(rhs.getDefiningOp()); - assert(rhs_tensor.isa()); auto rhs_labeledtensor = rhs.getDefiningOp()->getOpResult(0); comet_vdump(rhs_labeledtensor); - auto rhs_el_type = rhs_tensor.cast().getElementType(); + mlir::Type rhs_el_type; + if(auto tensor_type = llvm::dyn_cast(rhs_tensor)){ + rhs_el_type = tensor_type.getElementType(); + } + else if(auto tensor_type = llvm::dyn_cast(rhs_tensor)){ + rhs_el_type = tensor_type.getElementType(); + } + else { + assert(false && "Expected a tensor input"); + } + auto result_type = getBinOpResultType(lhs_el_type, rhs_el_type); comet_debug() << __LINE__ << " "; comet_vdump(result_type); @@ -817,8 +835,6 @@ namespace } std::vector result_dims = getDimSizes(ret_lbls_value); - auto ret_tensor_type = mlir::RankedTensorType::get(result_dims, result_type); - auto affineMapArrayAttr = builder.getAffineMapArrayAttr(affine_maps); SmallVector formats; @@ -1000,18 +1016,29 @@ namespace } comet_debug() << __LINE__ << " formats.size(): " << formats.size() << "\n"; assert(formats.size() == 2 && " less than 2 input tensors\n"); + mlir::Type ret_tensor_type; if (formats[0].compare("CSR") == 0 && formats[1].compare("CSR") == 0) { formats.push_back("CSR"); + std::vector format_array = getFormats("CSR", result_dims.size(), builder.getContext()); + ret_tensor_type = SparseTensorType::get(builder.getContext(), result_type, result_dims, format_array); } else if (formats[0].compare("Dense") == 0 && formats[1].compare("Dense") == 0) { formats.push_back("Dense"); + ret_tensor_type = mlir::RankedTensorType::get(result_dims, result_type); } else if (out_format.length() > 0) // non-empty format string provided. { comet_debug() << " Output Format: " << out_format << "\n"; formats.push_back(out_format); + if(out_format.compare("Dense") == 0) + { + ret_tensor_type = mlir::RankedTensorType::get(result_dims, result_type); + } else { + std::vector format_array = getFormats(out_format, result_dims.size(), builder.getContext()); + ret_tensor_type = SparseTensorType::get(builder.getContext(), result_type, result_dims, format_array); + } } else { @@ -1604,9 +1631,24 @@ namespace if (isDense(formats_str, ", ") == false) { /// BoolAttr is false because there is explicit sparse densor declaration. - /// SparseTensorDeclOp is not for temporaries in compound expressions + /// SparseTensorDeclOp is not for temporaries in compound expression + std::vector format = mlir::tensorAlgebra::getFormats(tensor_format, dims_sizes.size(), builder.getContext()); + mlir::Type element_type; + switch (vartype.elt_ty) + { + case VarType::TY_FLOAT: + element_type = builder.getF32Type(); + break; + case VarType::TY_DOUBLE: + element_type = builder.getF64Type(); + break; + case VarType::TY_INT: + element_type = builder.getIntegerType(64); + break; + } + auto sp_tensor_type = SparseTensorType::get(builder.getContext(), element_type, dims_sizes, format); value = builder.create(loc(tensordecl.loc()), - tensor_type, labels, tensor_format, false); + sp_tensor_type, labels, tensor_format, false); comet_debug() << "MLIRGen SparseTensorDeclaration creation\n"; comet_vdump(value); } @@ -1864,6 +1906,10 @@ namespace mlir::StringRef format_strref = dyn_cast(rhs_tensor.getDefiningOp()).getFormat(); mlir::StringAttr formatAttr = builder.getStringAttr(format_strref); + std::vector format = mlir::tensorAlgebra::getFormats(format_strref, result_dims.size(), builder.getContext()); + mlir::Type element_type = builder.getF64Type(); + return_type = SparseTensorType::get(builder.getContext(), element_type, result_dims, format); + /// no lhs_LabeledTensor has been created. The output tensor of tranpose doesn't have explicit declaration, /// BoolAttr is true to speficy SparseTensorDeclOp is for temporaries lhs_tensor = builder.create(loc(transpose.loc()), return_type, lhs_labels_val, formatAttr, builder.getBoolAttr(true)); diff --git a/include/comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h b/include/comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h index 2a8f44ba..77431b64 100644 --- a/include/comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h +++ b/include/comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h @@ -34,16 +34,13 @@ namespace mlir namespace comet { #define GEN_PASS_DECL_CONVERTINDEXTREETOSCF +#define GEN_PASS_DECL_CONVERTSYMBOLICDOMAINS #include "comet/Conversion/Passes.h.inc" - - /// Collect a set of patterns to convert IndexTree operations to SCF - /// operations within the SCF dialect. - void populateIndexTreeToSCFConversionPatterns(RewritePatternSet &patterns); - /// Lowers indexTree operations (e.g., IndexTreeComputeLHSOp, IndexTreeComputeRHSOp and IndexTreeComputeOp) /// to equivalent scf constructs including basic blocks and arithmetic /// primitives). std::unique_ptr createLowerIndexTreeToSCFPass(); + std::unique_ptr createConvertSymbolicDomainsPass(); } } // namespace mlir diff --git a/include/comet/Conversion/Passes.td b/include/comet/Conversion/Passes.td index 9e22c582..6fb15d69 100644 --- a/include/comet/Conversion/Passes.td +++ b/include/comet/Conversion/Passes.td @@ -43,6 +43,19 @@ def ConvertIndexTreeToSCF : Pass<"convert-it-to-scf"> { ]; } +def ConvertSymbolicDomains : Pass<"convert-symbolic-domains"> { + let summary = " " + ""; + let description = [{ + + }]; + let constructor = "comet::createConvertSymbolicDomainsPass()"; + let dependentDialects = [ + "memref::MemRefDialect", + "scf::SCFDialect" + ]; +} + //===----------------------------------------------------------------------===// // TensorAlgebraToIndexTree //===----------------------------------------------------------------------===// @@ -73,7 +86,22 @@ def ConvertTensorAlgebraToSCF : Pass<"convert-ta-to-scf"> { let constructor = "comet::createLowerTensorAlgebraToSCFPass()"; let dependentDialects = [ "memref::MemRefDialect", - "scf::SCFDialect" + "scf::SCFDialect", + "index::IndexDialect" + ]; +} + +def SparseTensorConversionPass : Pass<"convert-sparse-tensor"> { + let summary = "Lowers operations on Tensor Algebra sparse tensors to mlir tensor operations"; + let description = [{}]; + + let constructor = "comet::createSparseTensorConversionPass()"; + + let dependentDialects = [ + "memref::MemRefDialect", + "scf::SCFDialect", + "index::IndexDialect", + "tensor::TensorDialect" ]; } diff --git a/include/comet/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.h b/include/comet/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.h index af46bb1d..700a7633 100644 --- a/include/comet/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.h +++ b/include/comet/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.h @@ -25,16 +25,18 @@ #define COMET_CONVERSION_TENSORALGEBRATOSCF_H #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { class Pass; - class RewritePatternSet; - + namespace comet { #define GEN_PASS_DECL_CONVERTTENSORALGEBRATOSCF #include "comet/Conversion/Passes.h.inc" + void populateSparseTensorConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, TypeConverter &typeConverter); + std::unique_ptr createSparseTensorConversionPass(); /// Collect a set of patterns to convert remaining TensorAlgebra operations /// that are not converted to IndexTree operations to the operations with SCF diff --git a/include/comet/Dialect/IndexTree/IR/CMakeLists.txt b/include/comet/Dialect/IndexTree/IR/CMakeLists.txt index 9fb7a6d7..9095274e 100644 --- a/include/comet/Dialect/IndexTree/IR/CMakeLists.txt +++ b/include/comet/Dialect/IndexTree/IR/CMakeLists.txt @@ -1,8 +1,13 @@ set(LLVM_TARGET_DEFINITIONS IndexTreeOps.td) mlir_tablegen(IndexTreeOps.h.inc -gen-op-decls) mlir_tablegen(IndexTreeOps.cpp.inc -gen-op-defs) -mlir_tablegen(IndexTreeDialect.h.inc -gen-dialect-decls) -mlir_tablegen(IndexTreeDialect.cpp.inc -gen-dialect-defs) +mlir_tablegen(IndexTreeDialect.h.inc -gen-dialect-decls -dialect=it) +mlir_tablegen(IndexTreeDialect.cpp.inc -gen-dialect-defs -dialect=it) +mlir_tablegen(IndexTreeOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(IndexTreeOpInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(COMETIndexTreeOpsIncGen) - +set(LLVM_TARGET_DEFINITIONS IndexTreeTypes.td) +mlir_tablegen(IndexTreeTypes.h.inc -gen-typedef-decls) +mlir_tablegen(IndexTreeTypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(COMETIndexTreeTypesIncGen) \ No newline at end of file diff --git a/include/comet/Dialect/IndexTree/IR/IndexTreeBase.td b/include/comet/Dialect/IndexTree/IR/IndexTreeBase.td new file mode 100644 index 00000000..de9b5493 --- /dev/null +++ b/include/comet/Dialect/IndexTree/IR/IndexTreeBase.td @@ -0,0 +1,24 @@ +#ifndef INDEXTREE_BASE +#define INDEXTREE_BASE + +include "mlir/IR/OpBase.td" + +// Provide a definition of the 'it' dialect in the ODS framework so that we +// can define our operations. +def IndexTreeDialect : Dialect { + let name = "it"; + let cppNamespace = "::mlir::indexTree"; + + // We set this bit to generate the declarations for the dialect's type parsing + // and printing hooks. + let useDefaultTypePrinterParser = 1; +} + +class IndexTreeOpTrait : NativeOpTrait<""> { + let trait = name; + let cppNamespace = "::mlir::indexTree"; +} + +def UnknownDomain : IndexTreeOpTrait<"UnknownDomain">; + +#endif // INDEXTREE_BASE \ No newline at end of file diff --git a/include/comet/Dialect/IndexTree/IR/IndexTreeDialect.h b/include/comet/Dialect/IndexTree/IR/IndexTreeDialect.h index 771f01a1..fece8d08 100644 --- a/include/comet/Dialect/IndexTree/IR/IndexTreeDialect.h +++ b/include/comet/Dialect/IndexTree/IR/IndexTreeDialect.h @@ -30,11 +30,25 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/StringSet.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" /// Include the auto-generated header file containing the declaration of the index tree /// dialect. #include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h.inc" +/// Include the auto-generated header file containing the declaration of the index tree +/// types. +#define GET_TYPEDEF_CLASSES +#include "comet/Dialect/IndexTree/IR/IndexTreeTypes.h.inc" + +// Include the trait definitions +#include "comet/Dialect/IndexTree/IR/IndexTreeTraits.h" + +// Include the op interface definitions +#include "comet/Dialect/IndexTree/IR/IndexTreeOpInterfaces.h.inc" + /// Include the auto-generated header file containing the declarations of the /// Index Tree operations and also the operations of the Shape Inference Op Interface. //===----------------------------------------------------------------------===// @@ -43,4 +57,13 @@ //===----------------------------------------------------------------------===// +namespace mlir +{ + namespace indexTree + { + static const llvm::StringSet<> Semiring_intersectOps{"land", "times", "pairxy"}; + } +} + + #endif // INDEXTREE_DIALECT_H_ diff --git a/include/comet/Dialect/IndexTree/IR/IndexTreeOps.td b/include/comet/Dialect/IndexTree/IR/IndexTreeOps.td index 67037b3a..d978141a 100644 --- a/include/comet/Dialect/IndexTree/IR/IndexTreeOps.td +++ b/include/comet/Dialect/IndexTree/IR/IndexTreeOps.td @@ -21,32 +21,24 @@ // // ============================================================================= // -// Defines the operations of the IT dialect. +// Defines the operations of the IndexTree dialect. // //===----------------------------------------------------------------------===// #ifndef INDEXTREE_OPS #define INDEXTREE_OPS -include "mlir/IR/OpBase.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/IR/RegionKindInterface.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" - -// Provide a definition of the 'it' dialect in the ODS framework so that we -// can define our operations. -def IndexTreeDialect : Dialect { - let name = "it"; - let cppNamespace = "::mlir::indexTree"; - - // We set this bit to generate the declarations for the dialect's type parsing - // and printing hooks. - let useDefaultTypePrinterParser = 1; - -} +include "comet/Dialect/IndexTree/IR/IndexTreeTypes.td" +include "comet/Dialect/TensorAlgebra/IR/TATypes.td" // Base class for ta dialect operations. This operation inherits from the base // `Op` class in OpBase.td, and provides: @@ -60,60 +52,282 @@ class IndexTree_Op traits = []> : // Index Tree Operations //===----------------------------------------------------------------------===// -def IndexTreeComputeLHSOp : IndexTree_Op<"ComputeLHS", [Pure]>{ +def IndexTreeOp : IndexTree_Op<"itree", + [SingleBlockImplicitTerminator<"indexTree::YieldOp">]> { + let summary = "Create a scope for index tree iteration"; + let description = [{}]; + + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$region); +} + +def YieldOp : IndexTree_Op<"yield", [Pure, ReturnLike, Terminator, + HasParent<"IndexTreeOp">]> { + let summary = "index tree yield and termination operation"; + let description = [{ + }]; + + let arguments = (ins Variadic:$results); + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + + let assemblyFormat = + [{ attr-dict ($results^ `:` type($results))? }]; +} + +def IndexTreeRootOp : IndexTree_Op<"RootOp", [ + HasParent<"IndexTreeOp">, Pure]> { + let summary = "Create the base of the iteration tree"; + let description = [{}]; + + let results = (outs IndexTree_TreeType:$output); +} + +def IndexTreeIndicesOp : IndexTree_Op<"IndexOp", [Pure]>{ + let summary = "Create an index variable bound to a specific computation."; + let description = [{ + }]; + + let arguments = (ins IndexTree_NodeType:$parent, Optional:$domain); + + let results = (outs IndexTree_IndexNodeType:$output); +} + +def IndexTreeMaskOp : IndexTree_Op<"MaskOp", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins TA_AnyTensor:$tensor, TA_AnyTensor:$mask, StrAttr:$mask_type); + let results = (outs TA_AnyTensor:$masked_tensor); +} + +def IndexTreeLHSOperandOp : IndexTree_Op<"LHSOperandOp", [Pure, SameVariadicOperandSize]>{ let summary = ""; let description = [{}]; - let arguments = (ins Variadic:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats); - let results = (outs AnyType:$output); + let arguments = (ins + TA_AnyTensor:$tensor, + Variadic:$pos, + Variadic:$crds); + let results = (outs IndexTree_OperandType:$result); } -def IndexTreeComputeRHSOp : IndexTree_Op<"ComputeRHS", [Pure]>{ +def IndexTreeOperandOp : IndexTree_Op<"OperandOp", [Pure, SameVariadicOperandSize]>{ let summary = ""; let description = [{}]; - let arguments = (ins Variadic:$tensors, ArrayAttr:$allPerms, ArrayAttr:$allFormats); - let results = (outs AnyType:$output); + let arguments = (ins + TA_AnyTensor:$tensor, + Variadic:$pos, + Variadic:$crds); + let results = (outs IndexTree_OperandType:$result); } -def IndexTreeComputeOp : IndexTree_Op<"Compute", [Pure]>{ +def IndexTreeComputeOp : IndexTree_Op<"ComputeOp", [Pure]>{ let summary = ""; let description = [{ }]; //TODO(gkestor): rethink the use of comp_worksp_opt, should we decouple that? /// MaskType attribute: {push, pull, auto, none} - let arguments = (ins Variadic:$rhs, AnyType:$lhs, BoolAttr:$comp_worksp_opt, StrAttr:$semiring, StrAttr:$MaskType); + let arguments = (ins IndexTree_NodeType:$parent, IndexTree_OperandType:$lhs, Variadic:$rhs, StrAttr:$semiring); - let results = (outs I64:$output); + let results = (outs TA_AnyTensor); //TODO(gkestor): add verifier //let hasVerifier = 1; - } -def IndexTreeIndicesOp : IndexTree_Op<"Indices", [Pure]>{ +def IndexTreeTensorDomainOp : IndexTree_Op<"DomainOp", [Pure]>{ let summary = ""; let description = [{ }]; - let arguments = (ins Variadic:$children, ArrayAttr:$indices); - let results = (outs I64:$output); + let arguments = (ins TA_AnyTensor:$tensor, UI32Attr:$dim, I32Attr:$format, Optional>:$parent); - //TODO(gkestor): add verifier - //let hasVerifier = 1; + let results = (outs IndexTree_DomainType:$domain); } -def IndexTreeOp : IndexTree_Op<"itree", [Pure]>{ +def IndexTreeEmptyDomainOp : IndexTree_Op<"EmptyDomain", [Pure]>{ let summary = ""; let description = [{ }]; - let arguments = (ins AnyType:$children); - let results = (outs I64:$output); + let results = (outs IndexTree_DomainType:$domain); +} - //TODO(gkestor): add verifier - //let hasVerifier = 1; +def ConcreteDomainInterface : OpInterface<"ConcreteDomain"> { + let description = [{ + Describes an operation which implements a concrete domain. That is + the domain is fully-dscribed by it's arguments and does not rely on the + the tensor definition to create the domain. + }]; + let cppNamespace = "::mlir::indexTree"; + + let methods = [ + InterfaceMethod<[{ + Method to get the dimension size of a concrete dimension. + }], + "mlir::Value", "getDimensionSize", (ins), /*methodBody=*/[{ + return $_op.getDimSize(); + }]>, + ]; +} + +def IndexTreeSparseDomainOp : IndexTree_Op<"SparseDomainOp", + [Pure, ConcreteDomainInterface]>{ + + let summary = ""; + let description = [{ + }]; + + let arguments = (ins + TA_AnyTensor:$tensor, + UI32Attr:$dim, + I32Attr:$format, + AnyTensor:$pos, + AnyTensor:$crd, + Index:$pos_size, + Index:$crd_size, + Index:$dim_size, + Optional>:$parent); + + let results = (outs IndexTree_DomainType:$domain); +} + +def IndexTreeDenseDomainOp : IndexTree_Op<"DenseDomainOp", + [Pure, SameVariadicOperandSize, ConcreteDomainInterface]>{ + let summary = ""; + let description = [{ + }]; + + let arguments = (ins AnyTypeOf<[UI32,Index]>:$dim_size, Variadic:$tensors, I32ArrayAttr:$dims); + + let results = (outs IndexTree_DomainType:$domain); +} + +def IndexTreeWorkspaceDomainOp : IndexTree_Op<"WorkspaceDomainOp", + [Pure, ConcreteDomainInterface]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins + WorkspaceTensor:$tensor, + Index:$dim_size, + UI32Attr:$dim, + Optional>:$parent + ); + let results = (outs IndexTree_DomainType:$domain); +} + +def IndexTreeDomainUnionOp : IndexTree_Op<"DomainUnionOp", + [Pure, UnknownDomain, ConcreteDomainInterface, AttrSizedOperandSegments]>{ + let summary = ""; + let description = [{ + }]; + + let arguments = (ins Variadic:$domains, Optional:$dim_size); + let results = (outs IndexTree_DomainType:$domain); +} + +def IndexTreeDomainIntersectionOp : IndexTree_Op<"DomainIntersectionOp", + [Pure, UnknownDomain, ConcreteDomainInterface, AttrSizedOperandSegments]>{ + let summary = ""; + let description = [{ + }]; + + let arguments = (ins Variadic:$domains, Optional:$dim_size); + let results = (outs IndexTree_DomainType:$domain); +} + +def IndexTreeNestedDomainOp : IndexTree_Op<"NestedDomainOp", + [Pure, UnknownDomain, ConcreteDomainInterface]>{ + let summary = ""; + let description = [{ + }]; + + let arguments = (ins Variadic:$domains, Index:$dim_size); + let results = (outs IndexTree_DomainType:$domain); +} + +def DomainGetSize : IndexTree_Op<"DomainGetSize", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins IndexTree_DomainType:$domain); + let results = (outs Index:$result); +} + +def IndexTreeIndexToTensorOp : IndexTree_Op<"IndexToTensorDim", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins + TA_AnyTensor:$tensor, + IndexTree_IndexNodeType:$index, + UI32Attr:$dim, + Optional:$prev_dim + ); + + let results = (outs + Index:$crd, + Index:$pos + ); +} + +def DeclDomainOp : IndexTree_Op<"DeclDomainOp", [Pure]>{ + let summary = ""; + let description = [{ + }]; + + let arguments = (ins Index:$dim_size, Index:$num_rows, OptionalAttr:$is_dynamic); + let results = (outs IndexTree_SymbolicDomainType); +} + +def ComputeSymbolicDomainOp : IndexTree_Op<"ComputeSymbolicDomainOp", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins IndexTree_NodeType:$parent, IndexTree_SymbolicDomainType:$domain, DefaultValuedAttr:$is_unique); + let results = (outs IndexTree_SymbolicDomainType); +} + +def ComputeSymbolicDomainRowOp : IndexTree_Op<"ComputeSymbolicDomainRowOp", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins IndexTree_NodeType:$parent, IndexTree_SymbolicDomainType:$domain, DefaultValuedAttr:$needs_mark); + let results = (outs IndexTree_SymbolicDomainType); +} + +def SymbolicDomainInsertOp : IndexTree_Op<"SymbolicDomainInsertOp", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins IndexTree_SymbolicDomainType:$domain, Index:$crd, DefaultValuedAttr:$is_unique); + let results = (outs IndexTree_SymbolicDomainType); +} + +def SymbolicDomainEndRowOp : IndexTree_Op<"SymbolicDomainEndRowOp", [Pure]>{ + let summary = ""; + let description = [{}]; + + let arguments = (ins IndexTree_SymbolicDomainType:$domain, DefaultValuedAttr:$needs_mark); + let results = (outs IndexTree_SymbolicDomainType); } +def IndexTreeSparseTensorOp : IndexTree_Op<"IndexTreeSparseTensorOp", [Pure]>{ + let summary = "Declare a sparse tensor from an index tree domain"; + let description = [{}]; + + let arguments = (ins Variadic:$domains); + let results = (outs TA_AnyTensor); +} + +def IndexTreeCleanWorkspaceOp : IndexTree_Op<"WorkspaceStartRowOp", [Pure]>{ + let summary = "Associate the loop of this index variable with a clean workspace"; + let description = [{}]; + + let arguments = (ins IndexTree_NodeType:$parent, WorkspaceTensor:$workspace); + let results = (outs WorkspaceTensor:$result); +} #endif // INDEXTREE_OPS \ No newline at end of file diff --git a/include/comet/Dialect/IndexTree/IR/IndexTreeTraits.h b/include/comet/Dialect/IndexTree/IR/IndexTreeTraits.h new file mode 100644 index 00000000..de833da3 --- /dev/null +++ b/include/comet/Dialect/IndexTree/IR/IndexTreeTraits.h @@ -0,0 +1,15 @@ +#ifndef INDEXTREE_TRAITS_H_ +#define INDEXTREE_TRAITS_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace indexTree { + template + class UnknownDomain : public ::mlir::OpTrait::TraitBase {}; + +} // indexTree +} // mlir + +#endif //INDEXTREE_TRAITS_H_ \ No newline at end of file diff --git a/include/comet/Dialect/IndexTree/IR/IndexTreeTypes.td b/include/comet/Dialect/IndexTree/IR/IndexTreeTypes.td new file mode 100644 index 00000000..c11e59a1 --- /dev/null +++ b/include/comet/Dialect/IndexTree/IR/IndexTreeTypes.td @@ -0,0 +1,59 @@ +#ifndef INDEXTREE_TYPES +#define INDEXTREE_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "comet/Dialect/IndexTree/IR/IndexTreeBase.td" + +class IndexTree_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def IndexTree_TreeType : IndexTree_Type<"IndexTree", "index_tree"> { + let summary = "Operand for compute expression"; + let description = [{}]; +} + +def IndexTree_IndexNodeType : IndexTree_Type<"IndexNode", "index"> { + let summary = "Iteration Tree for Index Variables"; + let description = [{ + Type for storing iteration tree of index variables. + }]; +} + +def IndexTree_NodeType : AnyTypeOf<[IndexTree_TreeType, IndexTree_IndexNodeType]>; + +def IndexTree_TensorAccessType : IndexTree_Type<"TensorAccess", "tensor_access"> { + let summary = "Tensor access variabls"; + let description = [{ + Type for storing tensor access of index variables. + }]; +} + +def IndexTree_OperandType : IndexTree_Type<"Operand", "operand"> { + let summary = "Operand for compute expression"; + let description = [{}]; +} + +def IndexTree_SymbolicDomainType : IndexTree_Type<"SymbolicDomain", "symbolic_domain"> { + let summary = "Type representing a computed iteration/tensor domain"; + let description = [{}]; +} + +def IndexTree_DomainType : IndexTree_Type<"Domain", "domain"> { + let summary = "Type representing an iteration domain"; + let description = [{ + Type for storing iteration domain of an index variable. + }]; +} + +def IndexTree_AnyDomainType : AnyTypeOf<[IndexTree_SymbolicDomainType, IndexTree_DomainType]>; + +def IndexTree_ProvenanceGraphType : IndexTree_Type<"ProvenanceGraph", "prov_graph"> { + let summary = "Provenance graph for index tree transformations"; + let description = [{ + Type for storing provenance graph associated with a index. + }]; +} + +#endif // INDEXTREE_TYPES diff --git a/include/comet/Dialect/IndexTree/Passes.h b/include/comet/Dialect/IndexTree/Passes.h index 951c0856..582fa1da 100644 --- a/include/comet/Dialect/IndexTree/Passes.h +++ b/include/comet/Dialect/IndexTree/Passes.h @@ -33,12 +33,23 @@ namespace mlir /// Generate the code for registering conversion passes. #define GEN_PASS_DECL #include "comet/Dialect/IndexTree/Passes.h.inc" + // Create a pass for infering the domain from the use of the index variables + std::unique_ptr createIndexTreeDomainInferencePass(); + + // Create a pass for concretizing the domain from the tensor definitions + std::unique_ptr createIndexTreeDomainConcretizationPass(); + + // Create a pass for creating the symbolic pass + std::unique_ptr createIndexTreeSymbolicComputePass(); + + // Create a pass for inlining the index tree + std::unique_ptr createIndexTreeInliningPass(); /// Create a pass for applying compressed workspace transformation into IndexTreeIR std::unique_ptr createIndexTreeWorkspaceTransformationsPass(); /// Create a pass for the redundancy-aware kernel fusion on index tree dialect for some compound expressions - std::unique_ptr createIndexTreeKernelFusionPass(); + // std::unique_ptr createIndexTreeKernelFusionPass(); } } diff --git a/include/comet/Dialect/IndexTree/Passes.td b/include/comet/Dialect/IndexTree/Passes.td index 62eae9bd..583f3e68 100644 --- a/include/comet/Dialect/IndexTree/Passes.td +++ b/include/comet/Dialect/IndexTree/Passes.td @@ -30,36 +30,66 @@ include "mlir/Pass/PassBase.td" /// Kernel Fusion ///===----------------------------------------------------------------------===/// -def IndexTreeKernelFusion : Pass<"indextree-kernel-fusion"> { - let summary = "The redundancy-aware kernel fusion on index tree dialect for compound expressions"; - let description = [{ +//def IndexTreeKernelFusion : Pass<"indextree-kernel-fusion"> { +// let summary = "The redundancy-aware kernel fusion on index tree dialect for compound expressions"; +// let description = [{ - }]; - let constructor = "comet::createIndexTreeKernelFusionPass()"; - let dependentDialects = [ - "comet::IndexTreeDialect", - "memref::MemRefDialect", - "scf::SCFDialect" - ]; -} +// }]; +// let constructor = "comet::createIndexTreeKernelFusionPass()"; +// let dependentDialects = [ +// "comet::IndexTreeDialect", +// "memref::MemRefDialect", +// "scf::SCFDialect" +// ]; +//} ///===----------------------------------------------------------------------===/// /// Workspace Transformations ///===----------------------------------------------------------------------===/// -def IndexTreeWorkspaceTranformations: Pass<"indextree-workspace-transformations"> { +def IndexTreeWorkspaceTranformations: Pass<"indextree-workspace-transformations", "func::FuncOp"> { let summary = "Compressed workspace transformation on IndexTree dialect" "to produce sparse output"; - let description = [{ + let description = [{}]; + let constructor = "comet::createIndexTreeWorkspaceTransformationsPass()"; + let dependentDialects = ["comet::IndexTreeDialect"]; +} + +def IndexTreeDomainInference : Pass<"indextree-domain-inference", "func::FuncOp"> { + let summary = "Infer domain of index variables"; + let description = [{ + Propogate domain values from tensors, through compute operations to index variables. + Necessary for eventual lowering of index variables to for loops + }]; + + let constructor = "comet::createIndexTreeDomainInferencePass()"; +} - }]; - let constructor = "comet::createIndexTreeWorkspaceTransformationsPass("; - let dependentDialects = [ - "comet::IndexTreeDialect", - "memref::MemRefDialect", - "scf::SCFDialect" - ]; +def IndexTreeDomainConcretization : Pass<"indextree-domain-concretization", "func::FuncOp"> { + let summary = "Transform tensor domains into concrete descriptions of dense or sparse domains"; + let description = [{ + Transform tensor domains into concrete descriptions of dense or sparse domains + }]; + + let constructor = "comet::createIndexTreeDomainConcretizationPass()"; } +def IndexTreeSymbolicComputePass : Pass<"indextree-symbolic-compute", "func::FuncOp"> { + let summary = "Create index-tree for symbolic computations for sparse output"; + let description = [{ + Inserting values into a sparse output tensor cannot be done in parallel because it requires + sequential access to the row pointers and dynamic allocation. The symbolic pass + computes the size of allocations and row pointers without performing the computations + so the sparse tensors can be allocated and the computations can be done in parallel. + }]; + + let constructor = "comet::createIndexTreeSymbolicComputePass()"; +} + +def IndexTreeInliningPass : Pass <"indextree-inlining", "func::FuncOp"> { + let summary = "Inline index tree"; + let description = [{}]; + let constructor = "comet::createIndexTreeInliningPass()"; +} #endif /// COMET_DIALECT_INDEXTREE_PASSES diff --git a/include/comet/Dialect/TensorAlgebra/IR/TATypes.h b/include/comet/Dialect/IndexTree/Patterns.h similarity index 63% rename from include/comet/Dialect/TensorAlgebra/IR/TATypes.h rename to include/comet/Dialect/IndexTree/Patterns.h index 333b2d93..26872f74 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TATypes.h +++ b/include/comet/Dialect/IndexTree/Patterns.h @@ -1,4 +1,4 @@ -//===- TATypes.h - Types definitions for the TensorAlgebra IR ----------------------===// +//===- Patterns.h - Conversion Pass Construction and Registration -----------===// // // Copyright 2022 Battelle Memorial Institute // @@ -19,39 +19,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // -// ============================================================================= -// -// This file implements the IR Dialect Types for Tensor Algebra Dialect. -// //===----------------------------------------------------------------------===// -#ifndef TENSORALGEBRA_TYPES_H_ -#define TENSORALGEBRA_TYPES_H_ +#ifndef COMET_DIALECT_INDEXTREE_PATTERNS_H +#define COMET_DIALECT_INDEXTREE_PATTERNS_H -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Types.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { - class MLIRContext; - - namespace tensorAlgebra - { - - class RangeType : public Type::TypeBase + namespace indexTree { - public: - /// Used for generic hooks in TypeBase. - using Base::Base; - - static RangeType get(MLIRContext *context) - { - /// Custom, uniq'ed construction in the MLIRContext. - return Base::get(context); - } - }; - - } /// namespace tensorAlgebra -} /// namespace mlir - -#endif /// TENSORALGEBRA_TYPES_H_ \ No newline at end of file + void populateDomainInferencePatterns(MLIRContext *context, RewritePatternSet &patterns); + void populateDomainConcretizationPatterns(MLIRContext *context, RewritePatternSet &patterns); + void populateIndexTreeTypeConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, TypeConverter &typeConverter, ConversionTarget& target); + void populateIndexTreeInliningPatterns(MLIRContext *context, RewritePatternSet &patterns); + } +} + +#endif // COMET_DIALECT_INDEXTREE_PASSES_H diff --git a/include/comet/Dialect/TensorAlgebra/IR/CMakeLists.txt b/include/comet/Dialect/TensorAlgebra/IR/CMakeLists.txt index b5a0ddeb..dc504d68 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/CMakeLists.txt +++ b/include/comet/Dialect/TensorAlgebra/IR/CMakeLists.txt @@ -3,4 +3,13 @@ mlir_tablegen(TAOps.h.inc -gen-op-decls) mlir_tablegen(TAOps.cpp.inc -gen-op-defs) mlir_tablegen(TADialect.h.inc -gen-dialect-decls) mlir_tablegen(TADialect.cpp.inc -gen-dialect-defs) -add_public_tablegen_target(COMETTensorAlgebraOpsIncGen) \ No newline at end of file +mlir_tablegen(TAEnums.h.inc -gen-enum-decls) +mlir_tablegen(TAEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(TAAttrs.h.inc -gen-attrdef-decls) +mlir_tablegen(TAAttrs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(COMETTensorAlgebraOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TATypes.td) +mlir_tablegen(TATypes.h.inc -gen-typedef-decls) +mlir_tablegen(TATypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(COMETTensorAlgebraTypesIncGen) \ No newline at end of file diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAAttrs.td b/include/comet/Dialect/TensorAlgebra/IR/TAAttrs.td new file mode 100644 index 00000000..8257d645 --- /dev/null +++ b/include/comet/Dialect/TensorAlgebra/IR/TAAttrs.td @@ -0,0 +1,10 @@ +#ifndef TA_ATTRS +#define TA_ATTRS + +include "comet/Dialect/TensorAlgebra/IR/TAEnums.td" + +def TAFormatArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; +} + +#endif //TA_ATTRS \ No newline at end of file diff --git a/include/comet/Dialect/TensorAlgebra/IR/TABase.td b/include/comet/Dialect/TensorAlgebra/IR/TABase.td new file mode 100644 index 00000000..a6f58f21 --- /dev/null +++ b/include/comet/Dialect/TensorAlgebra/IR/TABase.td @@ -0,0 +1,15 @@ +#ifndef TA_BASE +#define TA_BASE + +/// Provide a definition of the 'TA' dialect in the ODS framework so that we +/// can define our operations. +def TA_Dialect : Dialect { + let name = "ta"; + let cppNamespace = "::mlir::tensorAlgebra"; + + /// We set this bit to generate the declarations for the dialect's type parsing + /// and printing hooks. + let useDefaultTypePrinterParser = 1; +} + +#endif //TA_BASE \ No newline at end of file diff --git a/include/comet/Dialect/TensorAlgebra/IR/TADialect.h b/include/comet/Dialect/TensorAlgebra/IR/TADialect.h index ab6e931a..b702da55 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TADialect.h +++ b/include/comet/Dialect/TensorAlgebra/IR/TADialect.h @@ -28,7 +28,6 @@ #ifndef TENSORALGEBRA_DIALECT_H_ #define TENSORALGEBRA_DIALECT_H_ -#include "comet/Dialect/TensorAlgebra/IR/TATypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -42,6 +41,18 @@ /// dialect. #include "comet/Dialect/TensorAlgebra/IR/TADialect.h.inc" +/// Include the auto-generated enum declerations +//===---------------------------------------------------------------------===// +#include "comet/Dialect/TensorAlgebra/IR/TAEnums.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "comet/Dialect/TensorAlgebra/IR/TATypes.h.inc" + +/// Include the auto-generated header file containing the declaration of the index tree +/// types. +#define GET_ATTRDEF_CLASSES +#include "comet/Dialect/TensorAlgebra/IR/TAAttrs.h.inc" + /// Include the auto-generated header file containing the declarations of the /// tensorAlgbra operations and also the operations of the Shape Inference Op Interface. //===----------------------------------------------------------------------===// @@ -52,13 +63,6 @@ namespace mlir { namespace tensorAlgebra { - std::vector getFormatsValue(std::string formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IndexType indexType); - - namespace detail - { - struct SparseTensorTypeStorage; - } /// end namespace detail - void populateMultiOpFactorizationPatterns( RewritePatternSet &patterns, MLIRContext *context); @@ -69,36 +73,7 @@ namespace mlir RewritePatternSet &patterns, MLIRContext *context); void populateSTCRemoveDeadOpsPatterns( - RewritePatternSet &patterns, MLIRContext *context); - - //===----------------------------------------------------------------------===// - /// Tensor Algebra Types - //===----------------------------------------------------------------------===// - - /// This class defines the TA sparse tensor type. It represents a collection of - /// element types for data and indices of COO format. - /// All derived types in MLIR must inherit from the CRTP class - /// 'Type::TypeBase'. It takes as template parameters the concrete type - /// (SparseTensorType), the base class to use (Type), and the storage class - /// (SparseTensorTypeStorage). - class SparseTensorType : public mlir::Type::TypeBase - { - public: - /// Inherit some necessary constructors from 'TypeBase'. - using Base::Base; - - /// Create an instance of a `SparseTensorType` with the given element types. There - /// *must* be atleast one element type. - static SparseTensorType get(llvm::ArrayRef elementTypes); - - /// Returns the element types of this sparse tensor type. - llvm::ArrayRef getElementTypes(); - - /// Returns the number of element type held by this sparse tensor. - size_t getNumElementTypes() { return getElementTypes().size(); } - }; - + RewritePatternSet &patterns, MLIRContext *context); } /// end namespace tensorAlgebra } /// end namespace mlir diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAEnums.td b/include/comet/Dialect/TensorAlgebra/IR/TAEnums.td new file mode 100644 index 00000000..26f67e1e --- /dev/null +++ b/include/comet/Dialect/TensorAlgebra/IR/TAEnums.td @@ -0,0 +1,19 @@ +#ifndef TA_ENUMS +#define TA_ENUMS + +include "mlir/IR/EnumAttr.td" + +def Unknown: I32EnumAttrCase<"UNK", 0, "unk">; +def Dense: I32EnumAttrCase<"D", 1, "d">; +def Compressed: I32EnumAttrCase<"CU", 2, "cu">; +def CompressedNonunique: I32EnumAttrCase<"CN", 3, "cn">; +def Singular: I32EnumAttrCase<"S", 4, "s">; + +def TAFormatEnum: I32EnumAttr<"TensorFormatEnum", "Valid format specifiers", + [Unknown, Dense, Compressed, CompressedNonunique, Singular]> { + let cppNamespace = "::mlir::tensorAlgebra"; + let stringToSymbolFnName = "ConvertToEnum"; + let symbolToStringFnName = "ConvertToString"; +} + +#endif //TA_ENUMS \ No newline at end of file diff --git a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td index 835e699a..34007c1e 100644 --- a/include/comet/Dialect/TensorAlgebra/IR/TAOps.td +++ b/include/comet/Dialect/TensorAlgebra/IR/TAOps.td @@ -29,35 +29,16 @@ #define TA_OPS -include "mlir/IR/OpBase.td" +include "mlir/IR/OpBase.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" -/// Provide a definition of the 'TA' dialect in the ODS framework so that we -/// can define our operations. -def TA_Dialect : Dialect { - let name = "ta"; - let cppNamespace = "::mlir::tensorAlgebra"; - - /// We set this bit to generate the declarations for the dialect's type parsing - /// and printing hooks. - let useDefaultTypePrinterParser = 1; - -} - -/// An implementation of RangeType. -def TA_RangeType : - DialectType()">, - "RangeType">; - - -/// Whether a type is a RangeType. -def TAIsRangeTypePred : CPred<"$_self.isa()">; -def Range : Type; +include "comet/Dialect/TensorAlgebra/IR/TABase.td" +include "comet/Dialect/TensorAlgebra/IR/TATypes.td" +include "comet/Dialect/TensorAlgebra/IR/TAAttrs.td" /// Base class for ta dialect operations. This operation inherits from the base /// `Op` class in OpBase.td, and provides: @@ -67,15 +48,6 @@ def Range : Type; class TA_Op traits = []> : Op; -/// Provide a definition for the TA SparseTensorType for use in ODS. -/// This allows for using SparseTensorType in a similar way to Tensor or MemRef. -def SparseTensor : - Type()">, "TA sparse tensor type">; - -/// Provide a definition of the types that are used within the TA dialect. -def TA_AnyTensor : AnyTypeOf<[TensorOf<[AnyType]>, SparseTensor]>; - - //===----------------------------------------------------------------------===// /// Tensor Algebra Operations //===----------------------------------------------------------------------===// @@ -177,7 +149,7 @@ def SparseTensorDeclOp : TA_Op<"sparse_tensor_decl", [Pure]> { let arguments = (ins Variadic:$labels, StrAttr:$format, BoolAttr:$temporal_tensor); /// The constant operation returns a single value of TensorType. - let results = (outs AnyTensor); + let results = (outs TA_AnyTensor); let extraClassDeclaration = [{ unsigned int getParameterCount() { @@ -222,7 +194,7 @@ def SparseOutputTensorDeclOp : TA_Op<"sparse_output_tensor_decl", [Pure]> { let arguments = (ins Variadic:$labels, StrAttr:$format); /// The constant operation returns a single value of TensorType. - let results = (outs AnyTensor); + let results = (outs TA_AnyTensor); /// Invoke a static verify method to verify this constant operation. @@ -252,7 +224,7 @@ def TempSparseOutputTensorDeclOp : TA_Op<"temp_sparse_output_tensor_decl", [Pure let arguments = (ins Variadic:$labels, StrAttr:$format); /// The constant operation returns a single value of TensorType. - let results = (outs AnyTensor); + let results = (outs TA_AnyTensor); //TODO(gkestor): add verifier //let hasVerifier = 1; @@ -315,7 +287,7 @@ def SparseTensorConstructOp : TA_Op<"sptensor_construct", [Pure]>{ //Aval_size (size of value array) //dim1_size, dim2_size (size of each dimension in sparse tensor) //TODO(gkestor): might be better to have a struct with all the data elements - let arguments = (ins Variadic:$indices, I32Attr:$tensor_rank); + let arguments = (ins Variadic:$indices, I32Attr:$tensor_rank, TAFormatArrayAttr:$dimension_formats); let results = (outs TA_AnyTensor:$output); let assemblyFormat = [{ @@ -354,7 +326,7 @@ def SparseTensorConstructOp : TA_Op<"sptensor_construct", [Pure]>{ } -def TensorFillOp : TA_Op<"fill", [Pure]>{ +def TensorFillOp : TA_Op<"fill">{ let summary = ""; let description = [{ @@ -407,7 +379,7 @@ def TensorElewsMultOp : TA_Op<"elews_mul", [Pure]>{ } -def TensorSetOp : TA_Op<"set_op", [Pure]>{ +def TensorSetOp : TA_Op<"set_op">{ let summary = ""; let description = [{ @@ -744,7 +716,7 @@ def GenericCallOp : TA_Op<"generic_call", ]; } -def TensorFillFromFileOp : TA_Op<"fill_from_file", [Pure]>{ +def TensorFillFromFileOp : TA_Op<"fill_from_file">{ let summary = ""; let description = [{ }]; @@ -782,8 +754,70 @@ def TensorCopyOp : TA_Op<"copy", [Pure]>{ //let hasVerifier = 1; } -#endif /// TA_OPS +def TensorInsertOp : TA_Op<"TAInsertOp", [Pure, SameVariadicOperandSize]>{ + let summary = "Insert intro a sparse tensor. Sparse Tensor equivalent of tensor.insert"; + let description = [{}]; + let arguments = ( + ins TA_AnyTensor:$tensor, + Variadic:$pos, + Variadic:$crds, + F64:$value + ); + let results = (outs TA_AnyTensor); +} +def TensorExtractOp : TA_Op<"TAExtractOp", [Pure, SameVariadicOperandSize]>{ + let summary = "Extract from a sparse tensor. Sparse Tensor equivalent of tensor.extract"; + let description = [{}]; + let arguments = ( + ins TA_AnyTensor:$tensor, + Index:$pos + ); + + let results = (outs AnyFloat); +} + +def SpTensorGetCrd : TA_Op<"SpTensorGetCrd", [Pure]>{ + let arguments = (ins TA_AnyTensor:$tensor, Index:$idx, OptionalAttr:$dim); + let results = (outs Index:$crd); +} + +def SpTensorInsertCrd : TA_Op<"SpTensorInsertCrd", [Pure]>{ + let arguments = (ins TA_AnyTensor:$tensor, I32Attr:$dim, Index:$idx, Index:$crd); + let results = (outs TA_AnyTensor:$result); +} + +def SpTensorGetDimSize : TA_Op<"SpTensorGetDimSize", [Pure]>{ + let arguments = (ins TA_AnyTensor:$tensor, I32Attr:$dim); + let results = (outs Index:$result); +} + +def SpTensorGetNNZ : TA_Op<"SpTensorGetNNZ", [Pure]>{ + let arguments = (ins TA_AnyTensor:$tensor, OptionalAttr:$dim); + let results = (outs Index:$result); +} + +def TensorFindPos : TA_Op<"TensorFindPos", [Pure]>{ + let arguments = (ins TA_AnyTensor:$tensor, Index:$crd, I32Attr:$dim, DefaultValuedAttr:$is_linear); + let results = (outs Index:$result); +} + +def AllocWorkspaceOp : TA_Op<"AllocWorkspace", [Pure, SameVariadicOperandSize]>{ + let summary = "Create a dense workspace to be used as an intermediate output tensor"; + let description = [{}]; + + let arguments = (ins SparseTensor:$tensor, I32ArrayAttr:$dims); + let results = (outs WorkspaceTensor:$result); +} + +def WorkspaceClearOp : TA_Op<"WorkspaceClear", [Pure]>{ + let summary = "Clear the workspace for use in the next iteration"; + let description = [{}]; + + let arguments = (ins WorkspaceTensor:$tensor); + let results = (outs WorkspaceTensor:$result); +} +#endif /// TA_OPS \ No newline at end of file diff --git a/include/comet/Dialect/TensorAlgebra/IR/TATypes.td b/include/comet/Dialect/TensorAlgebra/IR/TATypes.td new file mode 100644 index 00000000..5c916b17 --- /dev/null +++ b/include/comet/Dialect/TensorAlgebra/IR/TATypes.td @@ -0,0 +1,48 @@ +#ifndef TA_TYPES +#define TA_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "comet/Dialect/TensorAlgebra/IR/TABase.td" + +class TensorAlgebra_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def Range : TensorAlgebra_Type<"Range","range"> { + let summary = "Type representing the range of one dimnsion of a tensor"; + let description = [{}]; +} + +def SparseTensor : TensorAlgebra_Type<"SparseTensor", "sparse_tensor", [DeclareTypeInterfaceMethods]> { + let summary = "Sparse tensor to use in tensor algebra dialect"; + let description = [{}]; + + let parameters = (ins + "::mlir::Type":$element_type, + ArrayRefParameter<"int64_t", "Dimensions of tensor">:$dims, + ArrayRefParameter<"int32_t", "Format">:$format + ); + let assemblyFormat = "`<` $element_type `,` $dims `,` $format `>`"; + + // TODO: Implement custom builder from "common" format strings into format strings +} + +def WorkspaceTensor : TensorAlgebra_Type<"Workspace", "workspace", [DeclareTypeInterfaceMethods]> { + let summary = "Temporary tensor generated from a workspace transfrom"; + let description = [{ + Dense, temporary tensor generated from a workspace transformation. + Needed to represent dense row as well as mark array. + }]; + let parameters = ( + ins "::mlir::Type":$element_type, + ArrayRefParameter<"int64_t", "Dimensions of workspace">:$dims + ); + let assemblyFormat = "`<` $element_type `,` $dims `>`"; +} + +/// Provide a definition of the types that are used within the TA dialect. +def TA_AnyTensor : AnyTypeOf<[TensorOf<[AnyType]>, SparseTensor, WorkspaceTensor]>; + +#endif //TA_TYPES \ No newline at end of file diff --git a/include/comet/Dialect/Utils/Utils.h b/include/comet/Dialect/Utils/Utils.h index a3cf160c..1830903e 100644 --- a/include/comet/Dialect/Utils/Utils.h +++ b/include/comet/Dialect/Utils/Utils.h @@ -105,7 +105,7 @@ namespace mlir void print_vector_value(std::vector vec); std::string dump2str(Value t); - std::vector stringSplit(std::string s, std::string delimiter); + std::vector stringSplit(llvm::StringRef s, llvm::StringRef delimiter); std::vector getReverseIdentityPermutation(size_t size); std::vector getIdentityPermutation(size_t size); @@ -121,7 +121,7 @@ namespace mlir std::vector getFreeIndices(std::vector rhs_perm, std::vector lhs_perm); std::vector getSumIndices(std::vector rhs_perm, std::vector rhs_perm_free); std::vector getIndexIterateOrder(std::vector rhs1_perm, std::vector rhs2_perm); - std::vector> getAllFormats(ArrayAttr opFormatsArrayAttr, std::vector> allPerms); + std::vector> getAllFormats(ArrayAttr opFormatsArrayAttr, std::vector> allPerms); bool checkIsElementwise(std::vector> allPerms); bool checkIsMixedMode(std::vector> formats); bool checkIsDense(std::vector format); @@ -129,10 +129,11 @@ namespace mlir bool isDense(std::string s, std::string delim); bool isMergedIndex(std::vector format_vec, int cur_idx, int sumIndex); - std::vector getFormatsValue(std::string formats_str, int rank_size, + std::vector getFormatsValue(llvm::StringRef formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IndexType indexType); - std::vector getFormatsValueInt(std::string formats_str, int rank_size, + std::vector getFormatsValueInt(llvm::StringRef formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IntegerType intType); + std::vector getFormats(llvm::StringRef formats_str, int rank_size, MLIRContext* ctx); double loopCostHeuristic(const std::vector &loopOrder, size_t dim_, std::vector &sourceOrder, std::vector &destOrder); diff --git a/integration_test/compound_exps/CSR_mult_spTranspose_CSR.ta b/integration_test/compound_exps/CSR_mult_spTranspose_CSR.ta index a7c9ec95..e5aa6f99 100644 --- a/integration_test/compound_exps/CSR_mult_spTranspose_CSR.ta +++ b/integration_test/compound_exps/CSR_mult_spTranspose_CSR.ta @@ -26,7 +26,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/compound_exps/Dense_eltwise_sTranspose_CSR.ta b/integration_test/compound_exps/Dense_eltwise_sTranspose_CSR.ta index 7d722a00..217cf615 100644 --- a/integration_test/compound_exps/Dense_eltwise_sTranspose_CSR.ta +++ b/integration_test/compound_exps/Dense_eltwise_sTranspose_CSR.ta @@ -27,7 +27,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/compound_exps/dTranspose_eltwise_CSR.ta b/integration_test/compound_exps/dTranspose_eltwise_CSR.ta index 2b1fb8e4..45048cf5 100644 --- a/integration_test/compound_exps/dTranspose_eltwise_CSR.ta +++ b/integration_test/compound_exps/dTranspose_eltwise_CSR.ta @@ -28,7 +28,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/compound_exps/spTranspose_CSR_eltwise_CSR.ta b/integration_test/compound_exps/spTranspose_CSR_eltwise_CSR.ta index a33e59dc..a05eec8e 100644 --- a/integration_test/compound_exps/spTranspose_CSR_eltwise_CSR.ta +++ b/integration_test/compound_exps/spTranspose_CSR_eltwise_CSR.ta @@ -26,7 +26,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/compound_exps/spTranspose_CSR_eltwise_Dense.ta b/integration_test/compound_exps/spTranspose_CSR_eltwise_Dense.ta index b9535eb6..58f9f4f4 100644 --- a/integration_test/compound_exps/spTranspose_CSR_eltwise_Dense.ta +++ b/integration_test/compound_exps/spTranspose_CSR_eltwise_Dense.ta @@ -26,7 +26,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/compound_exps/spTranspose_CSR_mult_CSR.ta b/integration_test/compound_exps/spTranspose_CSR_mult_CSR.ta index fbde8a3c..dc5053f7 100644 --- a/integration_test/compound_exps/spTranspose_CSR_mult_CSR.ta +++ b/integration_test/compound_exps/spTranspose_CSR_mult_CSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/ops/eltwise_mult_COOxDense_oCOO.ta b/integration_test/ops/eltwise_mult_COOxDense_oCOO.ta index cb6626fe..05e2bc48 100644 --- a/integration_test/ops/eltwise_mult_COOxDense_oCOO.ta +++ b/integration_test/ops/eltwise_mult_COOxDense_oCOO.ta @@ -27,7 +27,7 @@ def main() { # CHECK-NEXT: data = # CHECK-NEXT: 0,0,1,1,2,3,3,4,4, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,3,1,4,2,0,3,1,4, # CHECK-NEXT: data = diff --git a/integration_test/ops/eltwise_mult_CSRxCSR_oCSR.ta b/integration_test/ops/eltwise_mult_CSRxCSR_oCSR.ta index dce50fc9..eca2d6bd 100644 --- a/integration_test/ops/eltwise_mult_CSRxCSR_oCSR.ta +++ b/integration_test/ops/eltwise_mult_CSRxCSR_oCSR.ta @@ -28,7 +28,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/ops/eltwise_mult_CSRxDense_oCSR.ta b/integration_test/ops/eltwise_mult_CSRxDense_oCSR.ta index 81ce6148..39fb1570 100644 --- a/integration_test/ops/eltwise_mult_CSRxDense_oCSR.ta +++ b/integration_test/ops/eltwise_mult_CSRxDense_oCSR.ta @@ -26,7 +26,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/ops/mult_spgemm_CSRxCSR_oCSR.ta b/integration_test/ops/mult_spgemm_CSRxCSR_oCSR.ta index 785fbff1..ce9f26e2 100644 --- a/integration_test/ops/mult_spgemm_CSRxCSR_oCSR.ta +++ b/integration_test/ops/mult_spgemm_CSRxCSR_oCSR.ta @@ -30,7 +30,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/ops/transpose_COO_tensor.ta b/integration_test/ops/transpose_COO_tensor.ta index edb5e9c6..7fd1cd5a 100644 --- a/integration_test/ops/transpose_COO_tensor.ta +++ b/integration_test/ops/transpose_COO_tensor.ta @@ -27,11 +27,11 @@ def main() { # CHECK-NEXT: data = # CHECK-NEXT: 1,3,6, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 2,1,3, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 3,2,5, # CHECK-NEXT: data = diff --git a/integration_test/opts/spgemm_w_compressed_workspace.ta b/integration_test/opts/spgemm_w_compressed_workspace.ta index 6c6daba8..a9ee52ad 100644 --- a/integration_test/opts/spgemm_w_compressed_workspace.ta +++ b/integration_test/opts/spgemm_w_compressed_workspace.ta @@ -30,7 +30,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/eltwise_monoidTimes_COOxDense_oCOO.ta b/integration_test/semiring/eltwise_monoidTimes_COOxDense_oCOO.ta index 7eb0fda2..322e12c9 100644 --- a/integration_test/semiring/eltwise_monoidTimes_COOxDense_oCOO.ta +++ b/integration_test/semiring/eltwise_monoidTimes_COOxDense_oCOO.ta @@ -27,7 +27,7 @@ def main() { # CHECK-NEXT: data = # CHECK-NEXT: 0,0,1,1,2,3,3,4,4, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,3,1,4,2,0,3,1,4, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringAnyPair_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringAnyPair_CSRxCSR_oCSR.ta index b4fe8c63..6bc08203 100644 --- a/integration_test/semiring/mm_SemiringAnyPair_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringAnyPair_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringMinFirst_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringMinFirst_CSRxCSR_oCSR.ta index 64acee8c..ed581294 100644 --- a/integration_test/semiring/mm_SemiringMinFirst_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringMinFirst_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringMinPlus_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringMinPlus_CSRxCSR_oCSR.ta index e1c5314d..21ad6ce7 100644 --- a/integration_test/semiring/mm_SemiringMinPlus_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringMinPlus_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringMinSecond_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringMinSecond_CSRxCSR_oCSR.ta index 6b6008db..314b0d32 100644 --- a/integration_test/semiring/mm_SemiringMinSecond_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringMinSecond_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringPlusFirst_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringPlusFirst_CSRxCSR_oCSR.ta index 1d49d593..a2967a19 100644 --- a/integration_test/semiring/mm_SemiringPlusFirst_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringPlusFirst_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringPlusPair_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringPlusPair_CSRxCSR_oCSR.ta index 1b590eec..1f47ef50 100644 --- a/integration_test/semiring/mm_SemiringPlusPair_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringPlusPair_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringPlusSecond_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringPlusSecond_CSRxCSR_oCSR.ta index fab31d69..9cf0fe0c 100644 --- a/integration_test/semiring/mm_SemiringPlusSecond_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringPlusSecond_CSRxCSR_oCSR.ta @@ -29,7 +29,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/integration_test/semiring/mm_SemiringPlusTimes_CSRxCSR_oCSR.ta b/integration_test/semiring/mm_SemiringPlusTimes_CSRxCSR_oCSR.ta index 46ae6de4..447f4e75 100644 --- a/integration_test/semiring/mm_SemiringPlusTimes_CSRxCSR_oCSR.ta +++ b/integration_test/semiring/mm_SemiringPlusTimes_CSRxCSR_oCSR.ta @@ -30,7 +30,7 @@ def main() { # CHECK: data = # CHECK-NEXT: 5, # CHECK-NEXT: data = -# CHECK-NEXT: 0, +# CHECK-NEXT: -1, # CHECK-NEXT: data = # CHECK-NEXT: 0,2,4,5,7,9, # CHECK-NEXT: data = diff --git a/lib/Conversion/IndexTreeToSCF/CMakeLists.txt b/lib/Conversion/IndexTreeToSCF/CMakeLists.txt index 240500da..511844d2 100644 --- a/lib/Conversion/IndexTreeToSCF/CMakeLists.txt +++ b/lib/Conversion/IndexTreeToSCF/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_conversion_library(COMETIndexTreeToSCF IndexTreeToSCF.cpp + SymbolicDomainConversion.cpp + IndexTreeConversion.cpp ADDITIONAL_HEADER_DIRS ${COMET_MAIN_INCLUDE_DIR}/comet/Conversion/IndexTreeToSCF @@ -12,6 +14,7 @@ add_mlir_conversion_library(COMETIndexTreeToSCF LINK_LIBS PUBLIC MLIRArithDialect + MLIRIndexDialect MLIRIR MLIRMemRefDialect MLIRSCFDialect diff --git a/lib/Conversion/IndexTreeToSCF/IndexTreeConversion.cpp b/lib/Conversion/IndexTreeToSCF/IndexTreeConversion.cpp new file mode 100644 index 00000000..9ec5e9ec --- /dev/null +++ b/lib/Conversion/IndexTreeToSCF/IndexTreeConversion.cpp @@ -0,0 +1,172 @@ +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/IndexTree/Passes.h" +#include "comet/Dialect/IndexTree/Patterns.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Index/IR/IndexAttrs.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using llvm::SmallVector; + +namespace mlir { + namespace comet{ + #define GEN_PASS_DEF_INDEXTREE_INLINING + #include "comet/Conversion/Passes.h.inc" + } +} + + +namespace { +class ConvertIndexTreeYieldOpTypes : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(indexTree::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector unpacked; + for (Value v : adaptor.getOperands()) { + if (auto cast = + dyn_cast_or_null(v.getDefiningOp())) { + if (cast.getInputs().size() != 1) { + unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); + continue; + } + } + // 1 : 1 type conversion. + unpacked.push_back(v); + } + rewriter.replaceOpWithNewOp(op, unpacked); + return success(); + } +}; + + +class ConvertIndexTreeTypes : public OpConversionPattern{ + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(indexTree::IndexTreeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector dstTypes; + SmallVector offsets; + offsets.push_back(0); + // Do the type conversion and record the offsets. + for (Type type : op.getResultTypes()) { + if (failed(typeConverter->convertTypes(type, dstTypes))) + return rewriter.notifyMatchFailure(op, "could not convert result type"); + offsets.push_back(dstTypes.size()); + } + + // Calls the actual converter implementation to convert the operation. + auto newOp = rewriter.create(op.getLoc(), dstTypes); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), newOp.getRegion().end()); + + // Packs the return value. + SmallVector packedRets; + for (unsigned i = 1, e = offsets.size(); i < e; i++) { + unsigned start = offsets[i - 1], end = offsets[i]; + unsigned len = end - start; + ValueRange mappedValue = newOp->getResults().slice(start, len); + if (len != 1) { + // 1 : N type conversion. + Type origType = op.getResultTypes()[i - 1]; + Value mat = typeConverter->materializeSourceConversion( + rewriter, op.getLoc(), origType, mappedValue); + if (!mat) { + return rewriter.notifyMatchFailure( + op, "Failed to materialize 1:N type conversion"); + } + packedRets.push_back(mat); + } else { + // 1 : 1 type conversion. + packedRets.push_back(mappedValue.front()); + } + } + + rewriter.replaceOp(op, packedRets); + return success(); + } +}; + +class InlineIndexTreeOp : public OpConversionPattern{ + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(indexTree::IndexTreeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Block& body = op.getRegion().front(); + Operation* terminator = body.getTerminator(); + rewriter.replaceOp(op, terminator->getOperands()); + rewriter.mergeBlockBefore(&body, op); + return success(); + } +}; + +class InlineIndexTreeYieldOp : public OpConversionPattern{ + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(indexTree::YieldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; +} //namespace + +void mlir::indexTree::populateIndexTreeInliningPatterns(MLIRContext *context, RewritePatternSet &patterns) { + patterns.add(context); +} + +void mlir::indexTree::populateIndexTreeTypeConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, TypeConverter &typeConverter, ConversionTarget& target) { + target.addDynamicallyLegalOp([&](Operation *op) { + return typeConverter.isLegal(op->getResultTypes()); + }); + target.addDynamicallyLegalOp([&](indexTree::YieldOp op) { + return typeConverter.isLegal(op->getOperandTypes()); + }); + + patterns.add(typeConverter, context); +} + +struct IndexTreeInliningPass + : public PassWrapper> +{ + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IndexTreeInliningPass) + + void runOnOperation() override + { + // Convert the rest of the index tree dialect to SCF + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + + mlir::RewritePatternSet patterns(&getContext()); + mlir::indexTree::populateIndexTreeInliningPatterns(&getContext(), patterns); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +/// Lower sparse tensor algebra operation to loops +std::unique_ptr mlir::comet::createIndexTreeInliningPass() +{ + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp b/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp index 60fec149..7577cb3f 100644 --- a/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp +++ b/lib/Conversion/IndexTreeToSCF/IndexTreeToSCF.cpp @@ -32,15 +32,22 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Pass/Pass.h" #include "mlir/IR/Dominance.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/SetVector.h" #include #include #include @@ -61,13 +68,13 @@ using namespace mlir::tensorAlgebra; using llvm::SmallVector; using llvm::StringRef; +using llvm::SmallDenseMap; #define DEBUG_TYPE "lowering-it-to-scf" // *********** For debug purpose *********// //#define COMET_DEBUG_MODE #include "comet/Utils/debug.h" -#undef COMET_DEBUG_MODE // *********** For debug purpose *********// namespace comet @@ -205,4315 +212,907 @@ namespace } }; - class OpsTree - { - /// private: - public: - std::vector forOps; /// The (nested) for loops - std::vector accessIdx; /// The coordinate of accessing that dimension - std::vector symbolicForOps; /// For-loops in symbolic phase (if necessary) - std::vector symbolicAccessIdx; /// The accessing index for that for-loop in symbolic phase (if necessary) - /// std::vector cmptOps; /// The computation ops (no used?) - std::vector children; - OpsTree *parent; - int id; /// the index in the ws_op array. The order is the DFS order. - - std::vector symbolicForOps_debug; - std::vector symbolicAccessIdx_debug; - - public: - OpsTree() {} - - OpsTree(std::vector &forOps, std::vector &accessIdx, - OpsTree *parent, int id) : forOps(forOps), accessIdx(accessIdx), parent(parent), id(id) - { - } - - OpsTree(std::vector &forOps, std::vector &accessIdx, - OpsTree *parent) : forOps(forOps), accessIdx(accessIdx), parent(parent) - { - } - - ~OpsTree() {} - - void addChild(OpsTree *tree) - { /// const T& node - this->children.push_back(tree); - } - - std::vector &getForOps() - { - return this->forOps; - } - - OpsTree *getParent() - { - return this->parent; - } - - void setForOps(std::vector &forOps) - { - this->forOps = forOps; - } - - std::vector &getChildren() - { - return this->children; - } - }; - - /// ----------------- /// - /// struct to pass symbolic phase information to the numeric phase - /// ----------------- /// - struct SymbolicInfo - { - bool are_inputs_sparse = false; /// If both inputs are sparse. It is true for SpGEMM and sparse elementwise operations. - /// All other members are only used when are_inputs_sparse is true. - - bool has_symbolic_phase = false; /// If current generated code should have a symbolic phase. - /// Currently, if are_inputs_parse == true; then has_symbolic_phase = true; - - Value mtxC_num_rows = nullptr; - Value mtxC_num_cols = nullptr; - - Value mtxC_rowptr = nullptr; /// Output C's rowptr array when do C = A * B and they are all sparse - /// %alloc_100 = memref.alloc(%43) : memref - Value mtxC_col = nullptr; /// Output C's col array when do C = A * B and they are all sparse - /// %alloc_104 = memref.alloc(%44) : memref - Value mtxC_val = nullptr; /// Output C's val array when do C = A * B and they are all sparse - /// %alloc_108 = memref.alloc(%44) : memref - Value mtxC_val_size = nullptr; /// Output C' correct number of non-zeros or C_val_size (ready after symbolic phase) - - Value mtxC_rowptr_size = nullptr; /// rowptr array's size, which is number of columns plus one (num_col + 1). - - Value row_offset = nullptr; /// In Numeric Phase, row_offset is the insertion location in the C_col and C_val. - - Value mtxC = nullptr; /// The sparse tensor - /// It is %55 below. - }; - - /// ----------------- /// - /// Auxiliary structures for the numeric phase - /// ----------------- /// - struct NumericInfo - { - Value ws_bitmap = nullptr; /// workspace's bitmap to tell if a column ID is visited. - Value ws_bitmap_valueAccessIdx = nullptr; /// value access index for the workspace bitmap. - - Value mask_array = nullptr; /// the intermediate dense vector for a row of the mask. - }; - - /// ----------------- /// - /// Remove an operantion's user who is a memref.store - /// This is very ad-hoc, just to avoid segmentation fault for old very large C.val array and C.col array. - /// ----------------- /// - void removeMemrefStoreUser(Value &opd) - { - { - comet_vdump(opd); - } - std::vector users; - for (Operation *user : opd.getUsers()) - { - if (isa(user)) - { - users.push_back(user); - { - comet_pdump(user); - } - } - } - for (Operation *user : users) - { - user->erase(); - } - } - - /// ----------------- /// - /// Find all users of the old_Value, and replace those users' corresponding operand to new_Value. For example, - /// "ta.print"(%old_Value) => "ta.print"(%new_Value) - /// ----------------- /// - void replaceOldValueToNewValue(Value &old_Value, - Value &new_Value) - { - { - comet_vdump(old_Value); - comet_vdump(new_Value); - } - - /// Traverse each user of new_Value - std::vector users; - for (Operation *user : old_Value.getUsers()) - { - users.push_back(user); - } - DominanceInfo domInfo(new_Value.getDefiningOp()); /// To check dominance - for (Operation *user : users) - { - { - comet_debug() << "before replace operand.\n"; - comet_pdump(user); - } - /// Check if new_Value dominates the user - if (!domInfo.dominates(new_Value, user)) - { - continue; - } - uint64_t op_i = 0; - for (Value op : user->getOperands()) - { - /// Find the mtxC in the user's operands - if (op.getDefiningOp() == old_Value.getDefiningOp()) - { - /// Replace the old sparse tensor to the new one - user->setOperand(op_i, new_Value); - { - comet_debug() << "after replace operand.\n"; - comet_pdump(user); - } - } - ++op_i; - } - } - } - - /// ----------------- /// - /// Add declaration of the function comet_index_func; - /// ----------------- /// - void declareSortFunc(ModuleOp &module, - MLIRContext *ctx, - Location loc) - { - IndexType indexType = IndexType::get(ctx); - - /// Declare comet_sort_index() - auto sort_index_func = FunctionType::get(ctx, - {UnrankedMemRefType::get(indexType, 0), indexType, indexType} /* inputs */, {} /* return */); - std::string func_name = "comet_sort_index"; - if (!hasFuncDeclaration(module, func_name /* func name */)) - { - func::FuncOp func_declare = func::FuncOp::create(loc, - func_name, - sort_index_func, - ArrayRef{}); - func_declare.setPrivate(); - module.push_back(func_declare); - } - } - - /// Get mask_rowptr, mask_col, and mask_val arrays. - /// ----------------- /// - /// mask_tensor = %50 - /// mask_rowptr = %alloc_99 - /// mask_col = %alloc_104 - /// mask_val = %alloc_109 - /// ----------------- /// - /// %45 = bufferization.to_tensor %alloc_99 : memref - /// %46 = bufferization.to_tensor %alloc_104 : memref - /// %49 = bufferization.to_tensor %alloc_109 : memref - /// %50 = ta.sptensor_construct(%41, %42, %43, %44, %45, %46, %47, %48, %49, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38) {tensor_rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, index, index, index, index, index, index, index, index, index, index, index) -> (!ta.sptensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, index, index, index, index, index, index, index, index, index, index, index>) - /// ----------------- /// - void getMaskSparseTensorInfo(MaskingInfo &maskingInfo /* contents updated after call*/) - { - Value &mask_tensor = maskingInfo.mask_tensor; - - /// A2pos - Value mask_rowtpr_buff = mask_tensor.getDefiningOp()->getOperand(CSR_A2POS); /// 2 - maskingInfo.mask_rowptr = mask_rowtpr_buff.getDefiningOp()->getOperand(0); - - /// A2crd - Value mask_col_buff = mask_tensor.getDefiningOp()->getOperand(CSR_A2CRD); /// 3 - maskingInfo.mask_col = mask_col_buff.getDefiningOp()->getOperand(0); - - /// Aval - Value mask_val_buff = mask_tensor.getDefiningOp()->getOperand(CSR_AVAL); /// 4 - maskingInfo.mask_val = mask_val_buff.getDefiningOp()->getOperand(0); - - { - comet_vdump(mask_tensor); - comet_vdump(maskingInfo.mask_rowptr); - comet_vdump(maskingInfo.mask_col); - comet_vdump(maskingInfo.mask_val); - } - } - - unsigned int findIndexInVector_OpsTree(std::vector vec, OpsTree *e) - { - /// Check if element e exists in vector - auto it = std::find(vec.begin(), vec.end(), e); - - /// It accepts a range and an element to search in the given range. If element is found then it returns an iterator to the first element in the given range that’s equal to given element, else it returns an end of the list. - unsigned int ret = 0; - if (it != vec.end()) - { - /// Get index of element from iterator - ret = std::distance(vec.begin(), it); - } - else - { - ret = vec.size(); - } - return ret; - } - - Value findCorrespondingAlloc(Value &iOp) - { - comet_debug() << "findCorrespondingAlloc for loop upper bound\n"; - comet_vdump(iOp); - auto init_alloc = iOp.getDefiningOp()->getOperand(0); - comet_vdump(init_alloc); - - while (true) - { - if (isa(init_alloc.getDefiningOp())) - { - if (init_alloc.getType().dyn_cast().getDimSize(0) != ShapedType::kDynamic) - { - return init_alloc; - } - } - if (init_alloc.getDefiningOp()->getNumOperands() > 0) - { - init_alloc = init_alloc.getDefiningOp()->getOperand(0); - } - else - { - /// Alloc related to another sparse tensor construct such as coming from sparse transpose - comet_debug() << "Return alloc op - comes from sptensor_construct\n"; - comet_vdump(init_alloc); - return init_alloc; - } - } - } - - /// Get allocs for a tensor (sparse or dense) - std::vector getAllocs(Value &tensor) - { - comet_vdump(tensor); - std::vector allocs; - if (tensor.getType().isa()) - { /// Dense tensor - comet_debug() << " getAllocs() - it is dense\n"; - if (isa(tensor.getDefiningOp())) - { - Operation *tensorload = cast(tensor.getDefiningOp()); - auto alloc_op = cast(tensorload->getOperand(0).getDefiningOp()); - comet_vdump(alloc_op); - allocs.push_back(alloc_op); - } - else - { - for (unsigned int i = 0; i < tensor.getDefiningOp()->getNumOperands(); i++) - { - if (isa(tensor.getDefiningOp()->getOperand(i).getDefiningOp())) - { - Operation *tensorload = cast(tensor.getDefiningOp()->getOperand(i).getDefiningOp()); - auto alloc_op = cast(tensorload->getOperand(0).getDefiningOp()); - comet_vdump(alloc_op); - allocs.push_back(alloc_op); - } - } - } - } - else if (tensor.getType().isa()) - { /// nSparse tensor - comet_debug() << " getAllocs() - it is sparse\n"; - auto defop = tensor.getDefiningOp(); - - for (unsigned int n = 0; n < defop.getTotalDimArrayCount(); n++) - { - comet_vdump(defop.getIndices()[n]); - Operation *tensorload = defop.getIndices()[n].getDefiningOp(); - auto alloc_op = cast(tensorload->getOperand(0).getDefiningOp()); - allocs.push_back(alloc_op); - comet_vdump(alloc_op); - } - } - else if (dyn_cast(tensor.getDefiningOp())) - { /// ConstantOp - allocs.push_back(tensor); - } - return allocs; - } - - std::vector> getAllAllocs(std::vector &tensors) - { - std::vector> allAllocs(tensors.size()); - for (unsigned int i = 0; i < tensors.size(); i++) - { - allAllocs[i] = getAllocs(tensors[i]); - } - return allAllocs; - } - - /// while until parent == null - void getAncestorsOps(OpsTree *opstree, std::vector &ret) - { - - while (opstree->parent != nullptr) - { - ret.push_back(opstree->parent); - opstree = opstree->parent; - } - } - - /// In genForOps, set Insertion Point for numeric loops. - void setInsertionPointInNumericLoops(OpBuilder &builder, - std::vector &ancestorsOps, - OpsTree *opstree) - { - /// If parent is for loop, insert into the body, How to get end of body? - if (ancestorsOps.size() > 0) - { - /// ancestorsOps[0] stores the closest parent - scf::ForOp parent_forop = nullptr; - comet_debug() << "\n"; - std::vector parent_forops = ancestorsOps[0]->forOps; - comet_debug() << " parent_forops.size(): " << parent_forops.size() << " \n"; - - parent_forop = parent_forops[parent_forops.size() - 1]; - - comet_debug() << " reset the insertion point\n"; - comet_vdump(parent_forop); - - unsigned int order = findIndexInVector_OpsTree(ancestorsOps[0]->getChildren(), opstree); - comet_debug() << " order: " << order << "\n"; - if (order == ancestorsOps[0]->getChildren().size()) - { - llvm::errs() << __FILE__ << ":" << __LINE__ << "ERROR: Not belong to parent's children\n"; - } - else - { - /// Get the children of the parent_forop - comet_debug() << " number of children: " << parent_forops.size() << "\n"; - if (order == 0) - { - /// builder.setInsertionPointToStart(parent_forop.getBody()); - comet_debug() << "Insertion point order == 0\n"; - builder.setInsertionPoint(parent_forop.getBody()->getTerminator()); - } - else - { - comet_debug() << "\n"; - std::vector brother_forops = ancestorsOps[0]->getChildren()[order - 1]->forOps; - if (brother_forops.size() > 0) - { - comet_debug() << " brother_forops.size(): " << brother_forops.size() << "\n"; - if (opstree->forOps.size() == 0) - { - comet_debug() << "\n"; - comet_vdump(brother_forops[0]); - comet_debug() << "Insertion point (brother_forops.size() > 0 && opstree->forOps.size() == 0)\n"; - builder.setInsertionPointAfter(brother_forops[0]); - } - else - { /// current opstree contains loops, insert in the body of the loops - comet_debug() << " -------- current opstree contain loops --- impossible\n"; - comet_debug() << "Insertion point (brother_forops.size() > 0 && opstree->forOps.size() != 0)\n"; - builder.setInsertionPoint(opstree->forOps[opstree->forOps.size() - 1].getBody()->getTerminator()); - } - } - } - } - comet_debug() << " reset the insertion point\n"; - } - } - - /// ----------------- /// - /// In genForOps, generate for-loop for a indexOp node if the index is corresponding to Format "D" - /// ----------------- /// - void genForOpFormat_D(OpBuilder &builder, - Location &loc, - Value &tensor, - unsigned int id, - unsigned int i, - std::vector> &allAllocs, - scf::ForOp &forLoop /* output */, - Value &accessIndex /* output */) - { - /// Value upperBound; - /// Value lowerBound; - /// Check which tensor is sparse, which is dense; - /// Since this function only handles mixed sparse/dense, then "D" only occurs in one tensor - /// Both the dense and sparse tensor contain the dim size; But they are different. Use one. - int64_t maxSize = 0; - comet_debug() << " "; - comet_vdump(tensor); - if (tensor.getType().isa()) - { /// Dense tensor - Value upperBound; - auto tensorTy = tensor.getType().cast(); - maxSize = tensorTy.getDimSize(id); - - /// Check if dynamic size - /// Check upperBoundsize - if (maxSize == ShapedType::kDynamic) - { - /// Find defOp allocOp, check the parameter - comet_debug() << " Dynamic size "; - comet_pdump(tensor.getDefiningOp()); /// tensor_load - comet_vdump(tensor.getDefiningOp()->getOperand(0)); /// alloc - /// Check the order of the current dynamic size - auto rhs1_alloc = tensor.getDefiningOp()->getOperand(0); - std::vector dyn_dims_vec; - for (unsigned i = 0; i < tensorTy.getRank(); i++) - { - if (tensorTy.isDynamicDim(i)) - { - dyn_dims_vec.push_back(i); - } - } /// ? x ? x 20 x ? - auto rhs1_loc_dyn = findIndexInVector(dyn_dims_vec, id); - comet_vdump(rhs1_alloc.getDefiningOp()->getOperand(rhs1_loc_dyn)); - - upperBound = rhs1_alloc.getDefiningOp()->getOperand(rhs1_loc_dyn); - } - else - { - upperBound = builder.create(loc, maxSize); - } - - Value lowerBound = builder.create(loc, 0); - auto step = builder.create(loc, 1); - auto loop = builder.create(loc, lowerBound, upperBound, step); - - comet_debug() << " D Loop\n"; - comet_vdump(loop); - - /// opstree->forOps.push_back(loop); - /// opstree->accessIdx.push_back(loop.getInductionVar()); - forLoop = loop; - accessIndex = loop.getInductionVar(); - } - else if (tensor.getType().isa()) - { - comet_debug() << " \n"; - comet_pdump(tensor.getDefiningOp()); - if (indexTree::IndexTreeComputeRHSOp rhsop = dyn_cast( - tensor.getDefiningOp())) - { - comet_debug() << " \n"; - } - } - else if (tensor.getType().cast()) - { - comet_debug() << "cur_idx is in tensor " << i << "\n"; - - Value lowerBound = builder.create(loc, 0); - auto index_0 = builder.create(loc, 0); - std::vector upper_indices = {index_0}; - Value upperBound = builder.create(loc, allAllocs[i][4 * id], upper_indices); - comet_vdump(allAllocs[i][4 * id]); - auto step = builder.create(loc, 1); - auto loop = builder.create(loc, lowerBound, upperBound, step); - - comet_debug() << " D Loop\n"; - comet_vdump(loop); - - forLoop = loop; - accessIndex = loop.getInductionVar(); - } - /// } - } - - /// ----------------- /// - /// In genForOps, generate for-loop for a indexOp node if the index is corresponding to Format "CU" - /// ----------------- /// - void genForOpFormat_CU(OpBuilder &builder, - Location &loc, - OpsTree *opstree, - Value &tensor, - unsigned int id, - unsigned int i, - std::vector> &allAllocs, - scf::ForOp &parent_forop, - Value &parent_accessIdx, - scf::ForOp &forLoop /* output */, - Value &accessIndex /* output */) - { - /// Generate for(int m = pos[0]; m < pos[1]; m++){int i = crd[m];} - /// if i = 0, index is [0,1] - /// if parent loop and child loop is accessing the same sparse tensor (CSF), index is [m, m+1], m is the nearest loop induction variable - /// Otherwise, the m comes from load operation of the input sparse tensor such as - /// j = crd[i]; - /// for (int m = pos[j]; m < pos[j+1]; m++) - - comet_debug() << " format is CU id: " << id << "\n"; - comet_debug() << " Tensor: \n"; - comet_vdump(tensor); - Value index_lower; - Value index_upper; - if (tensor.getType().cast()) - { - comet_debug() << " Tensor type is sparse\n"; - if (id == 0) - { /// The first index in the tensor - index_lower = builder.create(loc, 0); - comet_vdump(index_lower); - } - else - { - if (opstree->parent != nullptr) - { - comet_debug() << " opstree->parent is not NULL\n"; - comet_debug() << " parent forop\n"; - comet_vdump(parent_forop); - auto parent_UpperBound = parent_forop.getUpperBound(); - comet_debug() << " parent upperBound:\n"; - comet_vdump(parent_UpperBound); - - /// check if parent's and child's upper bounds come from the same sparse tensor - auto alloc_parent_bounds = findCorrespondingAlloc(parent_UpperBound); - comet_debug() << " parent upperBound alloc\n"; - comet_vdump(alloc_parent_bounds); - - comet_debug() << " child upperBound:\n"; - comet_vdump(allAllocs[i][4 * id]); - auto alloc_child_bounds = findCorrespondingAlloc(allAllocs[i][4 * id]); - comet_debug() << " child upperBound alloc\n"; - comet_vdump(alloc_child_bounds); - - if (alloc_child_bounds == alloc_parent_bounds) /// m is the nearest loop induction variable - { - comet_debug() << " THESAME: Parent and Child has the same alloc\n"; - index_lower = parent_forop.getInductionVar(); - } - else - { /// m comes from the load - comet_debug() << " DIFFERENT:Parent and Child has the different alloc\n"; - comet_vdump(alloc_parent_bounds); - comet_vdump(alloc_child_bounds); - index_lower = parent_accessIdx; - } - } - else - llvm::errs() << "ERROR: Unexpected condition\n"; - } - - comet_debug() << " index_lower:"; - comet_vdump(index_lower); - Value const_index_1 = builder.create(loc, 1); - comet_vdump(const_index_1); - index_upper = builder.create(loc, index_lower, const_index_1); - comet_debug() << " AddIOps (index_upper):"; - comet_vdump(index_upper); - - std::vector lower_indices = {index_lower}; - Value lowerBound = builder.create(loc, allAllocs[i][4 * id], lower_indices); /// 2 * id - - std::vector upper_indices = {index_upper}; - Value upperBound = builder.create(loc, allAllocs[i][4 * id], upper_indices); /// 2 * id - auto step = builder.create(loc, 1); - auto loop = builder.create(loc, lowerBound, upperBound, step); - - comet_debug() << " CU Loop\n"; - comet_vdump(loop); - - builder.setInsertionPoint(loop.getBody()->getTerminator()); - - std::vector crd_indices = {loop.getInductionVar()}; - auto get_index = builder.create(loc, allAllocs[i][4 * id + 1], crd_indices); - - comet_debug() << "CU loop generated\n"; - comet_vdump(loop); - forLoop = loop; - accessIndex = get_index; - } - } - - /// ----------------- /// - /// In genForOps, generate for-loop for a indexOp node if the index is corresponding to Format "CN" - /// ----------------- /// - void genForOpFormat_CN(OpBuilder &builder, - Location &loc, - Value &tensor, - unsigned int id, - unsigned int i, - std::vector> &allAllocs, - scf::ForOp &forLoop /* output */, - Value &accessIndex /* output */) - { - /// Generate for(int m = pos[0]; m < pos[1]; m++){int i = crd[m];} - if (tensor.getType().cast()) - { - auto index_0 = builder.create(loc, 0); - std::vector lower_indices = {index_0}; - Value lowerBound = builder.create(loc, allAllocs[i][4 * id], lower_indices); - - auto index_1 = builder.create(loc, 1); - std::vector upper_indices = {index_1}; - Value upperBound = builder.create(loc, allAllocs[i][4 * id], upper_indices); - auto step = builder.create(loc, 1); - auto loop = builder.create(loc, lowerBound, upperBound, step); - - comet_debug() << " CN Loop\n"; - comet_vdump(loop); - - builder.setInsertionPoint(loop.getBody()->getTerminator()); - - std::vector crd_indices = {loop.getInductionVar()}; - auto get_index = builder.create(loc, allAllocs[i][4 * id + 1], crd_indices); - - forLoop = loop; - accessIndex = get_index; - } - } - - /// ----------------- /// - /// In genForOps, generate for-loop for a indexOp node if the index is corresponding to Format "S" - /// ----------------- /// - void genForOpFormat_S(OpBuilder &builder, - Location &loc, - OpsTree *opstree, - Value &tensor, - unsigned int id, - unsigned int i, - std::vector> &allAllocs, - std::vector &opstree_forops, - scf::ForOp &parent_forop, - scf::ForOp &forLoop /* output */, - Value &accessIndex /* output */) - { - /// Currently supported formats, Singleton is not the format of first dimension - /// and it doesn't produce a loop - /// Generate: int j = A2crd[m]; - - if (tensor.getType().cast()) - { - comet_debug() << "cur_idx is in tensor " << i << "\n"; - /// Accesing the last level loop info - scf::ForOp last_forop; - if (opstree_forops.size() > 0) - { /// current node contain at least 1 level loop - last_forop = opstree_forops.back(); - } - else - { - if (opstree->parent != nullptr) - last_forop = parent_forop; - } - - std::vector crd_indices = {last_forop.getInductionVar()}; - auto get_index = builder.create(loc, allAllocs[i][4 * id + 1], crd_indices); - - /// Adding one iteration loop to provide consistency with the corresponding index tree. - /// Index tree includes an index node for the dimension but "S" format for this dimension - /// doesn't produce a loop. - Value lowerBound = builder.create(loc, 0); - Value upperBound = builder.create(loc, 1); - auto step = builder.create(loc, 1); - auto loop = builder.create(loc, lowerBound, upperBound, step); - comet_debug() << " S Loop\n"; - comet_vdump(loop); - forLoop = loop; - accessIndex = get_index; - } - else - { - llvm::errs() << "Not supported tensor type\n"; - } - } - - /// In genForOps, set Insertion Point for symbolic loops. - void setInsertionPointInSymbolicLoops(OpBuilder &builder, - std::vector &ancestorsOps, - OpsTree *opstree) - { - /// If parent is for loop, insert into the body, How to get end of body? - if (ancestorsOps.size() > 0) - { - /// ancestorsOps[0] stores the closest parent - scf::ForOp parent_forop = nullptr; - comet_debug() << "\n"; - std::vector parent_forops = ancestorsOps[0]->symbolicForOps; - comet_debug() << " parent_forops.size(): " << parent_forops.size() << " \n"; - - parent_forop = parent_forops.back(); - - comet_debug() << "symbolic: reset the insertion point\n"; - comet_vdump(parent_forop); - - unsigned int order = findIndexInVector_OpsTree(ancestorsOps[0]->getChildren(), opstree); - comet_debug() << " order: " << order << "\n"; - if (order == ancestorsOps[0]->getChildren().size()) - { - llvm::errs() << __LINE__ << "Not belong to parent's children\n"; - } - else - { - /// Get the children of the parent_forop - comet_debug() << " number of children: " << parent_forops.size() << "\n"; - if (order == 0) - { - /// builder.setInsertionPointToStart(parent_forop.getBody()); - comet_debug() << "Insertion point order == 0\n"; - builder.setInsertionPoint(parent_forop.getBody()->getTerminator()); - } - else - { - comet_debug() << "\n"; - std::vector brother_forops = ancestorsOps[0]->getChildren()[order - 1]->symbolicForOps; - if (brother_forops.size() > 0) - { - comet_debug() << " brother_forops.size(): " << brother_forops.size() << "\n"; - if (opstree->symbolicForOps.size() == 0) - { - comet_debug() << "\n"; - comet_vdump(brother_forops[0]); - comet_debug() << "Insertion point (brother_forops.size() > 0 && opstree->symbolicForOps.size() == 0)\n"; - builder.setInsertionPointAfter(brother_forops[0]); - } - else - { /// current opstree contains loops, insert in the body of the loops - comet_debug() << " -------- current opstree contain loops --- impossible\n"; - comet_debug() << "Insertion point (brother_forops.size() > 0 && opstree->symbolicForOps.size() != 0)\n"; - builder.setInsertionPoint(opstree->symbolicForOps.back().getBody()->getTerminator()); - } - } - else - { - comet_debug() << "brothers have no for-loops. Insert at the end of parent's for-loop body.\n"; - /// builder.setInsertionPointToEnd(parent_forop.getBody()); /// This doesn't work because it inserts even after the scf.yield, which is wrong. - builder.setInsertionPoint(parent_forop.getBody()->getTerminator()); - } - } - } - comet_debug() << " reset the insertion point\n"; - } - } - - /// In genCmptOps, generate code for a compute node with workspace transformation. - /// For example, A = 0.0 . A could be scalar or vector. - void genWorkspaceCmptOpInitialAssignment(OpBuilder &builder, - Location &loc, - int lhs_loc, - ConstantOp &cstop, - std::vector &nested_forops, - std::vector> &tensors_lhs_Allocs, - std::vector> &main_tensors_all_Allocs, - bool use_dynamic_init, - SymbolicInfo &symbolicInfo) - { - - /// Generate Store 1.0, A[...] this op - /// this case: allPerms[0] is empty, allFormats[0] is empty - comet_vdump(cstop); - comet_debug() << " cstop.getValue(): " << cstop.getValue() << "\n"; - comet_vdump(main_tensors_all_Allocs[lhs_loc].back()); - comet_debug() << " tensors_lhs_Allocs.size(): " << tensors_lhs_Allocs.size() << "\n"; - { - comet_vdump(nested_forops[0]); - } - Value local_accessIdx = nested_forops[0].getInductionVar(); - insertInitialize(loc, - cstop, - main_tensors_all_Allocs[lhs_loc].back(), - local_accessIdx, - builder, - use_dynamic_init, - symbolicInfo.mtxC_rowptr /* dynamic_init */); - } - - /// In genCmptOps, generate code for a compute node that copy a sparse input row into a dense vector. - void genWorkspaceCmptOpScatterInputToWorkspace(OpBuilder &builder, - Location &loc, - int main_tensor_nums, - std::vector> &main_tensors_all_Allocs, - std::vector> &allValueAccessIdx) - { - - std::vector allLoads(main_tensor_nums); - for (auto m = 0; m < main_tensor_nums; m++) - { - Value s = builder.create(loc, - main_tensors_all_Allocs[m][main_tensors_all_Allocs[m].size() - 1], - allValueAccessIdx[m]); - allLoads[m] = s; - comet_debug() << " "; - comet_vdump(s); - } - comet_debug() << " allLoads.size(): " << allLoads.size() << "\n"; - - builder.create(loc, allLoads[0], - main_tensors_all_Allocs[1][main_tensors_all_Allocs[1].size() - 1], - allValueAccessIdx[1]); - } - - /// Generate scf.for op for indices - /// The index is the "idx"th index of "tensor" - void genForOps(std::vector &tensors, - std::vector &ids, - std::vector &formats, - indexTree::IndexTreeOp rootOp, - OpBuilder &builder, - OpsTree *opstree, - SymbolicInfo &symbolicInfo) - { - comet_debug() << " genForOps indexTreeOp\n"; - comet_vdump(rootOp); - Location loc = rootOp.getLoc(); - /// The insertion location should be "the end of the body of parent loop" - std::vector ancestorsOps; - getAncestorsOps(opstree, ancestorsOps); - comet_debug() << " genForOps ancestorsOps.size(): " << ancestorsOps.size() << "\n"; - for (unsigned int i = 0; i < ancestorsOps.size(); i++) - { - comet_debug() << " ancestorsOps[" << i << "]->forOps.size(): " << ancestorsOps[i]->forOps.size() - << ", ancestorsOps->id: " - << ancestorsOps[i]->id << "\n"; - } - comet_debug() << "Tensor size: " << tensors.size() << "\n"; - std::vector> allAllocs = getAllAllocs(tensors); - - comet_debug() << "Tensors:\n"; - for (unsigned int i = 0; i < tensors.size(); i++) - { - comet_vdump(tensors[i]); - } - - /// ----------------- /// - /// Set insertion point - /// ----------------- /// - setInsertionPointInNumericLoops(builder, - ancestorsOps, - opstree); - - for (unsigned int i = 0; i < tensors.size(); i++) - { - if (i > 0) - { - /// insertion point: the body of the previous i's loop body - comet_debug() << " -------- current opstree contain loops\n"; - builder.setInsertionPoint(opstree->forOps.back().getBody()->getTerminator()); - } - - Value &tensor = tensors[i]; - std::string format = formats[i]; - unsigned int id = ids[i]; - - comet_debug() << " current index format: " << format << "\n"; - if (format.compare(0, 1, "D") == 0) - { - /// Symbolic Phase - if (symbolicInfo.has_symbolic_phase) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertions point - setInsertionPointInSymbolicLoops(builder, - ancestorsOps, - opstree); - - scf::ForOp forLoop; - Value accessIndex; - genForOpFormat_D(builder, - loc, - tensor, - id, - i, - allAllocs, - forLoop /* output */, - accessIndex /* output */); - opstree->symbolicForOps.push_back(forLoop); - opstree->symbolicAccessIdx.push_back(accessIndex); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - /// Check which tensor is sparse, which is dense; - /// Since this function only handles mixed sparse/dense, then "D" only occurs in one tensor - /// Both the dense and sparse tensor contain the dim size; But they are different. Use one. - scf::ForOp forLoop; - Value accessIndex; - genForOpFormat_D(builder, - loc, - tensor, - id, - i, - allAllocs, - forLoop /* output */, - accessIndex /* output */); - opstree->forOps.push_back(forLoop); - opstree->accessIdx.push_back(accessIndex); - } - /// mix sparse dense tensor contraction, only one sparse tensor - else if (format.compare(0, 2, "CU") == 0) - { - /// Symbolic Phase - if (symbolicInfo.has_symbolic_phase) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertions point - setInsertionPointInSymbolicLoops(builder, - ancestorsOps, - opstree); - - scf::ForOp forLoop; - Value accessIndex; - scf::ForOp parent_forop; - Value parent_accessIdx; - if (nullptr != opstree->parent) - { - parent_forop = opstree->parent->symbolicForOps.back(); - parent_accessIdx = opstree->parent->symbolicAccessIdx.back(); - } - genForOpFormat_CU(builder, - loc, - opstree, - tensor, - id, - i, - allAllocs, - parent_forop, - parent_accessIdx, - forLoop /* output */, - accessIndex /* output */); - opstree->symbolicForOps.push_back(forLoop); - opstree->symbolicAccessIdx.push_back(accessIndex); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - /// Generate for(int m = pos[0]; m < pos[1]; m++){int i = crd[m];} - /// if i = 0, index is [0,1] - /// if parent loop and child loop is accessing the same sparse tensor (CSF), index is [m, m+1], m is the nearest loop induction variable - /// Otherwise, the m comes from load operation of the input sparse tensor such as - /// j = crd[i]; - /// for (int m = pos[j]; m < pos[j+1]; m++) - - scf::ForOp forLoop; - Value accessIndex; - scf::ForOp parent_forop; - Value parent_accessIdx; - if (nullptr != opstree->parent) - { - parent_forop = opstree->parent->forOps.back(); - parent_accessIdx = opstree->parent->accessIdx.back(); - } - genForOpFormat_CU(builder, - loc, - opstree, - tensor, - id, - i, - allAllocs, - parent_forop, - parent_accessIdx, - forLoop /* output */, - accessIndex /* output */); - opstree->forOps.push_back(forLoop); - opstree->accessIdx.push_back(accessIndex); - } - else if (format.compare(0, 2, "CN") == 0) - { - /// Generate for(int m = pos[0]; m < pos[1]; m++){int i = crd[m];} - scf::ForOp forLoop; - Value accessIndex; - genForOpFormat_CN(builder, - loc, - tensor, - id, - i, - allAllocs, - forLoop /* output */, - accessIndex /* output */); - opstree->forOps.push_back(forLoop); - opstree->accessIdx.push_back(accessIndex); - } - else if (format.compare(0, 1, "S") == 0) - { - /// Currently supported formats, Singleton is not the format of first dimension - /// and it doesn't produce a loop - /// Generate: int j = A2crd[m]; - scf::ForOp forLoop; - Value accessIndex; - std::vector &opstree_forops = opstree->forOps; - scf::ForOp parent_forop; - if (nullptr != opstree->parent) - { - parent_forop = opstree->parent->forOps.back(); - } - genForOpFormat_S(builder, - loc, - opstree, - tensor, - id, - i, - allAllocs, - opstree_forops, - parent_forop, - forLoop /* output */, - accessIndex /* output */); - opstree->forOps.push_back(forLoop); - opstree->accessIdx.push_back(accessIndex); - } - else - { - llvm::errs() << "Not supported format: " << format << "\n"; - } - - comet_debug() << "finish generate loops for current index format: " << format << "\n"; - } - } - - Value getSemiringSecondVal(OpBuilder &builder, Location &loc, - llvm::StringRef &semiringSecond, Value &Input0, Value &Input1, - bool compressedWorkspace) - { - - Value elementWiseResult; - if (semiringSecond == "times") - { - elementWiseResult = builder.create(loc, Input0, Input1); - } - else if (semiringSecond == "first") - { - elementWiseResult = Input0; - } - else if (semiringSecond == "second") - { - elementWiseResult = Input1; - } - else if (semiringSecond == "atan2") - { - elementWiseResult = builder.create(loc, Input0, Input1); - } - else if (semiringSecond == "div") - { - elementWiseResult = builder.create(loc, Input0, Input1); - } - else if (semiringSecond == "eq") - { - elementWiseResult = builder.create(loc, CmpFPredicate::OEQ, Input0, Input1); - } - else if (semiringSecond == "ge") - { - elementWiseResult = builder.create(loc, CmpFPredicate::OGE, Input0, Input1); - } - else if (semiringSecond == "gt") - { - elementWiseResult = builder.create(loc, CmpFPredicate::OGT, Input0, Input1); - } - else if (semiringSecond == "le") - { - elementWiseResult = builder.create(loc, CmpFPredicate::OLE, Input0, Input1); - } - else if (semiringSecond == "lt") - { - elementWiseResult = builder.create(loc, CmpFPredicate::OLT, Input0, Input1); - } - else if (semiringSecond == "land") - { - /// land requires integer type input - llvm::errs() << "Not supported semiring operator (only works for int datatypes): " - << "land" - << "\n"; - /// we should not proceed forward from this point to avoid faulty behavior. - exit(1); - } - else if (semiringSecond == "lor") - { - /// lor requires integer type input - llvm::errs() << "Not supported semiring operator (only works for int datatypes): " - << "lor" - << "\n"; - /// we should not proceed forward from this point to avoid faulty behavior. - exit(1); - } - else if (semiringSecond == "lxor") - { - /// lxor requires integer type input - llvm::errs() << "Not supported semiring operator: " - << "lxor" - << "\n"; - } - else if (semiringSecond == "minxy") - { - Value cmp = builder.create(loc, CmpFPredicate::OLT, Input0, Input1); - elementWiseResult = builder.create(loc, cmp, Input0, Input1); - } - else if (semiringSecond == "max") - { - Value cmp = builder.create(loc, CmpFPredicate::OGT, Input0, Input1); - elementWiseResult = builder.create(loc, cmp, Input0, Input1); - } - else if (semiringSecond == "ne") - { - elementWiseResult = builder.create(loc, CmpFPredicate::ONE, Input0, Input1); - } - else if (semiringSecond == "minus") - { - elementWiseResult = builder.create(loc, Input0, Input1); - } - else if (semiringSecond == "plusxy") - { - elementWiseResult = builder.create(loc, Input0, Input1); - } - else if (semiringSecond == "pairxy") - { - elementWiseResult = builder.create(loc, builder.getF64Type(), builder.getF64FloatAttr(1)); - } - else if (semiringSecond == "pow") - { - elementWiseResult = builder.create(loc, Input0, Input1); - } - else - { - llvm::errs() << "Not supported semiring operator: " << semiringSecond << "\n"; - /// we should not proceed forward from this point to avoid faulty behavior. - } - - return elementWiseResult; - } - - Value getSemiringFirstVal(OpBuilder &builder, Location &loc, - llvm::StringRef &semiringFirst, Value &Input0, Value &Input1, - bool compressedWorkspace) - { - - Value reduceResult; - if (semiringFirst == "times") - { - reduceResult = builder.create(loc, Input0, Input1); - } - else if (semiringFirst == "plusxy") - { - reduceResult = builder.create(loc, Input0, Input1); - } - else if (semiringFirst == "minxy") - { - if (!compressedWorkspace) - { - llvm::errs() << "Not supported semiring operator " - "(please use compressed workspace optimization or opt-comp-workspace " - "where this operation is known to work): " - << "min" - << "\n"; - /// we should not proceed forward from this point to avoid in-correct results from generated code. - } - Value cmp = builder.create(loc, CmpFPredicate::OLT, Input0, Input1); - reduceResult = builder.create(loc, cmp, Input0, Input1); - } - else if (semiringFirst == "max") - { - if (!compressedWorkspace) - { - llvm::errs() << "Not supported semiring operator " - "(please use compressed workspace optimization or opt-comp-workspace " - "where this operation is known to work): " - << "max" - << "\n"; - /// we should not proceed forward from this point to avoid in-correct results from generated code. - } - Value cmp = builder.create(loc, CmpFPredicate::OGT, Input0, Input1); - reduceResult = builder.create(loc, cmp, Input0, Input1); - } - else if (semiringFirst == "land") - { - /// land requires integer type input - llvm::errs() << "Not supported semiring operator (only works for int datatypes): " - << "land" - << "\n"; - /// we should not proceed forward from this point to avoid faulty behavior. - } - else if (semiringFirst == "lor") - { - /// lor requires integer type input - llvm::errs() << "Not supported semiring operator (only works for int datatypes): " - << "lor" - << "\n"; - /// we should not proceed forward from this point to avoid faulty behavior. - } - else if (semiringFirst == "any") - { - reduceResult = Input1; - } - else if (semiringFirst == "noop") - { - reduceResult = Input1; - } - else - { - llvm::errs() << "Not supported semiring operator: " << semiringFirst << "\n"; - /// we should not proceed forward from this point to avoid faulty behavior. - } - - return reduceResult; - } - - /// Generate numeric semiring kernel if statement condition - void genCmptOpKernelIfStatementCondition(OpBuilder &builder, - Location &loc, - NumericInfo &numericInfo, - MaskingInfo &maskingInfo, - scf::IfOp &if_notAlreadySet /* output */) - { - Value &is_visited_alloc = numericInfo.ws_bitmap; - Value &valueAccessIdx = numericInfo.ws_bitmap_valueAccessIdx; - - Value const_i1_false = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(0)); - Value const_i1_true = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(1)); - if (PUSH_BASED_MASKING == maskingInfo.mask_type) - { - /// if (mask_array[j] == true) { /// C[i,k] is allowed by the mask and has not been seen yet - /// if (ws_bitmap[j] != true) { - Value &mask_array = numericInfo.mask_array; - Value ele_mask_array = builder.create(loc, mask_array, ValueRange{valueAccessIdx}); - Value compare_true = builder.create(loc, CmpIPredicate::eq, ele_mask_array, const_i1_true); - auto if_mask_set = builder.create(loc, compare_true, false /* no else region */); - builder.setInsertionPointToStart(&if_mask_set.getThenRegion().front()); - Value ele_bitmap = builder.create(loc, is_visited_alloc, ValueRange{valueAccessIdx}); - Value compare_false = builder.create(loc, CmpIPredicate::eq, ele_bitmap, const_i1_false); - if_notAlreadySet = builder.create(loc, compare_false, /*WithElseRigion*/ true); - { - comet_vdump(ele_mask_array); - comet_vdump(if_mask_set); - comet_vdump(if_notAlreadySet); - } - } - else if (NO_MASKING == maskingInfo.mask_type) - { - /// if (ws_bitmap[j] != true) { - /// Workspace tensors are on the lhs - Value checkAlreadySet = builder.create(loc, is_visited_alloc, ValueRange{valueAccessIdx}); - Value notAlreadySet = builder.create(loc, CmpIPredicate::eq, checkAlreadySet, const_i1_false); - if_notAlreadySet = builder.create(loc, notAlreadySet, /*WithElseRegion*/ true); - { - comet_vdump(checkAlreadySet); - comet_vdump(notAlreadySet); - comet_vdump(if_notAlreadySet); - } - } - else - { - llvm::errs() << "Error: mask_type " << maskingInfo.mask_type << " is not supported.\n"; - } - } - - /// Generate numeric semiring kernel if statement then region - void genCmptOpKernelIfStatementThenRegion(OpBuilder &builder, - Location &loc, - int lhs_loc, - int main_tensor_nums, - scf::IfOp &if_notAlreadySet, - bool compressedWorkspace, - llvm::StringRef &semiringSecond, - std::vector> &main_tensors_all_Allocs, - std::vector> &tensors_lhs_Allocs, - std::vector> &allValueAccessIdx, - SymbolicInfo &symbolicInfo, - NumericInfo &numericInfo) - { - Value &ws_bitmap = numericInfo.ws_bitmap; - Value &ws_bitmap_valueAccessIdx = numericInfo.ws_bitmap_valueAccessIdx; - Value &W_id_list_size = tensors_lhs_Allocs[3][0]; - Value &mtxC_col = symbolicInfo.mtxC_col; - Value &W_data = main_tensors_all_Allocs[lhs_loc].back(); - Value &W_data_valueAccessIdx = ws_bitmap_valueAccessIdx; - - builder.setInsertionPointToStart(&if_notAlreadySet.getThenRegion().front()); - - /// Wj = Aik * Bkj /// computation wj, outer has k, so +=/= need if/else - /// W_already_set[j] = 1 - /// W_index_list[W_index_list_size] = j - /// W_index_list_size++ - - std::vector allLoadsIf(main_tensor_nums); - for (int m = 0; m < main_tensor_nums; m++) - { - Value s = builder.create(loc, main_tensors_all_Allocs[m][main_tensors_all_Allocs[m].size() - 1], allValueAccessIdx[m]); - allLoadsIf[m] = s; - comet_debug() << " "; - comet_vdump(s); - } - comet_debug() << " allLoadsIf.size(): " << allLoadsIf.size() << "\n"; - - comet_debug() << "calculate elementWise operation only\n"; - /// val = A[j_idx] * B[j_idx]; - /// W_data[j_idx] = val; - Value elementWiseResult = getSemiringSecondVal(builder, loc, semiringSecond, allLoadsIf[0], allLoadsIf[1], compressedWorkspace); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_sum = builder.create(loc, - elementWiseResult, - W_data, - W_data_valueAccessIdx); - comet_vdump(elementWiseResult); - comet_vdump(store_sum); -#else - builder.create(loc, - elementWiseResult, - W_data, - W_data_valueAccessIdx); -#endif - Value const_index_0 = builder.create(loc, 0); - Value const_index_1 = builder.create(loc, 1); - Value const_i1_true = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(1)); - - /// ws_bitmap[j_idx] = true; - builder.create(loc, const_i1_true, ws_bitmap, ws_bitmap_valueAccessIdx); - - Value W_id_list_size_old = builder.create(loc, W_id_list_size, ValueRange{const_index_0}); - - assert(allValueAccessIdx[lhs_loc].size() == 1 && " more than one access id for auxiliary array\n"); - - /// C.col[W_id_list_size] = j_idx; - builder.create(loc, - ws_bitmap_valueAccessIdx, - mtxC_col, - ValueRange{W_id_list_size_old}); - - /// W_id_list_size += 1 - Value W_id_list_size_new = builder.create(loc, W_id_list_size_old, const_index_1); - comet_debug() << " AddIOps (W_index_list_size_new)"; - comet_vdump(W_id_list_size_new); - - builder.create(loc, W_id_list_size_new, W_id_list_size, ValueRange{const_index_0}); - { - comet_vdump(if_notAlreadySet); - } - } - - /// Generate numeric semiring kernel if statement else region - void genCmptOpKernelIfStatementElseRegion(OpBuilder &builder, - Location &loc, - int lhs_loc, - int main_tensor_nums, - scf::IfOp &if_notAlreadySet, - bool compressedWorkspace, - llvm::StringRef &semiringFirst, - llvm::StringRef &semiringSecond, - std::vector> &main_tensors_all_Allocs, - std::vector> &allValueAccessIdx) - { - - Value &W_data = main_tensors_all_Allocs[lhs_loc].back(); - Value &W_data_valueAccessIdx = allValueAccessIdx[lhs_loc][0]; - - builder.setInsertionPointToStart(&if_notAlreadySet.getElseRegion().front()); - - std::vector allLoadsElse(main_tensor_nums); - for (auto m = 0; m < main_tensor_nums; m++) - { - Value s = builder.create(loc, main_tensors_all_Allocs[m][main_tensors_all_Allocs[m].size() - 1], allValueAccessIdx[m]); - allLoadsElse[m] = s; - comet_vdump(s); - } - comet_debug() << " allLoadsElse.size(): " << allLoadsElse.size() << "\n"; - - comet_debug() << "calculate elementWise operation and reduction\n"; - Value elementWiseResult = getSemiringSecondVal(builder, loc, semiringSecond, allLoadsElse[0], allLoadsElse[1], compressedWorkspace); - Value reduceResult = getSemiringFirstVal(builder, loc, semiringFirst, allLoadsElse[lhs_loc], elementWiseResult, compressedWorkspace); - builder.create(loc, reduceResult, W_data, W_data_valueAccessIdx); - { - comet_vdump(if_notAlreadySet); - } - } - - /// Generate the numeric bitmap - /// It should be deprecated in the future, as the bitmap would be lowered from the Index Tree dialect. - void genNumericBitmap(OpBuilder &builder, - Location &loc, - scf::ForOp &symbolic_outermost_forLoop, - SymbolicInfo &symbolicInfo, - Value &bitmap_alloc) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Jump Insertion Point to the front of the 2nd outermost for-loop - builder.setInsertionPoint(symbolic_outermost_forLoop); - - Value const_index_0 = builder.create(loc, 0); - Value const_index_1 = builder.create(loc, 1); - Value &mtxC_dim2_size = symbolicInfo.mtxC_num_cols; - - MemRefType memTy_dynamic_1i = MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); - bitmap_alloc = builder.create(loc, - memTy_dynamic_1i, - ValueRange{mtxC_dim2_size}, - builder.getI64IntegerAttr(8) /* alignment bytes */); - Value const_i1_0 = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(0)); - scf::ForOp init_forLoop = builder.create(loc, - const_index_0 /* lowerBound */, - mtxC_dim2_size /* upperBound */, - const_index_1 /* step */); - builder.setInsertionPointToStart(init_forLoop.getBody()); - Value i_idx = init_forLoop.getInductionVar(); - builder.create(loc, - const_i1_0, - bitmap_alloc, - ValueRange{i_idx}); - { - comet_vdump(bitmap_alloc); - comet_vdump(init_forLoop); - } - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - /// Generate numeric mask-array before the numeric outermost for-loop. - /// Please don't confuse with mark-array. - void genNumericMaskArray(OpBuilder &builder, - Location &loc, - scf::ForOp &numeric_outermost_forLoop, - SymbolicInfo &symbolicInfo, - NumericInfo &numericInfo /* output */) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion Point before the numeric outermost for-loop - builder.setInsertionPoint(numeric_outermost_forLoop); - - /// Generate the mask-array - Value &dim2_size = symbolicInfo.mtxC_num_cols; - MemRefType memTy_dynamic_i1 = MemRefType::get({ShapedType::kDynamic}, builder.getI1Type()); - Value mask_array_alloc = builder.create(loc, - memTy_dynamic_i1, - ValueRange{dim2_size}, - builder.getI64IntegerAttr(8) /* alignment bytes */); - - /// Initialize the mask-array - Value const_index_0 = builder.create(loc, 0); - Value const_index_1 = builder.create(loc, 1); - scf::ForOp mask_array_init_loop = builder.create(loc, - const_index_0 /* lowerBound */, - dim2_size /* upperBound */, - const_index_1 /* step */); - builder.setInsertionPointToStart(mask_array_init_loop.getBody()); - Value j_idx = mask_array_init_loop.getInductionVar(); - Value const_i1_false = builder.create(loc, - builder.getI1Type(), - builder.getBoolAttr(false)); - builder.create(loc, - const_i1_false, - mask_array_alloc, - ValueRange{j_idx}); - - numericInfo.mask_array = mask_array_alloc; - - { - comet_vdump(mask_array_alloc); - comet_vdump(mask_array_init_loop); - } - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - /// Generate setting the mask-array at the begining of the numeric outermost for-loop, - /// and resetting at the end of the outermost for-loop. - /// ----------------- /// - /// %j_loc_start = memref.load %mask_rowptr[%i_idx] : memref /// alloc_16 = mask.rowptr - /// %j_loc_bound = memref.load %mask_rowptr[%i_idx_plus_one] : memref - /// scf.for %arg1 = %j_loc_start to %j_loc_bound step %c1 { - /// %val = memref.load %mask_val[%arg1] : memref - /// %54 = arith.cmpf une, %val, %cst : f64 - /// scf.if %54 { - /// %j_idx = memref.load %mask_col[%arg1] : memref - /// memref.store %true, %mask_array[%j_idx] : memref - /// } - /// } - /// ----------------- /// - /// Reset mask_array at the end of numeric outermost for-loop - /// ----------------- /// - /// scf.for %arg1 = %j_loc_start to %j_loc_bound step %c1 { - /// %j_idx = memref.load %mask_col[%arg1] : memref - /// memref.store %false, %array_mask[%j_idx] : memref - /// } - void genNumericSetAndResetMaskArray(OpBuilder &builder, - Location &loc, - scf::ForOp &numeric_outermost_forLoop, - Value &outermost_forLoop_valueAccessIdx, - NumericInfo &numericInfo, - MaskingInfo &maskingInfo) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion Point before the numeric semiring for-loop - builder.setInsertionPointToStart(numeric_outermost_forLoop.getBody()); - - /// Generate the setting for-loop entry - Value &mask_array = numericInfo.mask_array; - Value &mask_rowptr = maskingInfo.mask_rowptr; - Value &mask_col = maskingInfo.mask_col; - Value &mask_val = maskingInfo.mask_val; - Value const_index_1 = builder.create(loc, 1); - Value &i_idx = outermost_forLoop_valueAccessIdx; - Value i_idx_plus_one = builder.create(loc, i_idx, const_index_1); - Value j_loc_start = builder.create(loc, mask_rowptr, ValueRange{i_idx}); - Value j_loc_bound = builder.create(loc, mask_rowptr, ValueRange{i_idx_plus_one}); - scf::ForOp init_for_loop = builder.create(loc, - j_loc_start /* lower_bound */, - j_loc_bound /* upper_bound*/, - const_index_1 /* step */); - - /// Generate the setting for-loop body - builder.setInsertionPointToStart(init_for_loop.getBody()); - Value const_f64_0 = builder.create(loc, builder.getF64Type(), builder.getF64FloatAttr(0)); - Value j_loc = init_for_loop.getInductionVar(); - Value val = builder.create(loc, mask_val, ValueRange{j_loc}); - Value not_zero = builder.create(loc, CmpFPredicate::UNE, val, const_f64_0); - auto if_not_zero = builder.create(loc, not_zero, false /*NoElseRegion*/); - builder.setInsertionPointToStart(&if_not_zero.getThenRegion().front()); - Value j_idx = builder.create(loc, mask_col, ValueRange{j_loc}); - Value const_i1_1 = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(true)); - builder.create(loc, - const_i1_1, - mask_array, - ValueRange{j_idx}); - { - comet_vdump(val); - comet_vdump(if_not_zero); - comet_vdump(init_for_loop); - } - - /// Generate the resetting for-loop entry after the semiring for-loop - builder.setInsertionPoint(numeric_outermost_forLoop.getBody()->getTerminator()); - scf::ForOp reset_for_loop = builder.create(loc, - j_loc_start /* lower_bound */, - j_loc_bound /* upper_bound*/, - const_index_1 /* step */); - - /// Generate the resetting for-loop body - builder.setInsertionPointToStart(reset_for_loop.getBody()); - j_loc = reset_for_loop.getInductionVar(); - j_idx = builder.create(loc, mask_col, ValueRange{j_loc}); - Value const_i1_0 = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(false)); - builder.create(loc, - const_i1_0, - mask_array, - ValueRange{j_idx}); - { - comet_vdump(reset_for_loop); - comet_vdump(numeric_outermost_forLoop); - } - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - void formSemiringLoopBody(indexTree::IndexTreeComputeOp &cur_op, - bool comp_worksp_opt, - llvm::StringRef &semiringFirst, - llvm::StringRef &semiringSecond, - OpBuilder &builder, Location &loc, int lhs_loc, - std::vector> &main_tensors_all_Allocs, - std::vector> &tensors_lhs_Allocs, - std::vector> &tensors_rhs_Allocs, - std::vector> &allValueAccessIdx, - std::vector> &allAccessIdx, - std::vector &forLoops /* numeric for-loop statements, from innermost to outermost*/, - std::vector &numeric_nested_forLoop_AccessIdx, - std::vector &symbolic_nested_forops /* symbolic for-loops from innermost to outermost */, - std::vector> &rhsPerms, - SymbolicInfo &symbolicInfo, - NumericInfo &numericInfo, - MaskingInfo &maskingInfo) - { - std::vector> rhsFormats; - getRHSFormatsOfComputeOp(cur_op.getOperation()->getResult(0), rhsFormats); - std::vector> lhsFormats; - getLHSFormatsOfComputeOp(cur_op.getOperation()->getResult(0), lhsFormats); - bool isMixedMode = checkIsMixedMode(rhsFormats); - bool isElementwise = checkIsElementwise(rhsPerms); - comet_debug() << " isElementwise:" << isElementwise << " isMixedMode: " << isMixedMode << "\n"; - auto ctx = builder.getContext(); - IndexType indexType = IndexType::get(ctx); - - if ((semiringFirst.size() == 0) || (semiringSecond.size() == 0)) - llvm::errs() << "Error during semiring parsing!" - << "\n"; - - if (main_tensors_all_Allocs.size() != allValueAccessIdx.size()) - llvm::errs() << "DEBUG ONLY: issue with main_tensor_nums size" - << "\n"; - - auto f64Type = builder.getF64Type(); - auto const_f64_0 = builder.create(loc, f64Type, builder.getF64FloatAttr(0)); - - int main_tensor_nums = main_tensors_all_Allocs.size(); - bool compressedWorkspace = false; - - if (comp_worksp_opt) /// always lhs is dense after workspace transformations - { - compressedWorkspace = true; - - /// Generate the numeric bitmap - if (numericInfo.ws_bitmap == nullptr) - { - Value bitmap_alloc; - genNumericBitmap(builder, - loc, - symbolic_nested_forops.back(), - symbolicInfo, - bitmap_alloc); - /// TODO(zpeng): numericInfo.ws_bitmap should be lowered from Index Tree dialect. - numericInfo.ws_bitmap = bitmap_alloc; - numericInfo.ws_bitmap_valueAccessIdx = allValueAccessIdx[lhs_loc][0]; - } - - /// Generate the mask-array (please not confuse with mark-array) - if (PUSH_BASED_MASKING == maskingInfo.mask_type) - { - /// Generate numeric mask-array before the numeric outermost for-loop. - /// Please don't confuse with mark-array. - genNumericMaskArray(builder, - loc, - forLoops.back() /* numeric_outermost_forLoop= */, - symbolicInfo, - numericInfo /* output */); - - /// Generate setting the mask-array before the numeric semiring for-loop and resetting after the semiring for-loop. - genNumericSetAndResetMaskArray(builder, - loc, - forLoops.back() /* numeric_outermost_forLoop */, - numeric_nested_forLoop_AccessIdx.back() /* outermost_forLoop_valueAccessIdx */, - numericInfo, - maskingInfo); - } - - /// Value &is_visited_alloc = tensors_lhs_Allocs[1][0]; - /// Value &is_visited_alloc_valAccessIdx = allValueAccessIdx[lhs_loc][0]; - scf::IfOp if_notAlreadySet; - genCmptOpKernelIfStatementCondition(builder, - loc, - numericInfo, - maskingInfo, - if_notAlreadySet /* output */); - - /// if-then region corresponding to if_notAlreadySet instruction. - /// if (&if_notAlreadySet. getThenRegion()) - if (!if_notAlreadySet.getThenRegion().empty()) - { - genCmptOpKernelIfStatementThenRegion(builder, - loc, - lhs_loc, - main_tensor_nums, - if_notAlreadySet, - compressedWorkspace, - semiringSecond, - main_tensors_all_Allocs, - tensors_lhs_Allocs, - allValueAccessIdx, - symbolicInfo, - numericInfo); - } - - /// if-else region corresponding to if_notAlreadySet instruction. - /// if (&if_notAlreadySet.getElseRegion()) - if (!if_notAlreadySet.getElseRegion().empty()) - { - genCmptOpKernelIfStatementElseRegion(builder, - loc, - lhs_loc, - main_tensor_nums, - if_notAlreadySet, - compressedWorkspace, - semiringFirst, - semiringSecond, - main_tensors_all_Allocs, - allValueAccessIdx); - } - } - else - { /// general dense or mixed mode computation, no need workspace transformations - std::vector allLoads(main_tensor_nums); - for (auto m = 0; m < main_tensor_nums; m++) - { - Value load_op = builder.create(loc, - main_tensors_all_Allocs[m][main_tensors_all_Allocs[m].size() - 1], allValueAccessIdx[m]); - allLoads[m] = load_op; - comet_debug() << " "; - comet_vdump(load_op); - } - comet_debug() << " allLoads.size(): " << allLoads.size() << "\n"; - - /// if computeOp is elementwise mixed mode operation, the output is sparse - if (isMixedMode && isElementwise && !checkIsDense(lhsFormats[0])) - { - - int dense_inputtensor_id = 0; - for (unsigned int i = 0; i < rhsFormats.size(); i++) - { - if (checkIsDense(rhsFormats[i])) - { - dense_inputtensor_id = i; - break; - } - } - - int sparse_inputtensor_id = dense_inputtensor_id ? 0 : 1; - std::string sparse_format = getTensorFormat(rhsFormats, sparse_inputtensor_id); - - auto last_insertionPoint = builder.saveInsertionPoint(); - - /// Need to initialize some memory accesses outside the nested loop - /// Reset the insertion point: the body of the innermost loop - comet_debug() << "LoopSize: " << forLoops.size() << " Loop:\n"; - comet_vdump(forLoops[forLoops.size() - 1]); - builder.setInsertionPoint(forLoops[forLoops.size() - 1]); - - Value const_index_0 = builder.create(loc, 0); - MemRefType memTy_alloc_Cnnz = MemRefType::get({1}, indexType); - Value alloc_Cnnz = builder.create(loc, memTy_alloc_Cnnz); - comet_debug() << " AllocOp for Cnnz: "; - comet_vdump(alloc_Cnnz); - - std::vector alloc_Cnnz_insert_loc = {const_index_0}; -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_Cnnz = builder.create(loc, const_index_0, alloc_Cnnz, alloc_Cnnz_insert_loc); - comet_debug() << " StoreOp: "; - comet_vdump(store_Cnnz); -#else - builder.create(loc, const_index_0, alloc_Cnnz, alloc_Cnnz_insert_loc); -#endif - - /// The following code block is needed to update Update C2pos in the case of output tensor is in DCSR - Value Cnnz_index_old; - Value alloc_Cnnz_row; - if (sparse_format.compare("DCSR") == 0) - { - alloc_Cnnz_row = builder.create(loc, memTy_alloc_Cnnz); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_Cnnz_row = builder.create(loc, const_index_0, alloc_Cnnz_row, - alloc_Cnnz_insert_loc); - comet_debug() << " StoreOp DCSR: "; - comet_vdump(store_Cnnz_row); -#else - builder.create(loc, const_index_0, alloc_Cnnz_row, alloc_Cnnz_insert_loc); -#endif - /// Get Cnnz_old - Cnnz_index_old = builder.create(loc, alloc_Cnnz, alloc_Cnnz_insert_loc); - } - - builder.restoreInsertionPoint(last_insertionPoint); - - comet_debug() << " dense_inputtensor_id: " << dense_inputtensor_id << "\n"; - comet_debug() << " sparse_inputtensor_id: " << sparse_inputtensor_id << "\n"; - Value denseInput_is_nonzero = builder.create(loc, CmpFPredicate::ONE, allLoads[dense_inputtensor_id], - const_f64_0); - auto if_nonzero = builder.create(loc, denseInput_is_nonzero, /*WithElseRegion*/ false); - comet_debug() << " If branch:\n"; - comet_vdump(if_nonzero); - - if (!if_nonzero.getThenRegion().empty()) - { - - builder.setInsertionPointToStart(&if_nonzero.getThenRegion().front()); - - comet_debug() << "calculate product and sum in \n"; - Value elementWiseResult = getSemiringSecondVal(builder, loc, semiringSecond, allLoads[0], allLoads[1], - compressedWorkspace); - - /// Get Cnnz - Value Cnnz_index = builder.create(loc, alloc_Cnnz, alloc_Cnnz_insert_loc); - -/// Store product to Cval -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << "Store product to Cval\n"; - auto store_Cval = builder.create(loc, elementWiseResult, main_tensors_all_Allocs[2][main_tensors_all_Allocs[2].size() - 1], Cnnz_index); - comet_debug() << " StoreOp: "; - comet_vdump(store_Cval); - - /// Update C1crd, C2crd - comet_debug() << "Getting A1crd\n"; - comet_debug() << "allValueAccessIdx[" << sparse_inputtensor_id << "].size(): " - << allAccessIdx[sparse_inputtensor_id].size() << "\n"; - comet_vdump(allAccessIdx[sparse_inputtensor_id][0]); - - for (unsigned int i = 0; i < allAccessIdx.size(); i++) - { - comet_debug() << "allAccessIdx[" << i << "].size(): " << allAccessIdx[i].size() << "\n"; - for (auto n : allAccessIdx[i]) - { - comet_vdump(n); - } - } -#else - builder.create(loc, elementWiseResult, main_tensors_all_Allocs[2][main_tensors_all_Allocs[2].size() - 1], Cnnz_index); -#endif - - comet_debug() << "Store C1crd\n"; - /// Branch out COO... CSR... DCSR... - if (sparse_format.compare("COO") == 0) - { - comet_debug() << "COO format for Elementwise MulOp, update all coordinates\n"; - for (unsigned d = 0; d < rhsPerms[sparse_inputtensor_id].size(); d++) - { - Value crd = allAccessIdx[sparse_inputtensor_id][d]; -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_coo_crd = builder.create(loc, crd, main_tensors_all_Allocs[2][4 * d + 1], - Cnnz_index); - comet_debug() << " COO StoreOp: "; - comet_vdump(store_coo_crd); -#else - builder.create(loc, crd, main_tensors_all_Allocs[2][4 * d + 1], Cnnz_index); -#endif - } - } - else if (sparse_format.compare("CSR") == 0 || sparse_format.compare("DCSR") == 0) - { - for (unsigned int d = forLoops.size() - 1; d < rhsPerms[sparse_inputtensor_id].size(); d++) - { - Value crd = allAccessIdx[sparse_inputtensor_id][d]; -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_csr_crd = builder.create(loc, crd, main_tensors_all_Allocs[2][4 * d + 1], - Cnnz_index); - comet_debug() << " CSR or DCSR StoreOp: "; - comet_vdump(store_csr_crd); -#else - builder.create(loc, crd, main_tensors_all_Allocs[2][4 * d + 1], Cnnz_index); -#endif - } - } - - /// Update Cnnz - comet_debug() << "Update Cnnz\n"; - Value const_index_1 = builder.create(loc, 1); - Value new_Cnnz_index = builder.create(loc, Cnnz_index, const_index_1); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << "AddIOps (new_Cnnz_index): "; - comet_vdump(new_Cnnz_index); - auto store_updated_cnnz = builder.create(loc, new_Cnnz_index, alloc_Cnnz, - alloc_Cnnz_insert_loc); - comet_debug() << " Update Cnnz (store new value) StoreOp: "; - comet_vdump(store_updated_cnnz); -#else - builder.create(loc, new_Cnnz_index, alloc_Cnnz, alloc_Cnnz_insert_loc); -#endif - } - - /// Need to identify dense tensor upperbound to be able to update Cpos and Csize arrays - std::vector denseAllocs = tensors_rhs_Allocs[dense_inputtensor_id]; - assert(denseAllocs.size() == 1); - - comet_debug() << " DenseAllocs: "; - auto inputType = denseAllocs[0].getType(); - std::vector denseDimsSize; - for (unsigned rank = 0; rank < inputType.cast().getRank(); rank++) - { - auto dimSize = inputType.cast().getDimSize(rank); - Value upperBound; - if (dimSize == ShapedType::kDynamic) - { - comet_debug() << " This dimension is a dynamic size:\n"; - unsigned dynamicDimPos = inputType.dyn_cast().getDynamicDimIndex(rank); - comet_debug() << " DynamicDimPos: " << dynamicDimPos << "\n"; - upperBound = denseAllocs[0].getDefiningOp()->getOperand(dynamicDimPos); - comet_vdump(upperBound); - } - else - { - comet_debug() << " This dimension is a static size\n"; - upperBound = builder.create(loc, dimSize); - comet_vdump(upperBound); - } - denseDimsSize.push_back(upperBound); - } - - /// To update Cpos - if (sparse_format.compare("CSR") == 0) - { - builder.setInsertionPointAfter(forLoops[0]); - Value const_index_1 = builder.create(loc, 1); - Value arg0_next = builder.create(loc, forLoops[1].getInductionVar(), const_index_1); - comet_debug() << "AddIOp (arg0_next): "; - comet_vdump(arg0_next); - - Value Cnnz_index_final = builder.create(loc, alloc_Cnnz, alloc_Cnnz_insert_loc); - builder.create(loc, Cnnz_index_final, main_tensors_all_Allocs[2][4], arg0_next); /// 2 - - builder.setInsertionPointAfter(forLoops[1]); - /// Update C2pos[0] - comet_debug() << "Update C2pos[0]\n"; - std::vector insert_loc_0 = {const_index_0}; - builder.create(loc, const_index_0, main_tensors_all_Allocs[2][4], insert_loc_0); /// 2 - - /// Update C1pos[0] - comet_debug() << "Update C1pos[0]\n"; - Value dim0_index = denseDimsSize[0]; - builder.create(loc, dim0_index, main_tensors_all_Allocs[2][0], insert_loc_0); - } - else - { - if (sparse_format.compare("DCSR") == 0) - { - /// Update C2pos - comet_debug() << "Update DCSR C2pos\n"; - builder.setInsertionPointAfter(forLoops[0]); - auto Cnnz_index_new = builder.create(loc, alloc_Cnnz, alloc_Cnnz_insert_loc); - auto has_nnz_row = builder.create(loc, CmpIPredicate::ne, Cnnz_index_new, Cnnz_index_old); - auto has_nnz_row_ifOp = builder.create(loc, has_nnz_row, /*WithElseRegion*/ false); - comet_debug() << " If branch:\n"; - comet_vdump(has_nnz_row_ifOp); - - if (!has_nnz_row_ifOp.getThenRegion().empty()) - { - builder.setInsertionPointToStart(&has_nnz_row_ifOp.getThenRegion().front()); - - Value const_index_1 = builder.create(loc, 1); - Value arg0_next = builder.create(loc, forLoops[1].getInductionVar(), const_index_1); - comet_debug() << "AddIOp (arg0_next): "; - comet_vdump(arg0_next); - - Value Cnnz_index_final = builder.create(loc, alloc_Cnnz, alloc_Cnnz_insert_loc); - builder.create(loc, Cnnz_index_final, main_tensors_all_Allocs[2][4], arg0_next); /// C2pos //2 - Value Cnnz_row_index = builder.create(loc, alloc_Cnnz_row, alloc_Cnnz_insert_loc); - Value idx_i = allAccessIdx[sparse_inputtensor_id][0]; - builder.create(loc, /*i*/ idx_i, main_tensors_all_Allocs[2][1], Cnnz_row_index); /// C1crd - Value Cnnz_row_index_new = builder.create(loc, Cnnz_row_index, const_index_1); - comet_debug() << "AddIOp (Cnnz_row_index_new): "; - comet_vdump(Cnnz_row_index_new); - builder.create(loc, Cnnz_row_index_new, alloc_Cnnz_row, - alloc_Cnnz_insert_loc); /// Update Cnnz_row - } - - builder.setInsertionPointAfter(forLoops[1]); - Value const_index_1 = builder.create(loc, 1); - std::vector insert_loc_1 = {const_index_1}; - - /// Update C2pos[0] - std::vector insert_loc_0 = {const_index_0}; - builder.create(loc, const_index_0, main_tensors_all_Allocs[2][4], insert_loc_0); /// 2 - - /// Update C1pos[0], C1pos[1] - Value Cnnz_row_index = builder.create(loc, alloc_Cnnz_row, alloc_Cnnz_insert_loc); - builder.create(loc, const_index_0, main_tensors_all_Allocs[2][0], insert_loc_0); - builder.create(loc, Cnnz_row_index, main_tensors_all_Allocs[2][0], insert_loc_1); - } - else - { - if (sparse_format.compare("COO") == 0) - { - /// Finally, Update C1pos - comet_debug() << "Update C1pos\n"; - builder.setInsertionPointAfter(forLoops[0]); - Value Cnnz_index_final = builder.create(loc, alloc_Cnnz, alloc_Cnnz_insert_loc); - Value const_index_1 = builder.create(loc, 1); - builder.create(loc, const_index_0, main_tensors_all_Allocs[2][0], const_index_0); - builder.create(loc, Cnnz_index_final, main_tensors_all_Allocs[2][0], const_index_1); - } - else - llvm::errs() << "/// Coordinate values are not updated for output sparse tensor in " << sparse_format - << " format\n"; - } - } - - } /// end if (isMixedMode && isElementwise) - else - { - /// calculate elementWise operation and reduction for general dense or mix mode computation (which has dense output) - comet_debug() - << "calculate elementWise operation and reduction for general dense or mix mode computation (which has dense output)\n"; - Value elementWiseResult = getSemiringSecondVal(builder, loc, semiringSecond, allLoads[0], allLoads[1], - compressedWorkspace); - Value reduceResult = getSemiringFirstVal(builder, loc, semiringFirst, allLoads[2], elementWiseResult, - compressedWorkspace); - builder.create(loc, reduceResult, - main_tensors_all_Allocs[2][main_tensors_all_Allocs[2].size() - 1], - allValueAccessIdx[2]); - } - } - } - - /// ----------------- /// - /// Generate Cij = Wj node, gathering the results in the workspace to the sparse output C.val. - /// Called by genCmptOps(). - /// ----------------- /// - /// sort(C.col, C.rowptr[i_idx], C.rowptr[i_idx + 1]); - /// for (int j_loc = C.rowptr[i_idx]; j_loc < C.rowptr[i_idx + 1]; ++j_loc) { - /// int j_idx = C.col[j_loc]; - /// C.val[j_idx] = W_data[j_idx]; - /// is_visited[j_idx] = false; - /// } - /// ----------------- /// - /// %rowptr_bound = memref.load %rowptr[%c0] : memref<1xindex> - /// %C_col_ptr = memref.cast %C_col : memref to memref<*xindex> - /// func.call @comet_sort_index(%C_col_ptr, %rowptr_start, %rowptr_bound) : (memref<*xindex>, index, index) -> () - /// - /// scf.for %ptr = %rowptr_start to %rowptr_bound step %c1 { - /// %c_col_id = memref.load %C_col[%ptr] : memref /// c_col_id = C_col[ptr] - /// %data = memref.load %ws_data[%c_col_id] : memref /// data = ws_data[c_col_id] - /// memref.store %data, %C_val[%ptr] : memref /// C_val[ptr] = data - /// memref.store %false, %ws_bitmap[%c_col_id] : memref /// ws_bitmap[c_col_id] = false - /// } - void genWorkspaceCmptOpGatherFromWorkspaceToOutput(OpBuilder &builder, - Location &loc, - std::vector> &tensors_rhs_Allocs, - std::vector &nested_forops, - std::vector &nested_AccessIdx, - SymbolicInfo &symbolicInfo, - NumericInfo &numericInfo) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - assert(nested_forops.size() >= 2 && nested_AccessIdx.size() >= 2 && "Error: should be at least 2 levels of for-loop.\n"); - scf::ForOp &curr_for_loop = nested_forops[0]; - scf::ForOp parent_for_loop = nested_forops[1]; - - /// Set the insertion point before the innermost for-loop - builder.setInsertionPoint(curr_for_loop); - - /// Get and set the boundary of current for-loop - /// %rowptr_start = memref::LoadOp %C_rowptr[%i_idx] : memref - /// %id_idx_plus_one = arith.addi %rowptr_start, %c1 : index - /// %rowptr_bound = memref::LoadOp %C_rowptr[%i_idx_plus_one] : memref - Value const_index_1 = builder.create(loc, 1); - Value i_idx = nested_AccessIdx[1]; - Value i_idx_plus_one = builder.create(loc, i_idx, const_index_1); - Value &mtxC_rowptr = symbolicInfo.mtxC_rowptr; - Value rowptr_start = builder.create(loc, mtxC_rowptr, ValueRange{i_idx}); - Value rowptr_bound = builder.create(loc, mtxC_rowptr, ValueRange{i_idx_plus_one}); - { - comet_vdump(parent_for_loop); - comet_vdump(i_idx); - comet_vdump(rowptr_start); - comet_vdump(rowptr_bound); - } - - /// Generate calling comet_sort_index - /// %C_col_ptr = memref.cast %C_col : memref to memref<*xindex> - /// func.call @comet_sort_index(%C_col_ptr, %rowptr_start, %rowptr_bound) : (memref<*xindex>, index, index) -> () - std::string func_name = "comet_sort_index"; - Value &mtxC_col = symbolicInfo.mtxC_col; - IndexType indexType = IndexType::get(builder.getContext()); - Value C_col_cast = builder.create(loc, - UnrankedMemRefType::get(indexType, 0), - mtxC_col); - builder.create(loc, - func_name, - SmallVector{}, - ValueRange{C_col_cast, rowptr_start, rowptr_bound}); - - /// Change current for-loop boundaries - curr_for_loop.setLowerBound(rowptr_start); - curr_for_loop.setUpperBound(rowptr_bound); - - /// Generate current for-loop body - Value &mtxC_val = symbolicInfo.mtxC_val; - Value &ws_data = tensors_rhs_Allocs[0][0]; - Value &ws_bitmap = numericInfo.ws_bitmap; - Value rowptr = curr_for_loop.getInductionVar(); - builder.setInsertionPointToStart(curr_for_loop.getBody()); - Value c_col_id = builder.create(loc, mtxC_col, ValueRange{rowptr}); - Value data = builder.create(loc, ws_data, ValueRange{c_col_id}); - builder.create(loc, - data, - mtxC_val, - ValueRange{rowptr}); - Value const_i1_0 = builder.create(loc, builder.getI1Type(), builder.getBoolAttr(false)); - builder.create(loc, - const_i1_0, - ws_bitmap, - ValueRange{c_col_id}); - { - comet_vdump(c_col_id); - comet_vdump(data); - comet_vdump(curr_for_loop); - } - - /// Free up ws_data and ws_bitmap after - builder.setInsertionPointAfter(parent_for_loop); - builder.create(loc, ws_data); - builder.create(loc, ws_bitmap); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - /// In genCmptOps, get current compute node's numeric nested for-loop and access indices. - void getNumericNestedForOpsAndAccessIdx(std::vector &ancestorsWps, - std::vector &ancestorsOps, - std::vector &nested_forops /* output */, - std::vector &nested_AccessIdx /* output */, - std::vector &nested_forops_indices /* output */) - { - - for (unsigned int i = 0; i < ancestorsOps.size(); i++) - { - comet_debug() << " ancestorsOps[" << i << "]->forOps.size(): " << ancestorsOps[i]->forOps.size() - << ", ancestorsOps->id: " - << ancestorsOps[i]->id << "\n"; - if (!ancestorsOps[i]->forOps.empty()) - { /// for loops OpsTree node - for (int j = ancestorsOps[i]->forOps.size() - 1; j >= 0; j--) - { - comet_debug() << " j: " << j << "\n"; - nested_forops.push_back(ancestorsOps[i]->forOps[j]); - comet_debug() << "AccessIdx: " << ancestorsOps[i]->accessIdx[j] << "\n"; - nested_AccessIdx.push_back(ancestorsOps[i]->accessIdx[j]); - } - } - } - comet_debug() << " nested_forops.size(): " << nested_forops.size() << "\n"; - for (unsigned int i = 0; i < ancestorsWps.size(); i++) - { - comet_debug() << " "; - comet_vdump(ancestorsWps[i]); - - if (indexTree::IndexTreeIndicesOp cur_op = dyn_cast( - ancestorsWps[i].getDefiningOp())) - { - /// Get indices - ArrayAttr op_indices = cur_op.getIndices(); - - if (op_indices.size() > 0) - { /// for loops OpsTree node - for (int j = op_indices.size() - 1; j >= 0; j--) - { - /// Get the indices; - int64_t idx = op_indices[j].cast().getInt(); - nested_forops_indices.push_back(idx); - } - } - } - } - } - - /// In genCmptOps, get current compute node's RHS, LHS, tensors, formats, perms, etc. - void getNumericTensors(indexTree::IndexTreeComputeOp &cur_op, - std::vector &tensors_rhs /* output */, - std::vector> &tensors_lhs_Allocs /* output */, - std::vector> &tensors_rhs_Allocs /* output */, - std::vector> &allFormats /*output*/, - std::vector> &allPerms /* output */, - std::vector> &allPerms_rhs /* output */, - std::vector &main_tensors_all /* output */, - std::vector &main_tensors_rhs /* output */) - { - comet_vdump(cur_op); - for (auto n : cur_op.getRhs()) - { - comet_debug() << " "; - comet_vdump(n); - for (unsigned i = 0; i < n.getDefiningOp()->getNumOperands(); i++) - { - comet_debug() << " "; - comet_vdump(n.getDefiningOp()->getOperand(i)); - tensors_rhs.push_back(n.getDefiningOp()->getOperand(i)); - } - } - - std::vector tensors_lhs; /// inner - for (unsigned i = 0; i < cur_op.getLhs().getDefiningOp()->getNumOperands(); i++) - { - comet_debug() << " "; - comet_vdump(cur_op.getLhs().getDefiningOp()->getOperand(i)); - tensors_lhs.push_back(cur_op.getLhs().getDefiningOp()->getOperand(i)); - } - - /// Currently, only one case, the rhs is constant. Wj = 0.0; - tensors_lhs_Allocs = getAllAllocs(tensors_lhs); /// output - comet_debug() << " tensors_lhs_Allocs.size(): " << tensors_lhs_Allocs.size() << "\n"; - tensors_rhs_Allocs = getAllAllocs(tensors_rhs); /// output - comet_debug() << " tensors_rhs_Allocs.size(): " << tensors_rhs_Allocs.size() << "\n"; - -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << " tensors_rhs_Allocs: \n"; - for (auto m : tensors_rhs_Allocs) - { - comet_debug() << " "; - for (auto n : m) - { - comet_vdump(n); - } - comet_debug() << "\n"; - } -#endif - - getPermsOfComputeOp(cur_op.getOperation()->getResult(0), allPerms); - - comet_debug() << " allPerms: \n"; - for (auto m : allPerms) - { - comet_debug() << " "; /// print_vector(m); - for (auto n : m) - { - comet_debug() << n << " "; - } - comet_debug() << "\n"; - } - - getFormatsOfComputeOp(cur_op.getOperation()->getResult(0), allFormats); - comet_debug() << " allFormats: \n"; - for (auto m : allFormats) - { - comet_debug() << " "; - for (auto n : m) - { - comet_debug() << n << " "; - } - comet_debug() << "\n"; - } - - comet_debug() << " "; - comet_vdump(cur_op); - - assert(allPerms.size() == allFormats.size() && "allPerms.size() != allFormats.size()\n"); - for (unsigned int m = 0; m < allPerms.size(); m++) - { - assert(allPerms[m].size() == allFormats[m].size() && "allPerms[m].size() != allFormats[m].size()\n"); - } - comet_debug() << " allPerms.size(): " << allPerms.size() << "\n"; - /// tensor_nums means the actual tensors except the auxiliary tensors - /// Suppose for LHSOp, there are "n" real tensors, then allPerms[m].size() - - getRHSPermsOfComputeOp(cur_op.getOperation()->getResult(0), allPerms_rhs); - comet_debug() << " allPerms_rhs.size(): " << allPerms_rhs.size() << "\n"; - std::vector> allPerms_lhs; /// inner - getLHSPermsOfComputeOp(cur_op.getOperation()->getResult(0), allPerms_lhs); - - comet_debug() << " allPerms_lhs.size(): " << allPerms_lhs.size() << "\n"; - std::vector main_tensors_lhs; /// inner - if (tensors_rhs.size() == allPerms_rhs.size()) - { /// all are "main" tensors - main_tensors_rhs.insert(main_tensors_rhs.end(), tensors_rhs.begin(), tensors_rhs.end()); - } - else - { /// the rhs contains the auxiliary tensors - assert(allPerms_rhs.size() == 1 && - " rhs contains auxiliary tensors and main tensors at the same time, not support currently\n"); /// only 1 main tensor on rhs - main_tensors_rhs.push_back(tensors_rhs[0]); - } - comet_debug() << " main_tensors_rhs.size(): " << main_tensors_rhs.size() << "\n"; - if (tensors_lhs.size() == allPerms_lhs.size()) - { /// all are "main" tensors - main_tensors_lhs.insert(main_tensors_lhs.end(), tensors_lhs.begin(), tensors_lhs.end()); - } - else - { /// the lhs contains the auxiliary tensors - assert(allPerms_lhs.size() == 1 && - " lhs contains auxiliary tensors and main tensors at the same time, not support currently\n"); /// only 1 main tensor on lhs - main_tensors_lhs.push_back(tensors_lhs[0]); - } - comet_debug() << " main_tensors_lhs.size(): " << main_tensors_lhs.size() << "\n"; - - main_tensors_all = main_tensors_rhs; - main_tensors_all.insert(main_tensors_all.end(), main_tensors_lhs.begin(), main_tensors_lhs.end()); - comet_debug() << " main_tensors_all.size(): " << main_tensors_all.size() << "\n"; - } - - /// In genCmptOps, get for-loops' value access indices. - /// A value access index is not necessarily the for-loop's induction variable. - /// For example, To access sparse matrix C.val, we need to get rowptr = C.col[idx], then rowptr is the access index. - /// This function is used both by Numeric Phase and Symbolic Phase. - void getForLoopsValueAccessIdx(OpBuilder &builder, - Location &loc, - int main_tensor_nums, - std::vector> &allPerms, - std::vector> &allFormats, - std::vector &main_tensors_all, - std::vector &nested_forops, - std::vector &nested_AccessIdx, - std::vector &nested_forops_indices, - std::vector> &main_tensors_all_Allocs, - std::vector> &allAccessIdx /* output */, - std::vector> &allValueAccessIdx /* output */) + /// ----------------- /// + /// Add declaration of the function comet_index_func; + /// ----------------- /// + void declareSortFunc(ModuleOp &module, + MLIRContext *ctx, + Location loc) { + IndexType indexType = IndexType::get(ctx); - std::vector> allLoopsArg(main_tensor_nums); /// inner - /// std::vector> allAccessIdx(main_tensor_nums); /// output - for (unsigned int i = 0; i < main_tensors_all.size(); i++) - { - for (unsigned int j = 0; j < allPerms[i].size(); j++) - { - unsigned int index_loc = findIndexInVector(nested_forops_indices, allPerms[i][j]); - comet_debug() << " index_loc " << index_loc << "\n"; - comet_debug() << " Perm: " << allPerms[i][j] << "\n"; - comet_debug() << " Format: " << allFormats[i][j] << "\n"; - assert(index_loc < nested_forops.size() && - "index_loc < nested_forops.size(), i.e. the index not exist in nested for loop\n"); - allLoopsArg[i].push_back(nested_forops[index_loc].getInductionVar()); - allAccessIdx[i].push_back(nested_AccessIdx[index_loc]); - } - /// Consider for the case w_index_list_size - /// if allPerms[i].size() == 0 - } - - /// std::vector> allValueAccessIdx(main_tensor_nums); /// output - for (int i = 0; i < main_tensor_nums; i++) - { /// If constantOp, do not consider it - comet_debug() << " "; - comet_vdump(main_tensors_all[i]); - if (main_tensors_all[i].getType().isa()) - { /// sparse tensor - - /// Find the last sparse index m, then loop_arg * all dense loop args - unsigned lastSparseIndexLoc = allPerms[i].size(); - for (int d = (int)allPerms[i].size() - 1; d >= 0; d--) - { - if (allFormats[i][d].compare(0, 1, "D") != 0 && - allFormats[i][d].compare(0, 1, "S") != 0) - { /// sparse dimension and has a loop, i.e. "CU" or "CN" - lastSparseIndexLoc = d; - break; - } - } - /// Calculate for ModeGeneric style format: [CN, S, D (, ... ) ] - auto valueAccessIdx_part = allLoopsArg[i][lastSparseIndexLoc]; - if (lastSparseIndexLoc < allPerms[i].size() - 1) - { /// There is dense index after the sparse index - unsigned int last_d = lastSparseIndexLoc + 1; - for (unsigned int d = lastSparseIndexLoc + 1; d < allPerms[i].size(); d++) - { /// i=0 - if (allFormats[i][d].compare(0, 1, "D") == 0) - { - /// Get dense dim size - auto index_0 = builder.create(loc, 0); - std::vector upper_indices = {index_0}; - auto upperBound = builder.create(loc, main_tensors_all_Allocs[i][4 * d], upper_indices); - comet_vdump(upperBound); - valueAccessIdx_part = builder.create(loc, upperBound, valueAccessIdx_part); - last_d = d; - } - } - if (allFormats[i][last_d].compare(0, 1, "D") == 0) - { - comet_debug() << " "; - comet_vdump(allLoopsArg[i][allLoopsArg[i].size() - 1]); - comet_vdump(valueAccessIdx_part); - valueAccessIdx_part = builder.create(loc, allLoopsArg[i][allLoopsArg[i].size() - 1], - valueAccessIdx_part); - comet_debug() << " AddIOps (valueAccessIdx_part): "; - comet_vdump(valueAccessIdx_part); - } - } - - allValueAccessIdx[i].push_back(valueAccessIdx_part); - } - else if (main_tensors_all[i].getType().isa()) - { /// dense tensor - allValueAccessIdx[i] = allAccessIdx[i]; - } - } - -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - for (unsigned int i = 0; i < allValueAccessIdx.size(); i++) + /// Declare comet_sort_index() + auto sort_index_func = FunctionType::get(ctx, + {UnrankedMemRefType::get(indexType, 0), indexType, indexType} /* inputs */, {} /* return */); + std::string func_name = "comet_sort_index"; + if (!hasFuncDeclaration(module, func_name /* func name */)) { - comet_debug() << "allValueAccessIdx[" << i << "].size(): " << allValueAccessIdx[i].size() - << ", main_tensors_all_Allocs[" << i << "].size()-1: " << main_tensors_all_Allocs[i].size() - 1 - << "\n"; + func::FuncOp func_declare = func::FuncOp::create(loc, + func_name, + sort_index_func, + ArrayRef{}); + func_declare.setPrivate(); + module.push_back(func_declare); } -#endif } - /// In genCmptOps, get current compute node's symbolic nested for-loop and access indices. - void getSymbolicNestedForOpsAndAccessIdx(std::vector &ancestorsWps, - std::vector &ancestorsOps, - std::vector &nested_forops /* output */, - std::vector &nested_AccessIdx /* output */, - std::vector &nested_forops_indices /* output */) + Value getSemiringSecondVal(OpBuilder &builder, Location &loc, + llvm::StringRef &semiringSecond, Value &Input0, Value &Input1) { - for (unsigned int i = 0; i < ancestorsOps.size(); i++) - { - comet_debug() << " ancestorsOps[" << i << "]->forOps.size(): " << ancestorsOps[i]->symbolicForOps.size() - << ", ancestorsOps->id: " - << ancestorsOps[i]->id << "\n"; - if (!ancestorsOps[i]->symbolicForOps.empty()) - { /// for loops OpsTree node - for (int j = ancestorsOps[i]->symbolicForOps.size() - 1; j >= 0; j--) - { - comet_debug() << " j: " << j << "\n"; - nested_forops.push_back(ancestorsOps[i]->symbolicForOps[j]); - comet_debug() << "AccessIdx: " << ancestorsOps[i]->symbolicAccessIdx[j] << "\n"; - nested_AccessIdx.push_back(ancestorsOps[i]->symbolicAccessIdx[j]); - } - } - } - comet_debug() << " nested_forops.size(): " << nested_forops.size() << "\n"; - /// std::vector nested_forops_indices; - for (unsigned int i = 0; i < ancestorsWps.size(); i++) + Value elementWiseResult; + if (semiringSecond == "times") { - comet_debug() << " "; - comet_vdump(ancestorsWps[i]); - - if (indexTree::IndexTreeIndicesOp cur_op = dyn_cast( - ancestorsWps[i].getDefiningOp())) - { - /// Get indices - ArrayAttr op_indices = cur_op.getIndices(); - - if (op_indices.size() > 0) - { /// for loops OpsTree node - for (int j = op_indices.size() - 1; j >= 0; j--) - { - /// Get the indices; - int64_t idx = op_indices[j].cast().getInt(); - nested_forops_indices.push_back(idx); - } - } - } + elementWiseResult = builder.create(loc, Input0, Input1); } - } - - /// In genCmptOps, generate code for a compute node that does general A = 0.0 but without workspace transformation. - void genCmptOpGeneralInitialAssignment(OpBuilder &builder, - Location &loc, - int lhs_loc, - ConstantOp &cstop, - std::vector &nested_forops, - std::vector> &main_tensors_all_Allocs, - std::vector> &allValueAccessIdx) - { - /// Generate Store 1.0, A[...] this op - /// this case: allPerms[0] is empty, allFormats[0] is empty - comet_debug() << " cstop.getValue(): " << cstop.getValue() << "\n"; - comet_debug() << " "; - comet_vdump(main_tensors_all_Allocs[lhs_loc][main_tensors_all_Allocs[lhs_loc].size() - 1]); - - if (allValueAccessIdx[lhs_loc].size() > 0) + else if (semiringSecond == "first") { - builder.create(loc, cstop, - main_tensors_all_Allocs[lhs_loc][main_tensors_all_Allocs[lhs_loc].size() - - 1], - allValueAccessIdx[lhs_loc]); + elementWiseResult = Input0; } - else + else if (semiringSecond == "second") { - Value local_accessIdx = nested_forops[0].getInductionVar(); - insertInitialize(loc, - cstop, - main_tensors_all_Allocs[lhs_loc][main_tensors_all_Allocs[lhs_loc].size() - 1], - local_accessIdx, - builder, - false /* use_dynamic_init */, - nullptr /* dynamic_init */); + elementWiseResult = Input1; } - } - - /// In genCmptOps, get LHS nnz value and data array before gathering results from the workspace. - void getLHSBeforeGatherFromWorkspace(OpBuilder &builder, - Location &loc, - int lhs_loc, - Value lhs, - std::vector> &main_tensors_all_Allocs, - unsigned int &lhs_2crd_size_loc /* output */, - unsigned int &lhs_2pos_size_loc /* output */, - Value &lhs_nnz /* output */, - Value &lhs_nnz_alloc /* output */, - Value &lhs_val /* output */) - { - /// Get tensor ranks - auto sp_op = cast(lhs.getDefiningOp()); - int lhs_ranks = sp_op.getTensorRank(); - - //[0...2d,2d+1...4d+1,4d+2...5d+1] - unsigned int lhs_val_size_loc = 8 * lhs_ranks + 1; /// 17 (2d) /// 15 - lhs_2crd_size_loc = 7 * lhs_ranks; /// 14 (2d) /// 12 /// output - lhs_2pos_size_loc = 7 * lhs_ranks - 1; /// 13 (2d) /// 11 /// output - - /// [0...2d, 2d+1...4d+1, 4d+2...5d+1] - comet_pdump(lhs.getDefiningOp()); - comet_pdump(lhs.getDefiningOp()->getParentOp()); - comet_vdump(lhs.getDefiningOp()->getOperand(lhs_val_size_loc)); - - Value lhs_nnz_operand = lhs.getDefiningOp()->getOperand(lhs_val_size_loc); - Value lhs_nnz_op; - comet_vdump(lhs_nnz_operand); - if (isa(lhs_nnz_operand.getDefiningOp())) + else if (semiringSecond == "atan2") { - lhs_nnz_op = lhs_nnz_operand.getDefiningOp()->getOperand(0); + elementWiseResult = builder.create(loc, Input0, Input1); } - else + else if (semiringSecond == "div") { - lhs_nnz_op = lhs_nnz_operand; + elementWiseResult = builder.create(loc, Input0, Input1); } - comet_vdump(lhs_nnz_op); - auto lhs_nnz_load = cast(lhs_nnz_op.getDefiningOp()); /// index - lhs_nnz_alloc = cast(lhs_nnz_load.getMemRef().getDefiningOp()); /// index /// output - - Value cst_0_index = builder.create(loc, 0); - lhs_nnz = builder.create(loc, lhs_nnz_alloc, ValueRange{cst_0_index}); /// output - - lhs_val = main_tensors_all_Allocs[lhs_loc].back(); /// output - comet_vdump(lhs_val); - } - - /// In genCmptOps, generate code for Cij = Wj when they are both dense. - void genCmptOpGatherFromDenseToDense(OpBuilder &builder, - Location &loc, - int rhs_loc, - int lhs_loc, - std::vector> &main_tensors_all_Allocs, - std::vector> &allValueAccessIdx) - { - /// %1 = load b[...] - /// store %1, a[...] - comet_debug() << " main_tensors_all_Allocs[" << rhs_loc << "].size(): " - << main_tensors_all_Allocs[rhs_loc].size() << ", allValueAccessIdx[" << rhs_loc << "].size(): " - << allValueAccessIdx[rhs_loc].size() << "\n"; - - Value rhs_value = builder.create(loc, main_tensors_all_Allocs[rhs_loc].back(), allValueAccessIdx[rhs_loc]); - comet_vdump(rhs_value); - - comet_vdump(main_tensors_all_Allocs[lhs_loc].back()); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto s1 = builder.create(loc, rhs_value, main_tensors_all_Allocs[lhs_loc].back(), allValueAccessIdx[lhs_loc]); - comet_vdump(s1); -#else - builder.create(loc, rhs_value, main_tensors_all_Allocs[lhs_loc].back(), allValueAccessIdx[lhs_loc]); -#endif - } - - /// Used by genCmptOps, for Cij = Wj without Workspace Transformation - void genCmptOpGatherFromDenseToOutput(OpBuilder &builder, - Location &loc, - int rhs_loc, - int lhs_loc, - unsigned int lhs_2crd_size_loc, - unsigned int lhs_2pos_size_loc, - Value lhs, - Value lhs_nnz, - Value lhs_nnz_alloc, - Value lhs_val, - std::vector> &allFormats, - std::vector> &main_tensors_all_Allocs, - std::vector> &allAccessIdx, - std::vector> &allValueAccessIdx, - std::vector &nested_forops) - { - - /// %1 = load b[...] - /// if(%1 != 0) { - /// Cnnz = load Cop.operand(4d+1) - /// store %1, cval[Cnnz] - /// store Cnnz+1, Cop.operand(4d+1) - /// } - comet_debug() << " main_tensors_all_Allocs[" << rhs_loc << "].size(): " - << main_tensors_all_Allocs[rhs_loc].size() << ", allValueAccessIdx[" << rhs_loc - << "].size(): " << allValueAccessIdx[rhs_loc].size() << "\n"; - Value rhs_value = builder.create(loc, main_tensors_all_Allocs[rhs_loc][main_tensors_all_Allocs[rhs_loc].size() - 1], allValueAccessIdx[rhs_loc]); - comet_debug() << " "; - comet_vdump(rhs_value); - auto f64Type = builder.getF64Type(); - Value const_f64_0 = builder.create(loc, f64Type, builder.getF64FloatAttr(0)); - Value isNonzero = builder.create(loc, CmpFPredicate::ONE, rhs_value, const_f64_0); - comet_debug() << " "; - comet_vdump(isNonzero); - auto if_nonzero = builder.create(loc, isNonzero, /*WithElseRegion*/ false); - comet_debug() << " If branch:\n"; - comet_vdump(if_nonzero); - - if (!if_nonzero.getThenRegion().empty()) + else if (semiringSecond == "eq") { - auto last_insertionPoint = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&if_nonzero.getThenRegion().front()); - - builder.create(loc, rhs_value, lhs_val, ValueRange{lhs_nnz}); - - /// update pos/crd arrays - /// Fill C2crd in CSR format, parent loop's accessIdx - /// Check format j in the output - if (allFormats[lhs_loc][allFormats[lhs_loc].size() - 1].compare(0, 2, "CU") == 0) - { - Value crd_index = allAccessIdx[allAccessIdx.size() - 1][allAccessIdx[allAccessIdx.size() - 1].size() - - 1]; - comet_debug() << " "; - comet_vdump(crd_index); - Value lhs_2crd = main_tensors_all_Allocs[lhs_loc][main_tensors_all_Allocs[lhs_loc].size() - 4]; //-2 - comet_debug() << " "; - comet_vdump(lhs_2crd); - - builder.create(loc, crd_index, lhs_2crd, ValueRange{lhs_nnz}); - } - - comet_debug() << "\n"; - Value cst_1_index = builder.create(loc, 1); - comet_debug() << " "; - comet_vdump(lhs_nnz); - Value lhs_nnz_new = builder.create(loc, lhs_nnz, cst_1_index); - comet_debug() << " AddIOps: (lhs_nnz_new)"; - comet_vdump(lhs_nnz_new); - comet_debug() << " "; - comet_vdump(lhs_nnz_alloc); - - Value cst_0_index = builder.create(loc, 0); - builder.create(loc, lhs_nnz_new, lhs_nnz_alloc, ValueRange{cst_0_index}); - - comet_debug() << "\n"; - Value lhs_2crd = lhs.getDefiningOp()->getOperand(lhs_2crd_size_loc); - Value lhs_2crd_op; - comet_vdump(lhs_2crd); - if (isa(lhs_2crd.getDefiningOp())) - { - lhs_2crd_op = lhs_2crd.getDefiningOp()->getOperand(0); - } - else - { - lhs_2crd_op = lhs_2crd; - } - comet_debug() << " "; - comet_vdump(lhs_2crd_op); - auto c2crd_size_load = cast(lhs_2crd_op.getDefiningOp()); /// index - Value c2crd_size_alloc = cast(c2crd_size_load.getMemRef().getDefiningOp()); /// index - comet_debug() << " "; - comet_vdump(c2crd_size_alloc); - - builder.create(loc, lhs_nnz_new, c2crd_size_alloc, ValueRange{cst_0_index}); - - comet_debug() << " \n"; - builder.restoreInsertionPoint(last_insertionPoint); + elementWiseResult = builder.create(loc, CmpFPredicate::OEQ, Input0, Input1); } - - comet_debug() << " \n"; - auto prev_forop = nested_forops[nested_forops.size() - 1 - 1]; - builder.setInsertionPointAfter(prev_forop); - - comet_debug() << " "; - comet_vdump(lhs.getDefiningOp()->getOperand(lhs_2pos_size_loc)); - Value lhs_2pos_0 = lhs.getDefiningOp()->getOperand(lhs_2pos_size_loc); - Value lhs_2pos_op; - comet_vdump(lhs_2pos_0); - if (isa(lhs_2pos_0.getDefiningOp())) + else if (semiringSecond == "ge") { - lhs_2pos_op = lhs_2pos_0.getDefiningOp()->getOperand(0); + elementWiseResult = builder.create(loc, CmpFPredicate::OGE, Input0, Input1); } - else + else if (semiringSecond == "gt") { - lhs_2pos_op = lhs_2pos_0; + elementWiseResult = builder.create(loc, CmpFPredicate::OGT, Input0, Input1); } - comet_debug() << " "; - comet_vdump(lhs_2pos_op); - auto c2pos_size_load = cast(lhs_2pos_op.getDefiningOp()); /// index - Value c2pos_size_alloc = cast(c2pos_size_load.getMemRef().getDefiningOp()); /// index - Value cst_0_index = builder.create(loc, 0); - Value c2pos_size_value = builder.create(loc, c2pos_size_alloc, ValueRange{cst_0_index}); - - Value lhs_2crd = lhs.getDefiningOp()->getOperand(lhs_2crd_size_loc); - Value lhs_2crd_op; - comet_vdump(lhs_2crd); - if (isa(lhs_2crd.getDefiningOp())) + else if (semiringSecond == "le") { - lhs_2crd_op = lhs_2crd.getDefiningOp()->getOperand(0); + elementWiseResult = builder.create(loc, CmpFPredicate::OLE, Input0, Input1); } - else + else if (semiringSecond == "lt") { - lhs_2crd_op = lhs_2crd; + elementWiseResult = builder.create(loc, CmpFPredicate::OLT, Input0, Input1); } - comet_debug() << " "; - comet_vdump(lhs_2crd_op); - auto c2crd_size_load = cast(lhs_2crd_op.getDefiningOp()); /// index - Value c2crd_size_alloc = cast(c2crd_size_load.getMemRef().getDefiningOp()); /// index - Value c2crd_size_nnz = builder.create(loc, c2crd_size_alloc, ValueRange{cst_0_index}); - - /// store crd_size into pos - Value lhs_2pos = main_tensors_all_Allocs[lhs_loc][main_tensors_all_Allocs[lhs_loc].size() - 5]; /// -3 - comet_debug() << " "; - comet_vdump(lhs_2pos); - - builder.create(loc, c2crd_size_nnz, lhs_2pos, ValueRange{c2pos_size_value}); - - Value cst_1_index = builder.create(loc, 1); - comet_debug() << " "; - comet_vdump(c2pos_size_value); - Value c2pos_size_value_new = builder.create(loc, c2pos_size_value, cst_1_index); - comet_debug() << " AddIOps (c2pos_size_value_new): "; - comet_vdump(c2pos_size_value_new); - - builder.create(loc, c2pos_size_value_new, c2pos_size_alloc, ValueRange{cst_0_index}); - } - - /// From the W_id_list_size, get the output C and C.rowptr, C.col, and C.val. - /// ----------------- /// - /// %55 = "it.ComputeLHS"(%53) {allFormats = [[]], allPerms = [[]]} : (tensor<1xindex>) -> tensor<*xf64> - /// %56 = "it.Compute"(%54, %55) {MaskType = "none", comp_worksp_opt = true, semiring = "noop_times"} : (tensor<*xindex>, tensor<*xf64>) -> i64 - /// %70 = "it.ComputeRHS"(%50, %51, %52, %53) {allFormats = [["D"]], allPerms = [[1]]} : (tensor, tensor, tensor, tensor<1xindex>) -> tensor<*xf64> - /// %93 = "it.Compute"(%70, %92) {MaskType = "none", comp_worksp_opt = true, semiring = "noop_times"} : (tensor<*xf64>, tensor<*xf64>) -> i64 - /// %92 = "it.ComputeLHS"(%91) {allFormats = [["D", "CU"]], allPerms = [[0, 1]]} : (!ta.sptensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, index, index, index, index, index, index, index, index, index, index, index>) -> tensor<*xf64> - /// %91 = ta.sptensor_construct(%73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %11, %12) {tensor_rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, index, index, index, index, index, index, index, index, index, index, index) -> (!ta.sptensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, index, index, index, index, index, index, index, index, index, index, index>) - /// %77 = bufferization.to_tensor %alloc_156 : memref - /// %alloc_156 = memref.alloc(%71) : memref - void getOutputMtxCRowptrAndDims(indexTree::IndexTreeComputeOp &cur_op, - Value &W_id_list_size, - SymbolicInfo &symbolicInfo /* output */) - { - Value mtxC = nullptr; - for (Operation *u_rhs : W_id_list_size.getUsers()) + else if (semiringSecond == "land") { - if (indexTree::IndexTreeComputeRHSOp rhs_op = dyn_cast(u_rhs)) - { - /// rhs_op is %70 - for (Operation *u_cmpt : u_rhs->getUsers()) - { - if (indexTree::IndexTreeComputeOp cmpt_op = dyn_cast(u_cmpt)) - { - /// cmpt_op is %93 - /// then %93's Operand[1] is %92 - /// %92's Operand[0] is %91 which is the sparse tensor - Value lhs_op = cmpt_op.getOperand(1); /// lhs_op is %92 - mtxC = lhs_op.getDefiningOp()->getOperand(0); /// mtxC is %91 - break; - } - } - } + /// land requires integer type input + llvm::errs() << "Not supported semiring operator (only works for int datatypes): " + << "land" + << "\n"; + /// we should not proceed forward from this point to avoid faulty behavior. + exit(1); } - - assert(mtxC && "Error: cannot find mtxC as the output."); - /// %77 is mtxC.getDefiningOp()->getOperand(A2POS) - /// %alloc_156 is C_rowptr - /// %71 is mtxC_rowptr_size - Value C_rowptr = mtxC.getDefiningOp()->getOperand(CSR_A2POS).getDefiningOp()->getOperand(0); /// A2POS is rowptr's location - Value C_rowptr_size = C_rowptr.getDefiningOp()->getOperand(0); - Value C_num_rows = mtxC.getDefiningOp()->getOperand(CSR_DIM1_SIZE); - Value C_num_cols = mtxC.getDefiningOp()->getOperand(CSR_DIM2_SIZE); - symbolicInfo.mtxC = mtxC; - symbolicInfo.mtxC_rowptr = C_rowptr; - symbolicInfo.mtxC_rowptr_size = C_rowptr_size; - symbolicInfo.mtxC_num_rows = C_num_rows; - symbolicInfo.mtxC_num_cols = C_num_cols; + else if (semiringSecond == "lor") { - comet_vdump(mtxC); - comet_vdump(C_rowptr); - comet_vdump(C_rowptr_size); - comet_vdump(C_num_rows); - comet_vdump(C_num_cols); + /// lor requires integer type input + llvm::errs() << "Not supported semiring operator (only works for int datatypes): " + << "lor" + << "\n"; + /// we should not proceed forward from this point to avoid faulty behavior. + exit(1); } - } - - /// Generate mark before the outer-most symbolic for-loop, - /// and update mark for every idx at the beginning of the outer-most symbolic for-loop. - void genSymbolicMarkAndUpdate(OpBuilder &builder, - Location &loc, - /// std::vector &symbolic_nested_forops, /* from innermost to outermost */ - scf::ForOp &outermost_forLoop, /// the outermost for-loop - Value &mark_alloc /* output */, - Value &mark_new_val /* output */) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point before the outer-most symbolic for-loop - builder.setInsertionPoint(outermost_forLoop); - - /// Generate the variable mark - /// %mark = memref.alloc() : memref<1xindex> - /// memref.store %c0, %mark[%c0] : memref<1xindex> - MemRefType memTy_1xindex = MemRefType::get({1}, builder.getIndexType()); - mark_alloc = builder.create(loc, memTy_1xindex); - Value const_index_0 = builder.create(loc, 0); - builder.create(loc, - const_index_0, - mark_alloc, - ValueRange{const_index_0}); + else if (semiringSecond == "lxor") { - comet_vdump(mark_alloc); + /// lxor requires integer type input + llvm::errs() << "Not supported semiring operator: " + << "lxor" + << "\n"; } - - /// Generate updating mark += 2 - /// %c2 = arith.constant 2 : index - /// %old_val = memref.load %mark[%c0] : memref<1xindex> - /// %new_mark = arith.addi %old_val, %c2 : index - /// memref.store %new_mark, %mark[%c0] : memref<1xindex> - builder.setInsertionPointToStart(outermost_forLoop.getBody()); - Value const_index_2 = builder.create(loc, 2); - Value old_mark_val = builder.create(loc, mark_alloc, ValueRange{const_index_0}); - mark_new_val = builder.create(loc, old_mark_val, const_index_2); - builder.create(loc, - mark_new_val, - mark_alloc, - ValueRange{const_index_0}); + else if (semiringSecond == "minxy") { - comet_vdump(outermost_forLoop); + Value cmp = builder.create(loc, CmpFPredicate::OLT, Input0, Input1); + elementWiseResult = builder.create(loc, cmp, Input0, Input1); } - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - /// Generate symbolic if statement condition in the CmptOp - /// -------No masking---------- /// - /// if (mark_array[j_idx] != mark) { - /// mark_array[j_idx] = mark; /// C[i_idx, j_idx] has been visited - /// W_id_list_size += 1; - /// } - /// -------Push masking---------- /// - /// if (mark_array[j_idx] == mark) { - /// mark_array[j_idx] = mark + 1; /// C[i_idx, j_idx] has been visited - /// W_id_list_size += 1; - /// } - void genSymbolicIfStatementCondition(OpBuilder &builder, - Location &loc, - scf::ForOp &semiringLoop, /// symbolic_nested_forops[0] - Value &mark_array_alloc, /// tensors_lhs_Allocs[1][0] - Value &valueAccessIdx, /// allValueAccessIdx[lhs_loc][0] - Value &mark_new_val, - scf::IfOp &if_statement /* output */, - MaskingInfo &maskingInfo) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point at the end of the inner-most symbolic for-loop - builder.setInsertionPoint(semiringLoop.getBody()->getTerminator()); - + else if (semiringSecond == "max") { - comet_vdump(semiringLoop); + Value cmp = builder.create(loc, CmpFPredicate::OGT, Input0, Input1); + elementWiseResult = builder.create(loc, cmp, Input0, Input1); } - /// Generate If statement condition - Value ele_mark_val = builder.create(loc, mark_array_alloc, ValueRange{valueAccessIdx}); - - if (PUSH_BASED_MASKING == maskingInfo.mask_type) + else if (semiringSecond == "ne") { - Value equal_mask = builder.create(loc, - CmpIPredicate::eq, - ele_mark_val, - mark_new_val); - if_statement = builder.create(loc, equal_mask, false /* No Else Region */); + elementWiseResult = builder.create(loc, CmpFPredicate::ONE, Input0, Input1); } - else if (NO_MASKING == maskingInfo.mask_type) + else if (semiringSecond == "minus") { - - Value not_equal_mark = builder.create(loc, - CmpIPredicate::ne, - ele_mark_val, - mark_new_val); - if_statement = builder.create(loc, not_equal_mark, false /* No Else Region */); + elementWiseResult = builder.create(loc, Input0, Input1); } - else + else if (semiringSecond == "plusxy") { - llvm::errs() << "Error: mask_type " << maskingInfo.mask_type << " is not supported.\n"; + elementWiseResult = builder.create(loc, Input0, Input1); } + else if (semiringSecond == "pairxy") { - comet_vdump(ele_mark_val); - comet_vdump(if_statement); - comet_vdump(semiringLoop); + elementWiseResult = builder.create(loc, builder.getF64Type(), builder.getF64FloatAttr(1)); } - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - /// Generate symbolic if statement then region in the CmptOp - /// -------No masking---------- /// - /// if (mark_array[j_idx] != mark) { - /// mark_array[j_idx] = mark; /// C[i_idx, j_idx] has been visited - /// W_id_list_size += 1; - /// } - /// -------Push masking---------- /// - /// if (mark_array[j_idx] == mark) { - /// mark_array[j_idx] = mark + 1; /// C[i_idx, j_idx] has been visited - /// W_id_list_size += 1; - /// } - void genSymbolicIfStatementThenRegion(OpBuilder &builder, - Location &loc, - scf::IfOp &if_statement, - Value &mark_array_alloc, /// tensors_lhs_Allocs[1][0] - Value &valueAccessIdx, /// allValueAccessIdx[lhs_loc][0] - Value &W_id_list_size, - Value &mark_new_val, - MaskingInfo &maskingInfo) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point to the beginning of the if statement then region - builder.setInsertionPointToStart(&if_statement.getThenRegion().front()); - - if (PUSH_BASED_MASKING == maskingInfo.mask_type) + else if (semiringSecond == "pow") { - /// mark_array[j_idx] = mark + 1; - Value const_index_1 = builder.create(loc, 1); - Value mark_value_plus_one = builder.create(loc, mark_new_val, const_index_1); - builder.create(loc, - mark_value_plus_one, - mark_array_alloc, - ValueRange{valueAccessIdx}); + elementWiseResult = builder.create(loc, Input0, Input1); } - else if (NO_MASKING == maskingInfo.mask_type) + else { - /// mark_array[j_idx] = mark - builder.create(loc, - mark_new_val, - mark_array_alloc, - ValueRange{valueAccessIdx}); + llvm::errs() << "Not supported semiring operator: " << semiringSecond << "\n"; + /// we should not proceed forward from this point to avoid faulty behavior. } - /// W_id_list_size += 1; - Value const_index_0 = builder.create(loc, 0); - Value const_index_1 = builder.create(loc, 1); - Value old_val = builder.create(loc, W_id_list_size, ValueRange{const_index_0}); - Value new_val = builder.create(loc, old_val, const_index_1); - builder.create(loc, - new_val, - W_id_list_size, - ValueRange{const_index_0}); - - { - comet_vdump(if_statement); - } - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); + return elementWiseResult; } - /// Updating output - /// C.rowptr[idx] = W_id_list_size; - void genSymbolicUpdateCRowptr(OpBuilder &builder, - Location &loc, - scf::ForOp &outermost_forLoop, - Value &mtxC_rowptr, - Value &valueAccessIdx, - Value &W_id_list_size) + Value getSemiringFirstVal(OpBuilder &builder, Location &loc, + llvm::StringRef &semiringFirst, Value &Input0, Value &Input1) { + + Value reduceResult; + if (semiringFirst == "times") { - comet_vdump(mtxC_rowptr); - comet_vdump(valueAccessIdx); - comet_vdump(W_id_list_size); + reduceResult = builder.create(loc, Input0, Input1); } - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point at the end of the outermost for-loop body - builder.setInsertionPoint(outermost_forLoop.getBody()->getTerminator()); - - Value const_index_0 = builder.create(loc, 0); - Value rowptr_val = builder.create(loc, W_id_list_size, ValueRange{const_index_0}); - builder.create(loc, - rowptr_val, - mtxC_rowptr, - ValueRange{valueAccessIdx}); - + else if (semiringFirst == "plusxy") { - comet_vdump(outermost_forLoop); + reduceResult = builder.create(loc, Input0, Input1); } - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - - /// Generate the reduce of the output C.rowptr after the outermost for-loop - /// C.rowptr[M] = 0; - /// int C_val_size = 0; - /// for (int i_idx = 0; i_idx < M + 1; ++i_idx) { - /// int curr = C.rowptr[i_idx]; - /// C.rowptr[i_idx] = C_val_size; - /// C_val_size += curr; - /// } - /// C.col = new int[C_val_size] - /// C.val = new f64[C_val_size] - void genSymbolicReduceOutputCRowptrCColCVal(OpBuilder &builder, - Location &loc, - scf::ForOp &outermost_forLoop, - SymbolicInfo &symbolicInfo /* output */) - { - Value const_index_0 = builder.create(loc, 0); - Value const_index_1 = builder.create(loc, 1); - - /// C.rowptr[M] = 0 - Value &mtxC_rowptr = symbolicInfo.mtxC_rowptr; - Value &num_rows = symbolicInfo.mtxC_num_rows; - builder.create(loc, - const_index_0, - mtxC_rowptr, - ValueRange{num_rows}); - - /// C_val_size = 0; - MemRefType memTy_1xindex = MemRefType::get({1}, builder.getIndexType()); - Value C_val_size = builder.create(loc, memTy_1xindex); - builder.create(loc, - const_index_0, - C_val_size, - ValueRange{const_index_0}); - - /// for (int i_idx = 0; i_idx < M + 1; ++i_idx) { - /// int curr = C.rowptr[i_idx]; - /// C.rowptr[i_idx] = C_val_size; - /// C_val_size += curr; - /// } - Value &num_rows_plus_one = symbolicInfo.mtxC_rowptr_size; - scf::ForOp reduce_forLoop = builder.create(loc, - const_index_0 /* lowerBound */, - num_rows_plus_one /* upperBound */, - const_index_1 /* step */); - builder.setInsertionPointToStart(reduce_forLoop.getBody()); - Value i_idx = reduce_forLoop.getInductionVar(); - Value curr = builder.create(loc, mtxC_rowptr, ValueRange{i_idx}); - Value size_val = builder.create(loc, C_val_size, ValueRange{const_index_0}); - builder.create(loc, - size_val, - mtxC_rowptr, - ValueRange{i_idx}); - Value new_val = builder.create(loc, curr, size_val); - builder.create(loc, - new_val, - C_val_size, - ValueRange{const_index_0}); + else if (semiringFirst == "minxy") { - comet_vdump(reduce_forLoop); + Value cmp = builder.create(loc, CmpFPredicate::OLT, Input0, Input1); + reduceResult = builder.create(loc, cmp, Input0, Input1); } - builder.setInsertionPointAfter(reduce_forLoop); - Value mtxC_val_size = builder.create(loc, C_val_size, ValueRange{const_index_0}); - symbolicInfo.mtxC_val_size = mtxC_val_size; - - /// Allocate new C.col and new C.val - MemRefType memTy_alloc_dynamic_index = MemRefType::get({ShapedType::kDynamic}, builder.getIndexType()); - MemRefType memTy_alloc_dynamic_f64 = MemRefType::get({ShapedType::kDynamic}, builder.getF64Type()); - Value new_mtxC_col = builder.create(loc, - memTy_alloc_dynamic_index, - ValueRange{mtxC_val_size}); - Value new_mtxC_val = builder.create(loc, - memTy_alloc_dynamic_f64, - ValueRange{mtxC_val_size}); - symbolicInfo.mtxC_col = new_mtxC_col; - symbolicInfo.mtxC_val = new_mtxC_val; + else if (semiringFirst == "max") { - comet_vdump(mtxC_val_size); - comet_vdump(new_mtxC_col); - comet_vdump(new_mtxC_val); + Value cmp = builder.create(loc, CmpFPredicate::OGT, Input0, Input1); + reduceResult = builder.create(loc, cmp, Input0, Input1); } - } - - /// ----------------- /// - /// Store new mtxC_val_size to the old mtxC's C_col_size (A2crd_size) and C_val_size (Aval_size). - /// Just in case for safety. - /// ----------------- /// - void storeNewMtxCValeSizeToOldMtxC(OpBuilder &builder, - Location &loc, - SymbolicInfo &symbolicInfo) - { - Value &mtxC = symbolicInfo.mtxC; - Value &mtxC_val_size = symbolicInfo.mtxC_val_size; - Value const_index_0 = builder.create(loc, 0); - + else if (semiringFirst == "land") { - comet_vdump(mtxC); - comet_vdump(mtxC_val_size); + /// land requires integer type input + llvm::errs() << "Not supported semiring operator (only works for int datatypes): " + << "land" + << "\n"; + /// we should not proceed forward from this point to avoid faulty behavior. } - - /// Find the alloc of C_col_size (Arcrd_size) - /// %66 = memref.load %alloc_153[%c0_128] : memref<1xindex> - Value C_col_size_alloc = mtxC.getDefiningOp()->getOperand(CSR_A2CRD_SIZE).getDefiningOp()->getOperand(0); /// 8 - /// Store the new mtxC_val_size to C_col_size -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_C_col_size_alloc = builder.create(loc, - mtxC_val_size, - C_col_size_alloc, - ValueRange{const_index_0}); - comet_vdump(C_col_size_alloc); - comet_vdump(store_C_col_size_alloc); -#else - builder.create(loc, - mtxC_val_size, - C_col_size_alloc, - ValueRange{const_index_0}); -#endif - - /// Find the alloc of C_val_size (Aval_size) - /// %67 = memref.load %alloc_154[%c0_128] : memref<1xindex> - Value C_val_size_alloc = mtxC.getDefiningOp()->getOperand(CSR_AVAL_SIZE).getDefiningOp()->getOperand(0); /// 9 - /// Store the new mtxC_val_size to C_val_size -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - auto store_C_val_size_alloc = builder.create(loc, - mtxC_val_size, - C_val_size_alloc, - ValueRange{const_index_0}); - comet_vdump(C_val_size_alloc); - comet_vdump(store_C_val_size_alloc); -#else - builder.create(loc, - mtxC_val_size, - C_val_size_alloc, - ValueRange{const_index_0}); -#endif - } - - /// Dealloc the old C.val and C.col before the outermost_forLoop. - /// Replace the old C.val and C.col with new ones. - void deallocMtxCColCVal(OpBuilder &builder, - Location &loc, - scf::ForOp &outermost_forLoop, - SymbolicInfo &symbolicInfo) - { - /// Find old C.col and C.val - Value &mtxC = symbolicInfo.mtxC; - Value old_C_col = mtxC.getDefiningOp()->getOperand(CSR_A2CRD).getDefiningOp()->getOperand(0); - Value old_C_val = mtxC.getDefiningOp()->getOperand(CSR_AVAL).getDefiningOp()->getOperand(0); - - /// Dealloc old C.col and C.val - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point before the symbolic outermost_forloop - builder.setInsertionPoint(outermost_forLoop); - - builder.create(loc, old_C_col); - builder.create(loc, old_C_val); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - - /// -------------- /// - /// Remove mtxC_col's user who is a memref.store operation - /// This is very ad-hoc, just to avoid segmentation fault for old very large C.val array and C.col array. - /// -------------- /// - removeMemrefStoreUser(old_C_col); - removeMemrefStoreUser(old_C_col); - - /// Replace old C.col and C.val - /// Just in case of safety. - replaceOldValueToNewValue(old_C_col, symbolicInfo.mtxC_col); - replaceOldValueToNewValue(old_C_val, symbolicInfo.mtxC_val); - } - - /// Generate a new sparse tensor to replace the old output sparse tensor after the numeric outermost for-loop. - /// (e.g., ta.print(old_tensor) -> ta.print(new_tensor) - void genReplaceOutputSparseTensorToNewSparseTensor(OpBuilder &builder, - Location &loc, - scf::ForOp &numeric_outermost_forLoop, - SymbolicInfo &symbolicInfo) - { - /// Set the insertion point after the outermost_forloop - builder.setInsertionPointAfter(numeric_outermost_forLoop); - - Value &mtxC = symbolicInfo.mtxC; - Value &mtxC_col = symbolicInfo.mtxC_col; - Value &mtxC_val = symbolicInfo.mtxC_val; - - /// Generate the new mtxC_col and new mtxC_val bufferization.to_tensor - Value mtxC_col_buffer = builder.create(loc, mtxC_col); - Value mtxC_val_buffer = builder.create(loc, mtxC_val); - - auto sp_op = cast(mtxC.getDefiningOp()); - int tensorRanks = sp_op.getTensorRank(); - - /// Get the operands and their types for the sparse tensor ta.sptensor_construct() (which is mtxC). - SmallVector operands; - operands.insert(operands.end(), - mtxC.getDefiningOp()->getOperands().begin(), - mtxC.getDefiningOp()->getOperands().end()); - operands[CSR_A2CRD] = mtxC_col_buffer; /// 3 (A2crd) - operands[CSR_AVAL] = mtxC_val_buffer; /// 4 (AVal) - SmallVector elementTypes; - for (Value &opd : operands) + else if (semiringFirst == "lor") { - elementTypes.push_back(opd.getType()); + /// lor requires integer type input + llvm::errs() << "Not supported semiring operator (only works for int datatypes): " + << "lor" + << "\n"; + /// we should not proceed forward from this point to avoid faulty behavior. } - auto ty = tensorAlgebra::SparseTensorType::get(elementTypes); - Value sptensor = builder.create(loc, - ty, - operands, - tensorRanks); + else if (semiringFirst == "any") { - comet_vdump(mtxC_col_buffer); - comet_vdump(mtxC_val_buffer); - comet_vdump(sptensor); + reduceResult = Input1; } - - /// ----------------- /// - /// Find all users of the old sparse tensor mtxC, and replace those users' corresponding operands - /// to the new sparse tensor (sptensor). For example, - /// "ta.print"(%mtxC) => "ta.print"(%sptensor) - /// ----------------- /// - replaceOldValueToNewValue(mtxC, sptensor); - } - - /// Logistics of memory about old mtxC, mtxC.col, and mtxC.val - /// 1. Dealloc the old C.val and C.col before the outermost_forLoop. - /// 2. Change mtxC's old value in C_col_size (A2crd_size) and C_val_size (Aval_size) to new mtxC_val_size. - /// 3. Generate a new sparse tensor to replace the old output sparse tensor after the numeric outermost for-loop. - void logisticsForMtxCColCVal(OpBuilder &builder, - Location &loc, - scf::ForOp &symbolic_outermost_forLoop, - SymbolicInfo &symbolicInfo, - scf::ForOp &numeric_outermost_forLoop) - { - - /// Dealloc old C.col and C.val - /// Replace the old C.val and C.col with new ones. - deallocMtxCColCVal(builder, - loc, - symbolic_outermost_forLoop, - symbolicInfo); - - /// Change mtxC's old value in C_col_size (A2crd_size) and C_val_size (Aval_size) to new mtxC_val_size. - /// Just in case for safety. - storeNewMtxCValeSizeToOldMtxC(builder, - loc, - symbolicInfo); - - /// Generate a new sparse tensor to replace the old output sparse tensor after the numeric outermost for-loop. - /// (e.g., ta.print(old_tensor) -> ta.print(new_tensor) - genReplaceOutputSparseTensorToNewSparseTensor(builder, - loc, - numeric_outermost_forLoop, - symbolicInfo); - - /// builder.restoreInsertionPoint(last_insertion_point); - } - - /// Initialize the mark-array according to the mask at the beginning of the symbolic outermost for-loop - /// ----------------- /// - /// %j_loc_start = memref.load %mask_rowptr[%i_idx] : memref - /// %j_loc_bound = memref.load %mask_rowptr[%i_idx_plus_one] : memref - /// scf.for %j_loc = %j_loc_start to %j_loc_bound step %c1 { - /// %val = memref.load %mask_val[%j_loc] : memref - /// %70 = arith.cmpf une, %val, %cst : f64 - /// scf.if %70 { - /// %j_idx = memref.load %mask_col[%arg1] : memref - /// memref.store %mark, %mark_array[%j_idx] : memref - /// } - /// } - void genSymbolicInitMarkArrayByMask(OpBuilder &builder, - Location &loc, - scf::ForOp &symbolic_outermost_forLoop, - Value &outermost_forLoop_valueAccessIdx, - Value &mark_array_alloc, - Value &mark_new_val, - MaskingInfo &maskingInfo) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the Insertion Point at the beginning of the symbolic outermost for-loop but AFTER the mark_new_val - builder.setInsertionPointAfter(mark_new_val.getDefiningOp()); - - /// Generate the for-loop entry - Value &mask_rowptr = maskingInfo.mask_rowptr; - Value &mask_col = maskingInfo.mask_col; - Value &mask_val = maskingInfo.mask_val; - Value const_index_1 = builder.create(loc, 1); - Value &i_idx = outermost_forLoop_valueAccessIdx; - Value i_idx_plus_one = builder.create(loc, i_idx, const_index_1); - Value j_loc_start = builder.create(loc, mask_rowptr, ValueRange{i_idx}); - Value j_loc_bound = builder.create(loc, mask_rowptr, ValueRange{i_idx_plus_one}); - auto for_loop = builder.create(loc, - j_loc_start /* lower_bound */, - j_loc_bound /* upper_bound*/, - const_index_1 /* step */); + else if (semiringFirst == "noop") { - comet_vdump(j_loc_start); - comet_vdump(j_loc_bound); - comet_vdump(for_loop); + reduceResult = Input1; } - - /// Generate the for-loop body - builder.setInsertionPointToStart(for_loop.getBody()); - Value const_f64_0 = builder.create(loc, builder.getF64Type(), builder.getF64FloatAttr(0)); - Value j_loc = for_loop.getInductionVar(); - Value val = builder.create(loc, mask_val, ValueRange{j_loc}); - Value not_zero = builder.create(loc, CmpFPredicate::UNE, val, const_f64_0); - auto if_not_zero = builder.create(loc, not_zero, false /*NoElseRegion*/); - builder.setInsertionPointToStart(&if_not_zero.getThenRegion().front()); - Value j_idx = builder.create(loc, mask_col, ValueRange{j_loc}); - builder.create(loc, - mark_new_val, - mark_array_alloc, - ValueRange{j_idx}); - + else { - comet_vdump(val); - comet_vdump(if_not_zero); - comet_vdump(for_loop); - comet_vdump(symbolic_outermost_forLoop); + llvm::errs() << "Not supported semiring operator: " << semiringFirst << "\n"; + /// we should not proceed forward from this point to avoid faulty behavior. } - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); + return reduceResult; } - /// Generate the symbolic phase's kernel to compute the rowptr[i_idx] - void genSymbolicSemiringLoopBody(OpBuilder &builder, - Location &loc, - int lhs_loc, - std::vector> &tensors_lhs_Allocs, - std::vector &symbolic_nested_forops, - std::vector &symbolic_nested_AccessIdx, - std::vector> &symbolic_allValueAccessIdx, - SymbolicInfo &symbolicInfo, - std::vector &numeric_nested_forops, - MaskingInfo &maskingInfo) + struct LowerIndexTreeToSCFPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerIndexTreeToSCFPass) + void runOnOperation() override; - scf::ForOp &outermost_forLoop = symbolic_nested_forops.back(); - Value &outermost_forLoop_valueAccessIdx = symbolic_nested_AccessIdx.back(); - scf::ForOp &semiringLoop = symbolic_nested_forops[0]; - Value &mark_array = tensors_lhs_Allocs[1][0]; - Value &W_id_list_size = tensors_lhs_Allocs[3][0]; - Value &semiringLoop_valueAccessIdx = symbolic_allValueAccessIdx[lhs_loc][0]; + SetVector collectChildren(IndexTreeIndicesOp root); + void fillSubtree(Location loc, IRRewriter &rewriter, + const SetVector& subtree, + const SmallVector& old_outputs, + ValueRange new_outputs, + IRMapping& map); + void deleteDomain(Operation* op, IRRewriter &rewriter); + Value convertOperand(IndexTreeLHSOperandOp op, IRRewriter &rewriter); + Value convertOperand(IndexTreeOperandOp op, IRRewriter &rewriter); + mlir::LogicalResult convertCompute(Operation* op, IRRewriter &rewriter); + mlir::LogicalResult convertIndexNode(Operation* op, IRRewriter &rewriter); + }; +} - /// Generate mark before symbolic outer-most for-loop - Value mark_alloc; - Value mark_new_val; - genSymbolicMarkAndUpdate(builder, - loc, - outermost_forLoop, /// the outermost for-loop - mark_alloc /* output */, - mark_new_val /* output */); +Value +LowerIndexTreeToSCFPass::convertOperand(IndexTreeOperandOp op, IRRewriter &rewriter) +{ + Location loc = op->getLoc(); + Value tensor = op.getTensor(); + auto crds = op.getCrds(); + auto positions = op.getPos(); - if (PUSH_BASED_MASKING == maskingInfo.mask_type) + TensorType tensor_type; + if((tensor_type = llvm::dyn_cast(tensor.getType()))){ + return rewriter.create(loc, tensor_type.getElementType(), tensor, crds); + } else { + Type element_type; + if(llvm::isa(tensor.getType())) { - assert(symbolic_nested_forops.size() >= 2 && symbolic_allValueAccessIdx.size() >= 2 && - "Error: The symbolic for-loops should be at least 2 level.\n"); - - /// Initialize the mark-array according to the mask at the beginning of the symbolic outermost for-loop - genSymbolicInitMarkArrayByMask(builder, - loc, - outermost_forLoop, - outermost_forLoop_valueAccessIdx, - mark_array, - mark_new_val, - maskingInfo); + element_type = llvm::cast(tensor.getType()).getElementType(); + } else if(llvm::isa(tensor.getType())) + { + element_type = llvm::cast(tensor.getType()).getElementType(); } + Value pos = positions[positions.size() - 1]; + return rewriter.create(loc, element_type, tensor, pos); + } +} - /// Generate if statement condition - /// if (mark_array[j_idx] != mark) { - /// mark_array[j_idx] = mark; /// C[i_idx, j_idx] has been visited - /// W_id_list_size += 1; - /// } - scf::IfOp if_statement; - genSymbolicIfStatementCondition(builder, - loc, - semiringLoop, /// the inner-most for-loop (SemiringLoop) - mark_array, /// mark-array - semiringLoop_valueAccessIdx, /// value access index j_idx - mark_new_val, - if_statement /* output */, - maskingInfo); - - /// Generate if statement then region - /// if (mark_array[j_idx] != mark) { - /// mark_array[j_idx] = mark; /// C[i_idx, j_idx] has been visited - /// W_id_list_size += 1; - /// } - genSymbolicIfStatementThenRegion(builder, - loc, - if_statement, - mark_array, /// mark-array - semiringLoop_valueAccessIdx, /// value access index j_idx - W_id_list_size, /// W_id_list_size - mark_new_val, - maskingInfo); - - /// Updating output - /// C.rowptr[idx] = W_id_list_size; - Value i_idx = outermost_forLoop.getInductionVar(); - genSymbolicUpdateCRowptr(builder, - loc, - outermost_forLoop, - symbolicInfo.mtxC_rowptr, /// mtxC_rowptr - i_idx, /// value access index i_idx - W_id_list_size /* W_id_list_size */); - - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - /// Set the insertion point after the outermost_forloop - builder.setInsertionPointAfter(outermost_forLoop); - - /// Generate the reduce of output C.rowptr and new C.col and new C.val - /// C.rowptr[M] = 0; - /// int C_val_size = 0; - /// for (int i_idx = 0; i_idx < M + 1; ++i_idx) { - /// int curr = C.rowptr[i_idx]; - /// C.rowptr[i_idx] = C_val_size; - /// C_val_size += curr; - /// } - genSymbolicReduceOutputCRowptrCColCVal(builder, - loc, - outermost_forLoop, - symbolicInfo /* output */); +Value +LowerIndexTreeToSCFPass::convertOperand(IndexTreeLHSOperandOp op, IRRewriter &rewriter) +{ + Location loc = op->getLoc(); + Value tensor = op.getTensor(); + auto crds = op.getCrds(); + auto positions = op.getPos(); + + TensorType tensor_type; + if((tensor_type = llvm::dyn_cast(tensor.getType()))){ + return rewriter.create(loc, tensor_type.getElementType(), tensor, crds); + } else { + // LHS may not be constant (i.e. if we are inserting into a tensor that we need to resize), + // so cannot directly lower like we can the RHS + Value pos = positions[positions.size() - 1]; + return rewriter.create(loc, rewriter.getF64Type(), tensor, pos); + } +} - /// Logistics of memory about old mtxC, mtxC.col, and mtxC.val - /// 1. Dealloc the old C.val and C.col before the outermost_forLoop. - /// 2. Change mtxC's old value in C_col_size (A2crd_size) and C_val_size (Aval_size) to new mtxC_val_size. - /// 3. Generate a new sparse tensor to replace the old output sparse tensor after the numeric outermost for-loop. - scf::ForOp &numeric_outermost_forLoop = numeric_nested_forops.back(); - logisticsForMtxCColCVal(builder, - loc, - outermost_forLoop, /// symbolic_outermost_forLoop - symbolicInfo, - numeric_outermost_forLoop); +mlir::LogicalResult +LowerIndexTreeToSCFPass::convertCompute(Operation *op, + IRRewriter &rewriter) +{ + Location loc = op->getLoc(); + IndexTreeComputeOp compute_op = llvm::cast(op); + auto semiringParts = compute_op.getSemiring().split('_'); - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); + Value elementwise_result; + for(auto rhs = compute_op.getRhs().begin(); rhs != compute_op.getRhs().end(); rhs++) + { + Value rhs_value = convertOperand(llvm::cast((*rhs).getDefiningOp()), rewriter); + if(rhs == compute_op.getRhs().begin()){ + elementwise_result = rhs_value; + } else { + elementwise_result = getSemiringSecondVal(rewriter, loc, semiringParts.second, + elementwise_result, rhs_value); + } } - /// 1. Get the nested loops - /// ---1.1 the nested loops corresponding indices can be infered from ancestors_wp - /// 2. get lhs and rhs. if only 1 rhs, then it's a fill op; otherwise, binary op - /// Note: 1. The auxiliary arrays does not contain the perms/formats information - /// 2. We only apply the compressed workspace on the output of the tensor, then in this case, the workspace tensors will not be in the same side with the main tensors. - /// (main tensors: such as A, B, C, w; auxiliary tensors: such as w_index_list ...) - void genCmptOps(indexTree::IndexTreeComputeOp &cur_op, - indexTree::IndexTreeOp &rootOp, - /// PatternRewriter &rewriter, - OpBuilder &builder, - OpsTree *opstree, - std::vector &ancestorsWps, - std::vector &wp_ops, - SymbolicInfo &symbolicInfo, - NumericInfo &numericInfo) - { - comet_debug() << " calling genCmptOps\n"; - Location loc = rootOp.getLoc(); - comet_debug() << " \n"; + IndexTreeLHSOperandOp lhs = llvm::cast(compute_op.getLhs().getDefiningOp()); + Value reduce_result = convertOperand(lhs, rewriter); + reduce_result = getSemiringFirstVal(rewriter, loc, semiringParts.first, + reduce_result, elementwise_result); - comet_debug() << " Current IndexTreeComputeOp:"; - comet_vdump(cur_op); + Value old_tensor = lhs.getTensor(); + Value output_tensor; - const bool comp_worksp_opt(cur_op.getCompWorkspOpt()); - comet_debug() << " comp_worksp_opt (bool: true is compressed): " << comp_worksp_opt << "\n"; + // LLVM_DEBUG({ + // lhs->emitOpError() << "Creating Tensor Insert Op" << "\n"; + // }); + + if(llvm::isa(old_tensor.getType())) + { + output_tensor = rewriter.create(loc, old_tensor.getType(), reduce_result, old_tensor, lhs.getCrds()); + } else { + output_tensor = rewriter.create(loc, old_tensor.getType(), old_tensor, lhs.getPos(), lhs.getCrds(), reduce_result); + } + rewriter.replaceAllUsesWith(op->getResult(0), output_tensor); + ValueRange rhs = compute_op.getRhs(); + rewriter.eraseOp(compute_op); + for(Value operand : rhs){ + rewriter.eraseOp(operand.getDefiningOp()); + } + rewriter.eraseOp(lhs); - /// Two cases: - /// 1. for the initial workspace, only 1 auxiliary vector w - /// 2. for the compressed workspace, there are 4 auxiliaty vectors, w, w_already_set, w_index_list, w_index_list_size + return success(); +} - /// The insertion location should be "the end of the body of parent loop" - std::vector ancestorsOps; - getAncestorsOps(opstree, ancestorsOps); - comet_debug() << " ancestorsOps.size(): " << ancestorsOps.size() << "\n"; - for (unsigned int i = 0; i < ancestorsOps.size(); i++) +SetVector +LowerIndexTreeToSCFPass::collectChildren(IndexTreeIndicesOp root) +{ + SetVector result; + for(Operation* user : root->getUsers()) + { + if(llvm::isa(user)) { - comet_debug() << " ancestorsOps[i]->id:" << ancestorsOps[i]->id << "\n"; + Value domain = llvm::cast(user).getDomain(); + Operation* domain_op = domain.getDefiningOp(); + if(llvm::isa(domain_op)) + { + auto intersection_domain_op = llvm::cast(domain_op); + for(Value subdomain : intersection_domain_op.getDomains()){ + result.insert(subdomain.getDefiningOp()); + } + } else if(llvm::isa(domain_op)) + { + auto union_domain_op = llvm::cast(domain_op); + for(Value subdomain : union_domain_op.getDomains()){ + result.insert(subdomain.getDefiningOp()); + } + } + result.insert(domain_op); } - - /// 1. get the nested loops, from innermost to outermost order - std::vector nested_forops; - std::vector nested_AccessIdx; - std::vector nested_forops_indices; /// Each nested indexOp's index value (e.g., indices=[0]) - getNumericNestedForOpsAndAccessIdx(ancestorsWps, - ancestorsOps, - nested_forops /* output */, - nested_AccessIdx /* output */, - nested_forops_indices /* output */); - - comet_debug() << " nested_forops_indices.size(): " << nested_forops_indices.size() << "\n"; - assert( - nested_forops.size() == nested_forops_indices.size() && "nested_forops.size() != nested_forops_indices.size()"); - - /// Reset the insertion point: the body of the innermost loop - assert(nested_forops.size() > 0 && "No loops\n"); - comet_debug() << " "; - comet_pdump(nested_forops[0].getBody()); - comet_debug() << " "; - comet_pdump(nested_forops[0].getBody()->getTerminator()); - builder.setInsertionPoint(nested_forops[0].getBody()->getTerminator()); + if(llvm::isa(user)) { - comet_vdump(nested_forops[0]); + IndexTreeComputeOp compute_op = llvm::cast(user); + Value lhs = compute_op.getLhs(); + result.insert(lhs.getDefiningOp()); + for(Value rhs_operand : compute_op.getRhs()) + { + result.insert(rhs_operand.getDefiningOp()); + } } - - /// Analyze the leafop, Get the tensors, rhs, lhs, and operator_type - /// --- only one rhs, it will be a fill op; if two, check op_type (+, +=, *=) - /// Check the indices contained in each tensor - /// Generate loadOp, compute ops, StoreOp. - std::vector tensors_rhs; - std::vector> tensors_lhs_Allocs; - std::vector> tensors_rhs_Allocs; - std::vector> allFormats; - std::vector> allPerms; - std::vector> allPerms_rhs; - std::vector main_tensors_all; /// main_tensors_all has first RHS tensors then LHS tensors - std::vector main_tensors_rhs; - getNumericTensors(cur_op, - tensors_rhs /* output */, - tensors_lhs_Allocs /* output */, - tensors_rhs_Allocs /* output */, - allFormats /* output */, - allPerms /* output */, - allPerms_rhs /* output */, - main_tensors_all /* output */, - main_tensors_rhs /* output */); - - /// ----------------- /// - /// Get main_tensors_all_Allocs - /// ----------------- /// - int main_tensor_nums = main_tensors_all.size(); /// output - comet_debug() << " main_tensor_nums: " << main_tensor_nums << "\n"; - /// Check the loop arg in each tensor - std::vector> main_tensors_all_Allocs = getAllAllocs(main_tensors_all); /// output - comet_debug() << " main_tensors_all_Allocs.size(): " << main_tensors_all_Allocs.size() << "\n"; - - /// ----------------- /// - /// Get allValueAccessIdx - /// ----------------- /// - /// For every main_tensors_all[i], allAccessIdx[i] is the for-loop's induction variable. - /// However, allValueAccessIdx[i] is not necessarily the induction variable. - /// For CSR, for example, - /// for (j_loc = A.rowptr[idx]; j_loc < A.rowptr[idx + 1]; ++j_loc) { j_idx = A.col[j_loc]; } - /// j_idx is allValueAccessIdx[i], and j_loc is allAccessIdx[i] - std::vector> allAccessIdx(main_tensor_nums); - std::vector> allValueAccessIdx(main_tensor_nums); - getForLoopsValueAccessIdx(builder, - loc, - main_tensor_nums, - allPerms, - allFormats, - main_tensors_all, - nested_forops, - nested_AccessIdx, - nested_forops_indices, - main_tensors_all_Allocs, - allAccessIdx /* output */, - allValueAccessIdx /* output */); - - /// Symbolic Phase preparation - std::vector symbolic_nested_forops; - std::vector symbolic_nested_AccessIdx; - std::vector symbolic_nested_forops_indices; - std::vector> symbolic_allAccessIdx(main_tensor_nums); - std::vector> symbolic_allValueAccessIdx(main_tensor_nums); - if (symbolicInfo.has_symbolic_phase) + result.insert(user); + if(llvm::isa(user)) { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - getSymbolicNestedForOpsAndAccessIdx(ancestorsWps, - ancestorsOps, - symbolic_nested_forops /* output */, - symbolic_nested_AccessIdx /* output */, - symbolic_nested_forops_indices /* output */); - - /// Set the insertion point - builder.setInsertionPoint(symbolic_nested_forops[0].getBody()->getTerminator()); - - getForLoopsValueAccessIdx(builder, - loc, - main_tensor_nums, - allPerms, - allFormats, - main_tensors_all, - symbolic_nested_forops, - symbolic_nested_AccessIdx, - symbolic_nested_forops_indices, - main_tensors_all_Allocs, - symbolic_allAccessIdx /* output */, - symbolic_allValueAccessIdx /* output */); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); + auto slice = collectChildren(llvm::cast(user)); + result.insert(slice.begin(), slice.end()); } + } - int rhs_loc = 0; - int lhs_loc = main_tensors_rhs.size(); /// lhs_loc is the location of the first LHS tensor in main_tensors_all + return result; +} - /// New version - Value lhs = cur_op.getLhs().getDefiningOp()->getOperand(0); - comet_vdump(lhs); -/// lhs is TensorLoadOp -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - Value lhs_alloc = (lhs.getDefiningOp())->getOperand(0); - comet_vdump(lhs_alloc); -#endif - if (main_tensors_rhs.size() == 1) - { /// Generate "a = b" - if (ConstantOp cstop = dyn_cast(main_tensors_rhs[0].getDefiningOp())) - { /// "a = 1.0" - comet_vdump(cstop); - if (comp_worksp_opt) /// true attr means compressed workspace - { - /// Symbolic Phase - if (symbolicInfo.has_symbolic_phase) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); +void +LowerIndexTreeToSCFPass::fillSubtree(Location loc, + IRRewriter &rewriter, + const SetVector& subtree, + const SmallVector& old_outputs, + ValueRange new_outputs, + IRMapping& map) +{ + for(Operation* child : subtree) + { + rewriter.clone(*child, map); + } - /// Set the insertion point - builder.setInsertionPoint(symbolic_nested_forops[0].getBody()->getTerminator()); + // Create yield + SmallVector yield_args; + for(Value result : old_outputs) + { + Value new_result = map.lookup(result); + yield_args.push_back(new_result); + } + rewriter.create(loc, yield_args); - /// Symbolic Phase uses the W_id_list_size in the Index Tree (main_tensors_all_Allocs[lhs_loc].back) - /// to record the current row size. - /// W_id_list_size = 0; - /// However, Numeric Phase should use C.rowptr[i_idx] to initialize W_id_list_size. - /// W_id_list_size = C.rowptr[i_idx]; - genWorkspaceCmptOpInitialAssignment(builder, - loc, - lhs_loc, - cstop, - symbolic_nested_forops, - tensors_lhs_Allocs, - main_tensors_all_Allocs, - false /* use_dynamic_init */, - symbolicInfo); + auto replacement = new_outputs.begin(); + for(auto old = old_outputs.begin(); old != old_outputs.end(); old++, replacement++) + { + rewriter.replaceAllUsesWith(*old, *replacement); + } - /// Prepare C, C.rowptr - if (symbolicInfo.mtxC_rowptr == nullptr) - { - Value &W_id_list_size = lhs; - { - comet_vdump(W_id_list_size); - } - getOutputMtxCRowptrAndDims(cur_op, - W_id_list_size, - symbolicInfo /* output */); - } + for(auto r_iter = subtree.rbegin(); r_iter != subtree.rend(); r_iter++) + { + Operation* remove = *r_iter; + // LLVM_DEBUG({ + // remove->emitOpError() << "Trying to erase op" << "\n"; + // for(auto user : remove->getUsers()) { + // user->emitOpError() << "Still used by op" << "\n"; + // } + // }); + rewriter.eraseOp(remove); + } +} - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } /// End symbolic phase - if (allFormats[lhs_loc].empty()) - { - /// The computeOp node is W_id_list_size = 0, - /// then do W_id_list_size = symbolicInfo.mtxC_rowptr[idx] - genWorkspaceCmptOpInitialAssignment(builder, - loc, - lhs_loc, - cstop, - nested_forops, - tensors_lhs_Allocs, - main_tensors_all_Allocs, - true /* use_dynamic_init */, - symbolicInfo); - } - else - { - /// The computeOp node is V[j] = 0, - /// then do V[j] = 0.0 - genWorkspaceCmptOpInitialAssignment(builder, - loc, - lhs_loc, - cstop, - nested_forops, - tensors_lhs_Allocs, - main_tensors_all_Allocs, - false /* use_dynamic_init */, - symbolicInfo); - } - } - else - { /// initial workspace - /// Generate Store 1.0, A[...] this op - /// this case: allPerms[0] is empty, allFormats[0] is empty +void +LowerIndexTreeToSCFPass::deleteDomain(Operation* op, IRRewriter &rewriter) { + + ValueRange subdomains; + if(llvm::isa(op)){ + subdomains = llvm::cast(op).getDomains(); + } else if(llvm::isa(op)){ + subdomains = llvm::cast(op).getDomains(); + } + rewriter.eraseOp(op); + for(Value subdomain : subdomains) { + deleteDomain(subdomain.getDefiningOp(), rewriter); + } +} - genCmptOpGeneralInitialAssignment(builder, - loc, - lhs_loc, - cstop, - nested_forops, - main_tensors_all_Allocs, - allValueAccessIdx); +mlir::LogicalResult +LowerIndexTreeToSCFPass::convertIndexNode(Operation *op, + IRRewriter &rewriter) { + // Generate loop + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + IndexTreeIndicesOp index_node_op = llvm::cast(op); + Operation* domain_op = index_node_op.getDomain().getDefiningOp(); + auto index_type = rewriter.getIndexType(); + + Value crd; + Value induction_var; + OpBuilder::InsertPoint before; + OpBuilder::InsertPoint loop_end; + + SetVector subtree = collectChildren(index_node_op); + subtree = topologicalSort(subtree); + SmallVector loop_outputs; + SmallVector loop_init_args; + SmallDenseMap, std::pair> tensor_access_map; + IRMapping map; + + + IndexTreeComputeOp compute_op; + ComputeSymbolicDomainOp symbolic_op = nullptr; + ComputeSymbolicDomainRowOp end_row_op = nullptr; + for(Operation* child : subtree){ + + if((compute_op = llvm::dyn_cast(child))) + { + Value loop_output = compute_op->getResult(0); + Value lhs_tensor = compute_op.getLhs().getDefiningOp()->getOperand(0); + loop_outputs.push_back(loop_output); + + if(auto clean_op = lhs_tensor.getDefiningOp()){ + loop_init_args.push_back(clean_op.getWorkspace()); + } else { + loop_init_args.push_back(lhs_tensor); + } + compute_op = nullptr; + } else if((symbolic_op = llvm::dyn_cast(child))) + { + loop_init_args.push_back(symbolic_op.getDomain()); + } else if((end_row_op = llvm::dyn_cast(child))) + { + loop_outputs.push_back(end_row_op->getResult(0)); + } + } + // TODO: Determining loop outputs won't work if computing two domains at once!!! + // If we are computing a symbolic domain, but do not see the end of the row, + // yield symbolic domain + if(symbolic_op && !end_row_op) { + loop_outputs.push_back(symbolic_op->getResult(0)); + } + + + before = rewriter.saveInsertionPoint(); + + if(llvm::isa(domain_op)) + { + // Dense domain + Value lb = rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(0)); + Value ub = domain_op->getOperand(0); + Value step = rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(1)); + scf::ForOp for_loop = rewriter.create(loc, lb, ub, step, loop_init_args); + crd = for_loop.getInductionVar(); + induction_var = crd; + + unsigned init_arg_idx = 0; + for(Value init_arg : loop_init_args){ + map.map(init_arg, for_loop.getRegionIterArg(init_arg_idx)); + init_arg_idx += 1; + } + rewriter.setInsertionPointToStart(for_loop.getBody()); + fillSubtree(loc, rewriter, subtree, loop_outputs, for_loop.getResults(), map); + rewriter.setInsertionPoint(for_loop.getBody()->getTerminator()); + loop_end = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(for_loop); + + + } else if(llvm::isa(domain_op)) + { + // Sparse domain + IndexTreeSparseDomainOp sparse_domain = llvm::cast(domain_op); + TensorFormatEnum format = (TensorFormatEnum)sparse_domain.getFormat(); + switch(format) + { + case TensorFormatEnum::CU: + case TensorFormatEnum::CN: //TODO: Figure out difference + { + Value start_idx = sparse_domain.getParent(); + if(!start_idx){ + start_idx = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); } + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + Value end_idx = rewriter.create(loc, index_type, start_idx, inc); + Value lb = rewriter.create(loc, index_type, sparse_domain.getPos(), start_idx); + Value ub = rewriter.create(loc, index_type, sparse_domain.getPos(), end_idx); + Value step = rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(1)); + scf::ForOp for_loop = rewriter.create(loc, lb, ub, step, loop_init_args); + Block* loop_body = for_loop.getBody(); + + rewriter.setInsertionPointToStart(loop_body); + Value crd_idx = for_loop.getInductionVar(); + induction_var = crd_idx; + crd = rewriter.create(loc, index_type, sparse_domain.getCrd(), crd_idx); + + + unsigned init_arg_idx = 0; + for(Value init_arg : loop_init_args){ + map.map(init_arg, for_loop.getRegionIterArg(init_arg_idx)); + init_arg_idx += 1; + } + + fillSubtree(loc, rewriter, subtree, loop_outputs, for_loop.getResults(), map); + rewriter.setInsertionPoint(for_loop.getBody()->getTerminator()); + loop_end = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(for_loop); + + tensor_access_map.insert(std::make_pair( + std::make_pair(sparse_domain.getTensor(), sparse_domain.getDim()), + std::make_pair(crd_idx, crd) + )); + break; + } + case TensorFormatEnum::S: + { + Value crd_idx = sparse_domain.getParent(); + induction_var = crd_idx; + crd = rewriter.create(loc, index_type, sparse_domain.getCrd(), crd_idx); + loop_end = rewriter.saveInsertionPoint(); + tensor_access_map.insert(std::make_pair( + std::make_pair(sparse_domain.getTensor(), sparse_domain.getDim()), + std::make_pair(crd_idx, crd) + )); + break; } - else if (main_tensors_rhs[0].getType().isa()) - { /// Cij = Wj - /// When Cij is dense type - if (lhs.getType().isa()) + } + } else if(llvm::isa(domain_op)) { + auto workspace_domain_op = llvm::cast(domain_op); + Value start_idx = workspace_domain_op.getParent(); + if(!start_idx){ + start_idx = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + } + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + Value end_idx = rewriter.create(loc, index_type, start_idx, inc); + + /** TODO: Sort crd array? **/ + Value lb = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + Value ub = rewriter.create(loc, index_type, workspace_domain_op.getTensor()); + Value step = rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(1)); + scf::ForOp for_loop = rewriter.create(loc, lb, ub, step, loop_init_args); + Block* loop_body = for_loop.getBody(); + + rewriter.setInsertionPointToStart(loop_body); + Value crd_idx = for_loop.getInductionVar(); + induction_var = crd_idx; + crd = rewriter.create(loc, index_type, workspace_domain_op.getTensor(), crd_idx, nullptr); + + unsigned init_arg_idx = 0; + for(Value init_arg : loop_init_args){ + map.map(init_arg, for_loop.getRegionIterArg(init_arg_idx)); + init_arg_idx += 1; + } + + fillSubtree(loc, rewriter, subtree, loop_outputs, for_loop.getResults(), map); + rewriter.setInsertionPoint(for_loop.getBody()->getTerminator()); + loop_end = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(for_loop); + + tensor_access_map.insert(std::make_pair( + std::make_pair(workspace_domain_op.getTensor(), 0), + std::make_pair(crd_idx, crd) + )); + } else if(llvm::isa(domain_op)) { + // Intersection between sparse domains + auto domains = llvm::cast(domain_op).getDomains(); + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + SmallVector loop_conditions; + SmallVector array_crds; + + Block* cond_block = new Block(); + Block* body_block = new Block(); + + // Create loop carried arguments for output tensors and iteration counter + for(Value init_arg : loop_init_args){ + cond_block->addArgument(init_arg.getType(), loc); + BlockArgument body_arg = body_block->addArgument(init_arg.getType(), loc); + map.map(init_arg, body_arg); + } + + Value loop_ctr_init = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + loop_init_args.push_back(loop_ctr_init); + cond_block->addArgument(index_type, loc); + body_block->addArgument(index_type, loc); + unsigned loop_carry_args = loop_init_args.size(); + + // Create control iterators for each of the tensors + for(Value domain : domains) + { + IndexTreeSparseDomainOp sparse_domain = llvm::cast(domain.getDefiningOp()); + TensorFormatEnum format = (TensorFormatEnum)sparse_domain.getFormat(); + switch(format) + { + case TensorFormatEnum::CN: + case TensorFormatEnum::S: { - /// %1 = load b[...] - /// store %1, a[...] - genCmptOpGatherFromDenseToDense(builder, - loc, - rhs_loc, - lhs_loc, - main_tensors_all_Allocs, - allValueAccessIdx); + // Not yet supported!!! + break; } - /// Cij = Wj - else if (lhs.getType().isa()) + case TensorFormatEnum::CU: { - - unsigned int lhs_2crd_size_loc; - unsigned int lhs_2pos_size_loc; - Value lhs_nnz; - Value lhs_nnz_alloc; - Value lhs_val; - getLHSBeforeGatherFromWorkspace(builder, - loc, - lhs_loc, - lhs, - main_tensors_all_Allocs, - lhs_2crd_size_loc /* output */, - lhs_2pos_size_loc /* output */, - lhs_nnz /* output */, - lhs_nnz_alloc /* output */, - lhs_val /* output */); - - if (comp_worksp_opt) /// true attr means compressed workspace - { - /// Gather results from Workspace to the sparse output - genWorkspaceCmptOpGatherFromWorkspaceToOutput(builder, - loc, - tensors_rhs_Allocs, - nested_forops, - nested_AccessIdx, - symbolicInfo, - numericInfo); - /// } - } - else - { - /// %1 = load b[...] - /// if(%1 != 0) { - /// Cnnz = load Cop.operand(4d+1) - /// store %1, cval[Cnnz] - /// store Cnnz+1, Cop.operand(4d+1) - /// } - genCmptOpGatherFromDenseToOutput(builder, - loc, - rhs_loc, - lhs_loc, - lhs_2crd_size_loc, - lhs_2pos_size_loc, - lhs, - lhs_nnz, - lhs_nnz_alloc, - lhs_val, - allFormats, - main_tensors_all_Allocs, - allAccessIdx, - allValueAccessIdx, - nested_forops); + rewriter.restoreInsertionPoint(before); + Value start_idx = sparse_domain.getParent(); + if(!start_idx){ + start_idx = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); } + Value end_idx = rewriter.create(loc, index_type, start_idx, inc); + Value start = rewriter.create(loc, index_type, sparse_domain.getPos(), start_idx); + Value end = rewriter.create(loc, index_type, sparse_domain.getPos(), end_idx); + loop_init_args.push_back(start); + before = rewriter.saveInsertionPoint(); + + Value crd_idx = cond_block->addArgument(start.getType(), loc); + rewriter.setInsertionPointToStart(cond_block); + Value cnd = rewriter.create( + loc, rewriter.getI1Type(), + arith::CmpIPredicateAttr::get(context, arith::CmpIPredicate::ult), + crd_idx, end + ); + loop_conditions.push_back(cnd); + + crd_idx = body_block->addArgument(start.getType(), loc); + rewriter.setInsertionPointToStart(body_block); + Value array_crd = rewriter.create(loc, index_type, sparse_domain.getCrd(), crd_idx); + array_crds.push_back(array_crd); + + tensor_access_map.insert(std::make_pair( + std::make_pair(sparse_domain.getTensor(), sparse_domain.getDim()), + std::make_pair(crd_idx, array_crd) + )); } } - /// Vj = Bij - else if (main_tensors_rhs[0].getType().isa()) - { - /// %Bvalue = load %Bval[..] - /// store %Bvalue, %v[%j] - - /// Symbolic Phase - if (symbolicInfo.has_symbolic_phase) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point - builder.setInsertionPoint(symbolic_nested_forops[0].getBody()->getTerminator()); - - genWorkspaceCmptOpScatterInputToWorkspace(builder, - loc, - main_tensor_nums, - main_tensors_all_Allocs, - symbolic_allValueAccessIdx); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } /// End symbolic phase - - genWorkspaceCmptOpScatterInputToWorkspace(builder, - loc, - main_tensor_nums, - main_tensors_all_Allocs, - allValueAccessIdx); - } } - else if (main_tensors_rhs.size() == 2) - { /// Generate " a = b * c" binary op - - comet_debug() << "No masking codegen...\n"; + rewriter.restoreInsertionPoint(before); - auto semiringParts = cur_op.getSemiring().split('_'); - /// check validity of semiring provided by user. - if (!Semiring_reduceOps.contains(semiringParts.first) || !Semiring_ops.contains(semiringParts.second)) - { - llvm::errs() << "Not supported semiring operator: " - << semiringParts.first << " or " << semiringParts.second << " \n"; - llvm::errs() << "Please report this error to the developers!\n"; - /// we should not proceed forward from this point to avoid faults. - } + // Create while loop + scf::WhileOp while_loop = rewriter.create(loc, cond_block->getArgumentTypes(), loop_init_args); + while_loop.getBefore().push_front(cond_block); - MaskingInfo maskingInfo; - maskingInfo.mask_type = NO_MASKING; - if (symbolicInfo.has_symbolic_phase) + rewriter.setInsertionPointToEnd(cond_block); + Value loop_condition = nullptr; + for(Value cnd : loop_conditions) + { + if(loop_condition == nullptr) { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); - - /// Set the insertion point - builder.setInsertionPoint(symbolic_nested_forops[0].getBody()->getTerminator()); - - genSymbolicSemiringLoopBody(builder, - loc, - lhs_loc, - tensors_lhs_Allocs, - symbolic_nested_forops, - symbolic_nested_AccessIdx, - symbolic_allValueAccessIdx, - symbolicInfo, - nested_forops /* numeric_nested_forops= */, - maskingInfo); - - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); + loop_condition = cnd; + } else { + loop_condition = rewriter.create(loc, rewriter.getI1Type(), loop_condition, cnd); } - - formSemiringLoopBody(cur_op, - comp_worksp_opt, - semiringParts.first, semiringParts.second, - builder, loc, lhs_loc, - main_tensors_all_Allocs, - tensors_lhs_Allocs, - tensors_rhs_Allocs, - allValueAccessIdx, - allAccessIdx, - nested_forops, - nested_AccessIdx, - symbolic_nested_forops, - allPerms_rhs, - symbolicInfo, - numericInfo, - maskingInfo); } - else if (main_tensors_rhs.size() == 3) - { /// Generate " a = b * c" binary op with masking + rewriter.create(loc, loop_condition, cond_block->getArguments()); + while_loop.getAfter().push_front(body_block); + // Create intersection + rewriter.setInsertionPointToEnd(body_block); + crd = nullptr; + for(Value array_crd : array_crds){ + if(crd == nullptr) { - /// comet_pdump(rootOp.getOperation()->getParentOfType()); - comet_pdump(rootOp->getParentOfType()); - } - auto semiringParts = cur_op.getSemiring().split('_'); - /// check validity of semiring provided by user. - if (!Semiring_reduceOps.contains(semiringParts.first) || !Semiring_ops.contains(semiringParts.second)) - { - llvm::errs() << "Not supported semiring operator: " - << semiringParts.first << " or " << semiringParts.second << " \n"; - llvm::errs() << "Please report this error to the developers!\n"; - /// we should not proceed forward from this point to avoid faults. + crd = array_crd; + } else { + crd = rewriter.create(loc, index_type, crd, array_crd); } + } - auto maskingAttr = cur_op.getMaskType(); - std::string maskingAttrStr(maskingAttr.data()); - comet_debug() << "mask attr: " << maskingAttrStr << "\n"; - - MASKING_TYPE mask_type; - if (maskingAttrStr == "push") - mask_type = MASKING_TYPE::PUSH_BASED_MASKING; - else if (maskingAttrStr == "pull") - mask_type = MASKING_TYPE::PULL_BASED_MASKING; - else if (maskingAttrStr == "auto") - mask_type = MASKING_TYPE::PUSH_BASED_MASKING; - else /// none - mask_type = MASKING_TYPE::NO_MASKING; - - switch (mask_type) + Value intersection_cnd = nullptr; + SmallVector intersections; + for(Value array_crd : array_crds) + { + Value is_intersect = rewriter.create( + loc, rewriter.getI1Type(), + arith::CmpIPredicateAttr::get(context, arith::CmpIPredicate::eq), + crd, array_crd + ); + if(intersection_cnd == nullptr) { - case NO_MASKING: - { /// Use no masking; we should not hit this case because it is handled - /// by the previous if-else branch when main_tensors_rhs.size() == 2 - break; + intersection_cnd = is_intersect; + } else { + intersection_cnd = rewriter.create(loc, rewriter.getI1Type(), intersection_cnd, is_intersect); } - case PUSH_BASED_MASKING: - { /// Use push-based masking - /// mask_tensor should be the 3rd operand of ComputeRHS (tensors_rhs[2]). - mlir::Value mask_tensor = tensors_rhs[2]; - { - comet_debug() << "mask_tensor\n"; - comet_vdump(mask_tensor); - } - - MaskingInfo maskingInfo; - maskingInfo.mask_type = PUSH_BASED_MASKING; - maskingInfo.mask_tensor = mask_tensor; - - /// Get mask_rowptr, mask_col, and mask_val arrays - getMaskSparseTensorInfo(maskingInfo /* contents updated after call*/); + intersections.push_back(is_intersect); + } - if (symbolicInfo.has_symbolic_phase) - { - /// Store the insertion point - auto last_insertion_point = builder.saveInsertionPoint(); + SmallVector if_types; + for(unsigned i = 0; i < loop_carry_args; i++) + { + if_types.push_back(loop_init_args[i].getType()); + } - /// Set the insertion point - builder.setInsertionPoint(symbolic_nested_forops[0].getBody()->getTerminator()); + scf::IfOp if_op = rewriter.create(loc, if_types, intersection_cnd, true); + rewriter.setInsertionPointToStart(if_op.elseBlock()); + rewriter.create( + loc, + std::vector( + body_block->args_begin(), + body_block->args_begin() + if_op->getNumResults()) + ); + rewriter.setInsertionPointToStart(if_op.thenBlock()); + fillSubtree(loc, rewriter, subtree, loop_outputs, while_loop.getResults(), map); + Operation* yield_op = if_op.thenBlock()->getTerminator(); + rewriter.setInsertionPoint(yield_op); + loop_end = rewriter.saveInsertionPoint(); - genSymbolicSemiringLoopBody(builder, - loc, - lhs_loc, - tensors_lhs_Allocs, - symbolic_nested_forops, - symbolic_nested_AccessIdx, - symbolic_allValueAccessIdx, - symbolicInfo, - nested_forops /* numeric_nested_forops= */, - maskingInfo); + // Increment the induction variable + induction_var = body_block->getArgument(loop_outputs.size()); + Value step = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + Value loop_ctr = rewriter.create(loc, index_type, induction_var, step); + yield_op->insertOperands(yield_op->getNumOperands(), loop_ctr); - /// Restore the insertion point - builder.restoreInsertionPoint(last_insertion_point); - } - formSemiringLoopBody(cur_op, - comp_worksp_opt, - semiringParts.first, semiringParts.second, - builder, loc, lhs_loc, - main_tensors_all_Allocs, - tensors_lhs_Allocs, - tensors_rhs_Allocs, - allValueAccessIdx, - allAccessIdx, - nested_forops, - nested_AccessIdx, - symbolic_nested_forops, - allPerms_rhs, - symbolicInfo, - numericInfo, - maskingInfo); - break; - } - case PULL_BASED_MASKING: /// Use pull-based masking - llvm::errs() << "Error: mask type PULL_BASED_MASKING is not supported, yet.\n"; - } + // Increment each argument + rewriter.setInsertionPointAfter(if_op); + SmallVector yield_args; + for(auto result : if_op.getResults()) { + yield_args.push_back(result); } - else + auto cntrl_arg = body_block->args_begin() + loop_outputs.size() + 1; + for(Value cnd : intersections) { - llvm::errs() << "No support for operation with greater than two operands in workspace transforms!" - << "\n"; + Value inc = rewriter.create(loc, index_type, cnd); + yield_args.push_back(rewriter.create(loc, index_type, *cntrl_arg, inc)); + cntrl_arg += 1; } - } - - /// ----------------- /// - /// Get the itree roots - /// ----------------- /// - void getIndexTreeOps(func::FuncOp &function, - std::vector &iTreeRoots /* output */) - { - function.walk([&](indexTree::IndexTreeOp op) - { iTreeRoots.push_back(op); }); - } - /// ----------------- /// - /// Delete every objects in opstree_vec, preventing memory leak. - /// ----------------- /// - void cleanOpstreeVec(std::vector &opstree_vec) - { - for (auto &t : opstree_vec) - { - delete t; - } + // Create YieldOp + rewriter.create(loc, yield_args); + rewriter.setInsertionPointAfter(while_loop); } - - /// ----------------- /// - /// Check if the Index Tree inputs are all sparse - /// All inputs are sparse if and only if all computeOp nodes are using workspace transformation. - /// ----------------- /// - void checkIfAllSparse(std::vector &wp_ops, - SymbolicInfo &symbolicInfo /* output */) + OpBuilder::InsertPoint after = rewriter.saveInsertionPoint(); + + for(Operation* user : op->getUsers()) { - for (Value &op : wp_ops) + if(llvm::isa(user)) { - if (indexTree::IndexTreeComputeOp cur_op = dyn_cast(op.getDefiningOp())) + IndexTreeIndexToTensorOp access_op = llvm::cast(user); + Value tensor = access_op.getTensor(); + + if(llvm::isa(tensor.getType())) { - bool comp_worksp_opt(cur_op.getCompWorkspOpt()); - if (!comp_worksp_opt) + // Tensor is Sparse + auto tensor_type = llvm::cast(tensor.getType()); + auto dim = access_op.getDim(); + TensorFormatEnum format = (TensorFormatEnum) tensor_type.getFormat()[2 * dim]; + Value access_pos; + Value access_crd; + switch(format) { - symbolicInfo.are_inputs_sparse = false; - return; + case TensorFormatEnum::CU: + case TensorFormatEnum::CN: //TODO: Figure out difference + case TensorFormatEnum::S: + { + auto key = std::make_pair(tensor, dim); + if(tensor_access_map.find(key) != tensor_access_map.end()){ + auto access_pair = tensor_access_map[key]; + access_pos = access_pair.first; + access_crd = access_pair.second; + } else { + // TODO: Determine weather of not the sequence of accesses is linear or not + // Right now we blindly say that it will be linear. But this will cause problems + // A better temporary solution would be to determine that it is not linear, + // then raise a not implemented error when trying to lower the resulting op + rewriter.restoreInsertionPoint(loop_end); + access_crd = crd; + access_pos = rewriter.create(loc, index_type, tensor, crd, rewriter.getI32IntegerAttr(dim), rewriter.getBoolAttr(true)); + } + break; + } + case TensorFormatEnum::D: + { + // TODO: This is incorrect, deal with reordering!!!! + if(!access_op.getPrevDim()){ + access_pos = crd; + } else { + rewriter.restoreInsertionPoint(before); + Value dim_idx = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + Value dim_size = rewriter.create(loc, index_type, tensor, rewriter.getI32IntegerAttr(dim)); + Value pos_start = rewriter.create(loc, index_type, dim_size, access_op.getPrevDim()); + rewriter.restoreInsertionPoint(loop_end); + access_pos = rewriter.create(loc, index_type, pos_start,crd); + } + access_crd = crd; + } } + rewriter.replaceAllUsesWith(access_op.getPos(), access_pos); + rewriter.replaceAllUsesWith(access_op.getCrd(), access_crd); + } else { + rewriter.replaceAllUsesWith(access_op.getPos(), crd); + rewriter.replaceAllUsesWith(access_op.getCrd(), crd); } + rewriter.eraseOp(access_op); } - - symbolicInfo.are_inputs_sparse = true; - } - - //===----------------------------------------------------------------------===// - /// LowerIndexTreeIRToSCF PASS - //===----------------------------------------------------------------------===// - - /// Lower the ta.tc (tensor contraction operation in TA dialect) into scf dialect. - struct LowerIndexTreeToSCFPass - : public PassWrapper> - { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerIndexTreeToSCFPass) - void runOnOperation() override; - - void doLoweringIndexTreeToSCF(indexTree::IndexTreeOp &rootOp, - OpBuilder &builder); - }; - -} /// end anonymous namespace. - -/** - * @brief : - * Goal: IndexTreeOp(i.e. a tree structure), convert into OpsTree(also tree structure) - * Steps: 1.Iterate over IndexTreeOptree - * 2.pass info to opsGen(), including tensors, current workspacetreeop, parent OpsTree node - * -- the parent of "current workspacetreeop" can get from getUser(). Only one user(tree structure) - * -- DFS traverse the workspacetreeop. How? - * */ -void LowerIndexTreeToSCFPass::doLoweringIndexTreeToSCF(indexTree::IndexTreeOp &rootOp, - OpBuilder &builder) -{ - assert(isa(rootOp)); - comet_debug() << "\ndoLoweringIndexTreeToSCF in LowerIndexTreeIRToSCF\n"; - /// auto module = rootOp->getParentOfType(); - { - /// comet_pdump(rootOp.getOperation()->getParentOfType()); - comet_pdump(rootOp->getParentOfType()); - } - - /// comet_pdump(rootOp.getOperation()->getParentOp()); - /// Here, should check the operands, at least one operand should be sparse; - /// Otherwise, if all dense operands, just return. - /// rootOp only contains one workspace child, no indices - - std::vector wp_ops; - dfsRootOpTree(rootOp.getChildren(), wp_ops); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << " wp_ops.size(): " << wp_ops.size() << "\n"; - for (auto n : wp_ops) - { - comet_debug() << " "; - comet_vdump(n); - /// Declare opsTree } -#endif - /// In ops vector, for each op, the parent of each op can get from getUsers() - /// Since it's a tree structure, only one user ==> which is the parent - /// We can initialize the OpsTree structure with this relationship. - /// Search the location of the parent of current op, if rootOp, return ops.size; - /// Otherwise, return the location index. - std::vector parent_idx; - for (unsigned int i = 0; i < wp_ops.size(); i++) + rewriter.restoreInsertionPoint(loop_end); + auto users = topologicalSort(llvm::SetVector(op->user_begin(), op->user_end())); + llvm::SetVector toRemove; + for(Operation* user : users) { - mlir::Value wp_op = wp_ops[i]; - mlir::Value wp_parent; - - for (auto n : wp_op.getDefiningOp()->getUsers()) + LLVM_DEBUG({ + user->emitOpError() << "Converting node from " << op << "\n"; + }); + if (llvm::isa(user)) { - comet_debug() << " " << i << " "; - comet_pdump(n); - wp_parent = n->getResult(0); - - comet_debug() << " parent: " << findIndexInVector_Value(wp_ops, wp_parent) << "\n"; - bool isInTree = false; - if (findIndexInVector_Value(wp_ops, wp_parent) < wp_ops.size()) - { - isInTree = true; - } - - if (isInTree || isRealRoot(wp_op.getDefiningOp())) - parent_idx.push_back(findIndexInVector_Value(wp_ops, wp_parent)); + auto clean_workspace_op = llvm::cast(user); + Value workspace = rewriter.create(loc, user->getResultTypes(), clean_workspace_op.getWorkspace()); + rewriter.replaceOp(user, {workspace}); + toRemove.insert(user); } } - -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << " parent_idx: " << parent_idx.size() << "\n"; - for (auto n : parent_idx) - { - comet_debug() << " " << n << " \n"; - /// Declare opsTree - } -#endif - - std::vector opstree_vec; - for (unsigned int i = 0; i < wp_ops.size(); i++) - { - std::vector forOps; - std::vector accessIdx; - - OpsTree *parent = nullptr; - if (i >= 1) - { /// Not rootop - parent = opstree_vec[parent_idx[i]]; - } - comet_debug() << " \n"; - OpsTree *ops = new OpsTree(forOps, accessIdx, parent, i); - if (parent != nullptr) - { /// add child to the parent - parent->addChild(ops); - } - - opstree_vec.push_back(ops); - } - -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - { - int opstree_i = 0; - for (auto n : opstree_vec) - { - comet_debug() << " " << n->id << "\n"; - comet_debug() << "opstree_vec[" << opstree_i << "] " - << "forOps.size():" << n->forOps.size() << " " - << "accessIdx.size():" << n->accessIdx.size() << "\n"; - /// << "cmptOps.size():" << n->cmptOps.size() << "\n"; - if (n->parent != nullptr) - { - comet_debug() << "parent: " << n->parent->id << "\n"; - } - else - { - comet_debug() << "parent: null \n"; - } - ++opstree_i; - } + for(Operation * user : toRemove) { + users.remove(user); } -#endif - - SymbolicInfo symbolicInfo; - NumericInfo numericInfo; - checkIfAllSparse(wp_ops, - symbolicInfo /* output */); - if (symbolicInfo.are_inputs_sparse) + + for(Operation* user : users) { - symbolicInfo.has_symbolic_phase = true; - } + LLVM_DEBUG({ + user->emitOpError() << "Converting node from " << op << "\n"; + }); - for (unsigned int i = 0; i < wp_ops.size(); i++) - { - comet_debug() << " i: " << i << "\n"; - comet_vdump(wp_ops[i]); - if (indexTree::IndexTreeIndicesOp cur_op = dyn_cast(wp_ops[i].getDefiningOp())) + if(llvm::isa(user)) { - /// Get indices - ArrayAttr op_indices = cur_op.getIndices(); - comet_debug() << "curOp is IndexTreeIndicesOp\n"; - comet_vdump(cur_op); - - /// cur_op's index attribute, e.g., "indices = [0]" - std::vector indices; - for (unsigned int j = 0; j < op_indices.size(); j++) - { - /// Get the indices; - int idx = op_indices[j].cast().getInt(); - indices.push_back(idx); - } - comet_debug() << " indices.size(): " << indices.size() << "\n"; - - /// Leaves are the computeOp nodes and the children of cur_op (an index node) - std::vector leafs; - - /// Find leaves of cur_op in the Index Tree (wp_ops). - /// A leaf is a computeOp node and cur_op is one its ancestors. - findLeafs(cur_op, indices, wp_ops, leafs /* output leaves*/); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << " leafs.size(): " << leafs.size() << "\n"; - for (auto n : leafs) - { - comet_debug() << " "; - comet_vdump(n); - } -#endif - - /// tensors: the tensors that uses the cur_op (index node) as their iterative index. - /// ids: An id is the location (0, 1, 2, ...) of the cur_op (index node) in the tensor's Perms. - /// formats: The format (e.g., "D", "CU", "CN", etc.) for the id-th dimension of the tensor. - /// tensors[i] uses the cur_op (index node) as its iterative index (e.g., [0], [1], etc.), and - /// ids[i] is the location of the iterative index in tensors[i]'s Perms. - /// For example, - /// allPerms = [[0, 1]]; the tensor[i]'s Perms is [0, 1]. If cur_op (iterative index) is indices = [1], then - /// ids[i] = 1, because [1] is at location 1 in [0, 1], i.e., the 1-st dimension of the tensors[i]. - /// formats[i] is "CU" if tensors[i]'s Formats = ["D", "CU"]. - std::vector tensors; - std::vector ids; - std::vector formats; - - comet_vdump(cur_op); - - getFormatsInfo(cur_op, - indices, - leafs, - tensors /* output */, - ids /* output */, - formats /* output */); - - comet_debug() << " indices.size(): " << indices.size() << " tensors.size(): " << tensors.size() << "\n"; - for (unsigned int m = 0; m < tensors.size(); m++) - { - comet_debug() << " Formats:" << formats[m] << " " << ids[m] << " "; - comet_vdump(tensors[m]); - } + // Recurse down tree + if(mlir::failed(convertIndexNode(user, rewriter))) + return failure(); + // loop_end = rewriter.saveInsertionPoint(); - comet_debug() << " call genForOps, i = " << i << "\n"; - genForOps(tensors, ids, formats, rootOp, builder, opstree_vec[i], symbolicInfo); - { - comet_pdump(rootOp->getParentOfType()); - } - comet_debug() << " finished call genForOps, i = " << i << "\n"; - } - else if (indexTree::IndexTreeComputeOp cur_op = dyn_cast(wp_ops[i].getDefiningOp())) + } else if (llvm::isa(user)) { - /// Generate computation ops. - std::vector ancestors_wp; /// workspace tree ancestor - getAncestorsWp(cur_op, ancestors_wp, wp_ops); -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - comet_debug() << " Current Op (IndexTreeComputeOp):"; - comet_vdump(cur_op); - for (auto n : ancestors_wp) - { - comet_debug() << " "; - comet_vdump(n); - } -#endif - - comet_debug() << " call genCmptOps, i = " << i << "\n"; - /// ancestors_wp can give all the indices of the nested loops - genCmptOps(cur_op, rootOp, builder, opstree_vec[i], ancestors_wp, - wp_ops, symbolicInfo, numericInfo); - { - comet_pdump(rootOp->getParentOfType()); - } - comet_debug() << " finished call genCmptOps, i = " << i << "\n"; - } - } - - { - comet_debug() << "End of doLoweringIndexTreeToSCF()\n"; - comet_pdump(rootOp->getParentOfType()); - } - - comet_debug() << "Cleaning up IndexTree Operations\n"; - comet_vdump(rootOp); - std::vector operations_dumpster; - rootOp.erase(); - for (auto itOp : wp_ops) - { - if (indexTree::IndexTreeComputeOp cur_op = dyn_cast(itOp.getDefiningOp())) + // Generate available compute expressions + if(mlir::failed(convertCompute(user, rewriter))) + return failure(); + } else if(llvm::isa(user)) { - comet_pdump(itOp.getDefiningOp()->getOperand(0).getDefiningOp()); /// RHS - comet_pdump(itOp.getDefiningOp()->getOperand(1).getDefiningOp()); /// LHS - operations_dumpster.push_back(cur_op.getOperand(0).getDefiningOp()); - operations_dumpster.push_back(cur_op.getOperand(1).getDefiningOp()); + ComputeSymbolicDomainOp symbolic_domain_op = llvm::cast(user); + Value symbolic_domain = symbolic_domain_op.getDomain(); + Value new_domain = rewriter.create(loc, + symbolic_domain.getType(), + symbolic_domain, + crd, + symbolic_domain_op.getIsUniqueAttr()); + rewriter.replaceOp(user, {new_domain}); } - comet_pdump(itOp.getDefiningOp()); - itOp.getDefiningOp()->erase(); - } - for (auto op : operations_dumpster) - { - op->erase(); - } - -#ifdef DEBUG_MODE_LowerIndexTreeToSCFPass - { - int opstree_i = 0; - for (auto n : opstree_vec) + else if(llvm::isa(user)) { - comet_debug() << " " << n->id << "\n"; - comet_debug() << "opstree_vec[" << opstree_i << "] " - << "forOps.size():" << n->forOps.size() << " " - << "accessIdx.size():" << n->accessIdx.size() << "\n"; - /// << "cmptOps.size():" << n->cmptOps.size() << "\n"; - if (n->parent != nullptr) - { - comet_debug() << "parent: " << n->parent->id << "\n"; - } - else - { - comet_debug() << "parent: null \n"; - } - ++opstree_i; + ComputeSymbolicDomainRowOp symbolic_domain_op = llvm::cast(user); + Value symbolic_domain = symbolic_domain_op.getDomain(); + Value new_domain = rewriter.create(loc, + symbolic_domain.getType(), + symbolic_domain, + symbolic_domain_op.getNeedsMarkAttr()); + rewriter.replaceOp(user, {new_domain}); } } -#endif - /// ----------------- /// - /// Free the memory occupied by each element in opstree_vec. - /// ----------------- /// - cleanOpstreeVec(opstree_vec); - -} /// End doLoweringIndexTreeToSCF() + rewriter.eraseOp(op); + deleteDomain(domain_op, rewriter); + rewriter.restoreInsertionPoint(after); + return success(); +} void LowerIndexTreeToSCFPass::runOnOperation() { - comet_debug() << "LowerIndexTreeToSCFPass\n"; - func::FuncOp function = getOperation(); - auto module = function.getOperation()->getParentOfType(); - auto *ctx = &getContext(); - /// Declare comet_sort_index() - declareSortFunc(module, - ctx, - function.getLoc()); - - std::vector iTreeRoots; - getIndexTreeOps(function, iTreeRoots /* output */); - for (auto root : iTreeRoots) - { - comet_vdump(root); - OpBuilder builder(root); - doLoweringIndexTreeToSCF(root, builder); + // Convert all the index trees to loops, + // Happens outside conversion pattern rewriter for convenience + // TODO: Should this also be part of the conversion pattern rewriter? + std::vector iTreeRoots; + func::FuncOp funcOp = getOperation(); + auto *context = &getContext(); + funcOp.walk([&](IndexTreeRootOp op){ iTreeRoots.push_back(op); }); + + for(auto op : iTreeRoots) + { + OpBuilder builder(op); + IRRewriter rewriter(builder); + for(Operation* user : op->getUsers()) + { + if(llvm::isa(user)) + convertIndexNode(user, rewriter); + else if (llvm::isa(user)) + convertCompute(user, rewriter); + } + rewriter.eraseOp(op); } } diff --git a/lib/Conversion/IndexTreeToSCF/SymbolicDomainConversion.cpp b/lib/Conversion/IndexTreeToSCF/SymbolicDomainConversion.cpp new file mode 100644 index 00000000..46c9b9ae --- /dev/null +++ b/lib/Conversion/IndexTreeToSCF/SymbolicDomainConversion.cpp @@ -0,0 +1,411 @@ +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/Utils/Utils.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" +#include "comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h" +#include "comet/Dialect/IndexTree/Patterns.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Index/IR/IndexAttrs.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using llvm::SmallVector; + +namespace mlir { + namespace comet{ + #define GEN_PASS_DEF_CONVERTSYMBOLICDOMAINS + #include "comet/Conversion/Passes.h.inc" + } +} + +struct SymbolicDomain { + Value pos_size; + Value pos_alloc_size; + Value crd_size; + Value dim_size; + Value pos; + Value mark_array; +}; + +static bool unpack_symbolic_domain(Value symbolic_domain, SymbolicDomain& result) +{ + if (auto cast = symbolic_domain.getDefiningOp()) { + result.pos_size = cast->getOperand(0); + result.pos_alloc_size = cast->getOperand(1); + result.crd_size = cast->getOperand(2); + result.dim_size = cast->getOperand(3); + result.pos = cast->getOperand(4); + result.mark_array = cast->getOperand(5); + return true; + } + return false; +} + +namespace { +struct ConvertDomainInsertOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(indexTree::SymbolicDomainInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + auto context = op.getContext(); + Type index_type = rewriter.getIndexType(); + SymbolicDomain domain; + if(!unpack_symbolic_domain(llvm::cast(adaptor).getDomain(), domain)){ + return failure(); + } + + Value one = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + if(op.getIsUnique()) + { + // If we know the crd is unique, we can just increment the crd_size value + domain.crd_size = rewriter.create(loc, index_type, domain.crd_size, one); + } else + { + + Value mark = rewriter.create(loc, index_type, domain.pos_size, one); + Value mark_val = rewriter.create(loc, index_type, domain.mark_array, op.getCrd()); + Value is_marked = rewriter.create(loc, + rewriter.getI1Type(), + index::IndexCmpPredicateAttr::get(context, index::IndexCmpPredicate::EQ), + mark, + mark_val); + scf::IfOp if_op = rewriter.create(loc, index_type, is_marked, true); + // We have seen this crd before + rewriter.setInsertionPointToStart(if_op.thenBlock()); + rewriter.create(loc, domain.crd_size); + + // We haven't seen this crd before + rewriter.setInsertionPointToStart(if_op.elseBlock()); + rewriter.create(loc, TypeRange(), mark, domain.mark_array, op.getCrd()); + Value new_crd_size = rewriter.create(loc, index_type, domain.crd_size, one); + rewriter.create(loc, new_crd_size); + rewriter.setInsertionPointAfter(if_op); + domain.crd_size = if_op.getResult(0); + } + + Value materialized = getTypeConverter()->materializeArgumentConversion( + rewriter, + op.getLoc(), + op.getDomain().getType(), + { + domain.pos_size, + domain.pos_alloc_size, + domain.crd_size, + domain.dim_size, + domain.pos, + domain.mark_array + } + ); + rewriter.replaceOp(op, {materialized}); + return success(); + } +}; + +struct ConvertDomainEndRowOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(indexTree::SymbolicDomainEndRowOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + Type index_type = rewriter.getIndexType(); + SymbolicDomain domain; + if(!unpack_symbolic_domain(llvm::cast(adaptor).getDomain(), domain)){ + return failure(); + } + + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + Value new_pos_size = rewriter.create(loc, index_type, domain.pos_size, inc); + + // TODO: Dynamically resize array? + rewriter.create(loc, TypeRange(), domain.crd_size, domain.pos, new_pos_size); + Value materialized = getTypeConverter()->materializeArgumentConversion( + rewriter, + op.getLoc(), + op.getDomain().getType(), + { + new_pos_size, + domain.pos_alloc_size, + domain.crd_size, + domain.dim_size, + domain.pos, + domain.mark_array + } + ); + rewriter.replaceOp(op, {materialized}); + return success(); + } +}; + +struct ConvertDomainDeclarationOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(indexTree::DeclDomainOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + Type index_type = rewriter.getIndexType(); + Type memref_type = MemRefType::get({ShapedType::kDynamic,}, index_type); + + Value zero = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + Value pos_alloc_size = rewriter.create(loc, index_type, op.getNumRows(), inc); + Value pos = rewriter.create(loc, memref_type, ValueRange{pos_alloc_size}, ValueRange(), nullptr); + rewriter.create(loc, zero, pos, zero); + Value mark_array = rewriter.create(loc, memref_type, ValueRange{op.getDimSize()}, ValueRange(), nullptr); + auto new_op = rewriter.create( + loc, + op->getResultTypes(), + ValueRange({zero, pos_alloc_size, zero, op.getDimSize(), pos, mark_array}) + ); + rewriter.replaceOp(op, new_op->getResults()); + return success(); + } +}; + +struct ConvertSparseTensorOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(indexTree::IndexTreeSparseTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + + for(Value domain : llvm::cast(adaptor).getDomains()) + { + if(llvm::isa(domain.getType())){ + if(!domain.getDefiningOp()) + return failure(); + } + } + + llvm::SmallVector arrays; + llvm::SmallVector array_sizes; + llvm::SmallVector dim_sizes; + + auto ctx = op.getContext(); + auto format_unk = tensorAlgebra::TensorFormatEnumAttr::get(ctx, tensorAlgebra::TensorFormatEnum::UNK); + auto format_dense = tensorAlgebra::TensorFormatEnumAttr::get(ctx, tensorAlgebra::TensorFormatEnum::D); + auto format_compressed = tensorAlgebra::TensorFormatEnumAttr::get(ctx, tensorAlgebra::TensorFormatEnum::CU); + SmallVector dim_format; + uint32_t rank = 0; + + auto loc = op.getLoc(); + Type index_type = rewriter.getIndexType(); + Type memref_type = MemRefType::get({ShapedType::kDynamic,}, index_type); + Value zero = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + Value one = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + Value nnz = one; + + for(Value domain : llvm::cast(adaptor).getDomains()) + { + rank += 1; + if(llvm::isa(domain.getType())) + { + Operation* domain_op = domain.getDefiningOp(); + if(llvm::isa(domain_op)) + { + auto dense_domain_op = llvm::cast(domain_op); + Value dim_size = dense_domain_op.getDimSize(); + + Value pos = rewriter.create(loc, MemRefType::get({1,}, index_type)); + rewriter.create(loc, TypeRange(), dim_size, pos, zero); + Value crd = rewriter.create(loc, MemRefType::get({0,}, index_type)); + Value pos_tile = rewriter.create(loc, MemRefType::get({0,}, index_type)); + Value crd_tile = rewriter.create(loc, MemRefType::get({0,}, index_type)); + + arrays.push_back(rewriter.create(loc, pos, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, crd, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, pos_tile, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, crd_tile, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + + array_sizes.push_back(one); + array_sizes.push_back(zero); + array_sizes.push_back(zero); + array_sizes.push_back(zero); + + dim_sizes.push_back(dim_size); + nnz = rewriter.create(loc, index_type, nnz, dim_size); + dim_format.push_back(format_dense); + dim_format.push_back(format_unk); + } else if(llvm::isa(domain_op)) + { + auto sparse_domain_op = llvm::cast(domain_op); + Value dim_size = sparse_domain_op.getDimSize(); + Value pos_size = sparse_domain_op.getPosSize(); + Value crd_size = sparse_domain_op.getCrdSize(); + + Value pos = sparse_domain_op.getPos(); + Value crd = sparse_domain_op.getCrd(); + Value pos_tile = rewriter.create(loc, MemRefType::get({0,}, index_type)); + Value crd_tile = rewriter.create(loc, MemRefType::get({0,}, index_type)); + + arrays.push_back(pos); + arrays.push_back(crd); + arrays.push_back(rewriter.create(loc, pos_tile, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, crd_tile, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + + array_sizes.push_back(pos_size); + array_sizes.push_back(crd_size); + array_sizes.push_back(zero); + array_sizes.push_back(zero); + + dim_sizes.push_back(dim_size); + nnz = crd_size; + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + } else + return failure(); + } else if(llvm::isa(domain.getType())) + { + SymbolicDomain domain_struct; + assert(unpack_symbolic_domain(domain, domain_struct)); + Value crd = rewriter.create(loc, memref_type, ValueRange{domain_struct.crd_size}, ValueRange(), nullptr); + Value pos_tile = rewriter.create(loc, MemRefType::get({0,}, index_type)); + Value crd_tile = rewriter.create(loc, MemRefType::get({0,}, index_type)); + + arrays.push_back(rewriter.create(loc, domain_struct.pos, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, crd, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, pos_tile, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + arrays.push_back(rewriter.create(loc, crd_tile, rewriter.getUnitAttr(), rewriter.getUnitAttr())); + + array_sizes.push_back(domain_struct.pos_size); + array_sizes.push_back(domain_struct.crd_size); + array_sizes.push_back(zero); + array_sizes.push_back(zero); + + dim_sizes.push_back(domain_struct.dim_size); + nnz = domain_struct.crd_size; + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + } + } + + //Allocate values array and initialize + Type float_type = llvm::cast(op.getResult().getType()).getElementType(); + Value val_array = rewriter.create(loc, MemRefType::get({ShapedType::kDynamic,}, float_type), ValueRange{nnz}, ValueRange(), nullptr); + Value float_zero = rewriter.create(loc, float_type, rewriter.getFloatAttr(float_type, 0.0)); + auto for_loop = rewriter.create(loc, zero, nnz, one); + rewriter.setInsertionPointToStart(for_loop.getBody()); + auto induction_var = for_loop.getInductionVar(); + rewriter.create(loc, TypeRange(), float_zero, val_array, induction_var); + rewriter.setInsertionPointAfter(for_loop); + val_array = rewriter.create(loc, val_array, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + + std::vector args; + args.insert(args.end(), arrays.begin(), arrays.end()); + args.push_back(val_array); + args.insert(args.end(), array_sizes.begin(), array_sizes.end()); + args.push_back(nnz); + args.insert(args.end(), dim_sizes.begin(), dim_sizes.end()); + + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), args, rank, rewriter.getArrayAttr(dim_format)); + return success(); + } +}; + +class EraseDenseDomainOp : public OpConversionPattern{ + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(indexTree::IndexTreeDenseDomainOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +class EraseSparseDomainOp : public OpConversionPattern{ + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(indexTree::IndexTreeSparseDomainOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +} //namespace + +struct ConvertSymbolicDomainsPass + : public PassWrapper> +{ + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertSymbolicDomainsPass) + + void runOnOperation() override + { + // Convert the rest of the index tree dialect to SCF + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](indexTree::SymbolicDomainType domainType, SmallVectorImpl &types) { + auto context = domainType.getContext(); + Type index_type = IndexType::get(context); + Type memref_type = MemRefType::get({ShapedType::kDynamic,}, index_type); + types.push_back(index_type); + types.push_back(index_type); + types.push_back(index_type); + types.push_back(index_type); + types.push_back(memref_type); + types.push_back(memref_type); + return success(); + }); + + typeConverter.addSourceMaterialization( + [](OpBuilder &builder, indexTree::SymbolicDomainType resultType, ValueRange inputs, + Location loc) -> Optional { + assert(inputs.size() == 6); + Value value = builder.create(loc, resultType, inputs)->getResult(0); + return value; + }); + + typeConverter.addArgumentMaterialization( + [](OpBuilder &builder, indexTree::SymbolicDomainType resultType, ValueRange inputs, + Location loc) -> Optional { + assert(inputs.size() == 6); + Value value = builder.create(loc, resultType, inputs)->getResult(0); + return value; + }); + + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + + + mlir::RewritePatternSet patterns(&getContext()); + populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); + indexTree::populateIndexTreeTypeConversionPatterns(&getContext(), patterns, typeConverter, target); + patterns.add(typeConverter, &getContext()); + patterns.add(typeConverter, &getContext()); + + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +/// Lower sparse tensor algebra operation to loops +std::unique_ptr mlir::comet::createConvertSymbolicDomainsPass() +{ + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Conversion/TensorAlgebraToIndexTree/TensorAlgebraToIndexTree.cpp b/lib/Conversion/TensorAlgebraToIndexTree/TensorAlgebraToIndexTree.cpp index 26a4a58a..7ba799a2 100644 --- a/lib/Conversion/TensorAlgebraToIndexTree/TensorAlgebraToIndexTree.cpp +++ b/lib/Conversion/TensorAlgebraToIndexTree/TensorAlgebraToIndexTree.cpp @@ -41,7 +41,6 @@ using namespace mlir::tensorAlgebra; // *********** For debug purpose *********// //#define COMET_DEBUG_MODE #include "comet/Utils/debug.h" -#undef COMET_DEBUG_MODE // *********** For debug purpose *********// using namespace mlir; @@ -106,387 +105,197 @@ Value getRealRhs(Operation *op) return op->getResult(0); } -void buildDefUseInfo(UnitExpression *e) -{ - auto lhs = e->getLHS(); - lhs->setDefiningExpr(e); - for (auto operand : e->getOperands()) - { - if (auto def = operand->getDefiningExpr()) - { - def->addUser(e); - } - } -} +template +mlir::LogicalResult generalIndexOperationRewrite( + mlir::Operation* op, + ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) { -IndicesType getUnion(IndicesType indices1, IndicesType indices2) -{ - sort(indices1.begin(), indices1.end()); - sort(indices2.begin(), indices2.end()); - - IndicesType allIndices(indices1.size() * 4); - - IndicesType::iterator it = set_union(indices1.begin(), indices1.end(), indices2.begin(), indices2.end(), allIndices.begin()); - allIndices.resize(it - allIndices.begin()); - return allIndices; -} + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + TATensorOp mult_op = llvm::dyn_cast(op); -void doTensorMultOp(TensorMultOp op, unique_ptr &tree) -{ - Value rhs1_tensor = getRealRhs(op.getRhs1().getDefiningOp()); - Value rhs2_tensor = getRealRhs(op.getRhs2().getDefiningOp()); + Value rhs1_tensor = getRealRhs(mult_op.getRhs1().getDefiningOp()); + Value rhs2_tensor = getRealRhs(mult_op.getRhs2().getDefiningOp()); Value lhs_tensor = getRealLhs(op); - Value mask_tensor = op.getMask(); - - comet_debug() << "LowerTensorAlgebraToIndexTreePass: doTensorMultOp\n"; - comet_debug() << "rhs1-tensor\n"; - comet_vdump(rhs1_tensor); - comet_debug() << "rhs2-tensor\n"; - comet_vdump(rhs2_tensor); - comet_debug() << "lhs-tensor\n"; - comet_vdump(lhs_tensor); - comet_debug() << "mask-tensor\n"; - comet_vdump(mask_tensor); - - auto allPerms = getAllPerms(op.getIndexingMaps()); - auto allFormats = getAllFormats(op.getFormatsAttr(), allPerms); - auto SemiringOp = op.getSemiringAttr(); - auto MaskingTypeAttr = op.getMaskTypeAttr(); - - assert(allPerms.size() == 3); - - auto B = tree->getOrCreateTensor(rhs1_tensor, allFormats[0]); - auto C = tree->getOrCreateTensor(rhs2_tensor, allFormats[1]); - auto A = tree->getOrCreateTensor(lhs_tensor, allFormats[2]); - Tensor *M; - std::unique_ptr e; - if (mask_tensor != nullptr) /// mask is an optional input - { - comet_debug() << "mask input provided by user\n"; - M = tree->getOrCreateTensor(mask_tensor, allFormats[2]); /// format same as lhs_tensor - e = make_unique(A, B, C, M, "*"); - } - else - { - comet_debug() << "no mask input provided by user\n"; - e = make_unique(A, B, C, "*"); - } - - e->setSemiring(SemiringOp.cast().getValue()); - e->setMaskType(MaskingTypeAttr.cast().getValue()); - - e->setOperation(op); - buildDefUseInfo(e.get()); - - auto inputDomains = e->computeInputIterDomains(); - auto outputDomains = e->computeOutputIterDomains(); - - IndicesType rhs1_indices = tree->getIndices(rhs1_tensor); - IndicesType rhs2_indices = tree->getIndices(rhs2_tensor); - IndicesType allIndices = getUnion(rhs1_indices, rhs2_indices); - - auto lhsIndices = A->getIndices(); - - TreeNode *parent = tree->getRoot(); - for (unsigned long i = 0; i < allIndices.size(); i++) - { - int index = allIndices[i]; - auto &idomain = inputDomains.at(index); - auto node = tree->addIndexNode(index, parent, idomain); - - /// If this index appears on the lhs too, set output domain for the index node - if (std::find(lhsIndices.begin(), lhsIndices.end(), index) != lhsIndices.end()) - { - auto &odomain = outputDomains.at(index); - node->setOutputDomain(odomain); - } - - parent = node; + Value mask_tensor = nullptr; + if(llvm::isa(op)){ + mask_tensor = llvm::cast(op).getMask(); } - tree->addComputeNode(std::move(e), parent); -} - -template -void doElementWiseOp(T op, unique_ptr &tree) -{ - Value rhs1_tensor = getRealRhs(op.getRhs1().getDefiningOp()); - Value rhs2_tensor = getRealRhs(op.getRhs2().getDefiningOp()); - Value lhs_tensor = getRealLhs(op); - - comet_debug() << "LowerTensorAlgebraToIndexTreePass: doElementWiseMultOp\n"; - comet_debug() << "rhs1-tensor\n"; - comet_vdump(rhs1_tensor); - comet_debug() << "rhs2-tensor\n"; - comet_vdump(rhs2_tensor); - comet_debug() << "lhs-tensor\n"; - comet_vdump(lhs_tensor); - - auto allPerms = getAllPerms(op.getIndexingMaps()); - auto allFormats = getAllFormats(op.getFormatsAttr(), allPerms); - auto SemiringOp = op.getSemiringAttr(); - auto maskAttr = "none"; - - assert(allPerms.size() == 3); - - auto B = tree->getOrCreateTensor(rhs1_tensor, allFormats[0]); - auto C = tree->getOrCreateTensor(rhs2_tensor, allFormats[1]); - auto A = tree->getOrCreateTensor(lhs_tensor, allFormats[2]); + auto indexing_maps = mult_op.getIndexingMaps(); + auto semiring = mult_op.getSemiringAttr().template cast().getValue(); + auto MaskingTypeAttr = mult_op.getMaskTypeAttr(); - auto e = make_unique(A, B, C, "*"); + auto tensor_type = op->getResultTypes()[0]; + auto itree_op = rewriter.create(loc, tensor_type); + Region* body = &itree_op.getRegion(); + loc = body->getLoc(); + Block* block = rewriter.createBlock(body); - e->setOperation(op); - e->setSemiring(SemiringOp.template cast().getValue()); /// for element-wise multiplication - e->setMaskType(maskAttr); /// for element-wise multiplication - buildDefUseInfo(e.get()); + indexTree::IndexTreeType tree_type = indexTree::IndexTreeType::get(context); + Value parent = rewriter.create(loc, tree_type); - auto inputDomains = e->computeInputIterDomains(); - auto outputDomains = e->computeOutputIterDomains(); - - /// RHS and LHS indices must be the same for elementwise multiplication - IndicesType allIndices = tree->getIndices(rhs1_tensor); - - auto lhsIndices = A->getIndices(); - TreeNode *parent = tree->getRoot(); - for (unsigned long i = 0; i < allIndices.size(); i++) + //Construct each index variable + auto lhsMap = indexing_maps[2].template cast().getValue(); + indexTree::IndexNodeType index_node_type = indexTree::IndexNodeType::get(context); + std::vector index_nodes; + for (unsigned i = 0; i < lhsMap.getNumDims(); i++) { - int index = allIndices[i]; - auto &idomain = inputDomains.at(index); - - auto node = tree->addIndexNode(index, parent, idomain); - - /// If this index appears on the lhs too, set output domain for the index node - if (std::find(lhsIndices.begin(), lhsIndices.end(), index) != lhsIndices.end()) - { - auto &odomain = outputDomains.at(index); - node->setOutputDomain(odomain); - } - - parent = node; + parent = rewriter.create(loc, index_node_type, parent); + index_nodes.push_back(parent); } - tree->addComputeNode(std::move(e), parent); - /// cout << "print tree after tc\n"; - /// tree->print(); -} -/// helper for treeToDialect() -Operation *getSetOpForTC(Operation *op) -{ - assert(isa(op) || isa(op) || isa(op) || isa(op)); - /// TODO(gkestor): fix the issue with getUsers() after getRealRhs(). - comet_debug() << "The following loop may cause issue!\n"; - Operation *firstUser; - for (auto user : op->getResult(0).getUsers()) + //Construct LHS Operand + if(mask_tensor != nullptr) { - firstUser = user; - break; + lhs_tensor = rewriter.create(loc, tensor_type, lhs_tensor, mask_tensor, MaskingTypeAttr); } - - assert(isa(firstUser)); - return firstUser; -} - -/// helper for treeToDialect() -IndexTreeComputeOp createComputeNodeOp(OpBuilder &builder, TreeNode *node, Location &loc) -{ - auto context = builder.getContext(); - IntegerType i64Type = IntegerType::get(context, 64); - auto expr = node->getExpression(); - - SmallVector allIndices_rhs; - for (auto t : expr->getOperands()) + llvm::SmallVector pos; + llvm::SmallVector crds; + Value prev_dim = nullptr; + auto access_type = rewriter.getIndexType(); + for (size_t i = 0; i < lhsMap.getNumResults(); i++) { - SmallVector indices; - for (auto index : t->getIndices()) - { - indices.push_back(index); - } - allIndices_rhs.push_back(builder.getI64ArrayAttr(indices)); + auto expr = lhsMap.getResult(i); + IndexTreeIndexToTensorOp access_op = rewriter.create( + loc, + TypeRange({access_type, access_type}), + lhs_tensor, + index_nodes[expr.template cast().getPosition()], + rewriter.getUI32IntegerAttr((unsigned)i), + prev_dim + ); + pos.push_back(access_op.getPos()); + crds.push_back(access_op.getCrd()); + prev_dim = pos[pos.size() - 1]; } - SmallVector allIndices_lhs; - for (auto t : expr->getResults()) + indexTree::OperandType operand_type = indexTree::OperandType::get(context); + Value lhs_operand = rewriter.create(loc, operand_type, + lhs_tensor, pos, + crds); + + //Construct RHS operands + std::vector rhs_operands; + pos.clear(); + crds.clear(); + prev_dim = nullptr; + auto affineMap = indexing_maps[0].template cast().getValue(); + for (size_t i = 0; i < affineMap.getNumResults(); i++) { - SmallVector indices; - for (auto index : t->getIndices()) - { - indices.push_back(index); - } - allIndices_lhs.push_back(builder.getI64ArrayAttr(indices)); + auto expr = affineMap.getResult(i); + IndexTreeIndexToTensorOp access_op = rewriter.create( + loc, + TypeRange({access_type, access_type}), + rhs1_tensor, + index_nodes[expr.template cast().getPosition()], + rewriter.getUI32IntegerAttr((unsigned)i), + prev_dim + ); + pos.push_back(access_op.getPos()); + crds.push_back(access_op.getCrd()); + prev_dim = pos[pos.size() - 1]; } - - SmallVector allFormats_rhs; - for (auto t : expr->getOperands()) + rhs_operands.push_back(rewriter.create( + loc, operand_type, rhs1_tensor, pos, crds)); + + pos.clear(); + crds.clear(); + prev_dim = nullptr; + affineMap = indexing_maps[1].template cast().getValue(); + for (size_t i = 0; i < affineMap.getNumResults(); i++) { - SmallVector formats; - for (auto &f : t->getFormats()) - { - formats.push_back(f); - } - allFormats_rhs.push_back(builder.getStrArrayAttr(formats)); - } - SmallVector allFormats_lhs; - for (auto t : expr->getResults()) - { - SmallVector formats; - for (auto &f : t->getFormats()) - { - formats.push_back(f); - } - allFormats_lhs.push_back(builder.getStrArrayAttr(formats)); + auto expr = affineMap.getResult(i); + IndexTreeIndexToTensorOp access_op = rewriter.create( + loc, + TypeRange({access_type, access_type}), + rhs2_tensor, + index_nodes[expr.template cast().getPosition()], + rewriter.getUI32IntegerAttr((unsigned)i), + prev_dim + ); + pos.push_back(access_op.getPos()); + crds.push_back(access_op.getCrd()); + prev_dim = pos[pos.size() - 1]; } + rhs_operands.push_back(rewriter.create( + loc, operand_type, rhs2_tensor, pos, crds)); + + Value compute_op = rewriter.create( + loc, + tensor_type, + parent, + lhs_operand, + rhs_operands, + rewriter.getStringAttr(semiring) + ); + + rewriter.create(loc, TypeRange(), compute_op); + rewriter.replaceOp(op, itree_op->getResults()); + return success(); +} - std::vector t_rhs; - Value t_lhs = expr->getLHS()->getValue(); - for (auto o : expr->getOperands()) - { - t_rhs.push_back(o->getValue()); - } +struct TensorMultOpLowering : public mlir::ConversionPattern { + TensorMultOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TensorMultOp::getOperationName(), 1, ctx) {} - /// check if mask exists and add to t_rhs - if (expr->getMask() != nullptr) - { - comet_debug() << "user has provided mask input\n"; - t_rhs.push_back(expr->getMask()->getValue()); /// add mask to IndexTreeComputeRHSOp + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + return generalIndexOperationRewrite(op, operands, rewriter); } +}; - Value leafop_rhs = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), t_rhs, - builder.getArrayAttr(allIndices_rhs), - builder.getArrayAttr(allFormats_rhs)); - comet_vdump(leafop_rhs); - Value leafop_lhs = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), t_lhs, - builder.getArrayAttr(allIndices_lhs), - builder.getArrayAttr(allFormats_lhs)); - bool comp_worksp_opt = false; /// non-compressed workspace, this is a place-holder and it is updated in workspace transform pass. - llvm::StringRef semiring = expr->getSemiring(); - llvm::StringRef maskType = expr->getMaskType(); - auto leafop = builder.create(loc, i64Type, leafop_rhs, leafop_lhs, builder.getBoolAttr(comp_worksp_opt), builder.getStringAttr(semiring), builder.getStringAttr(maskType)); - - comet_pdump(leafop); - return leafop; -} - -/** - * This function performs the actual removal of the ta operations in the tree, - * and add corresponding ta.itree operations.› - * @param tree - */ -void treeToDialect(Index_Tree *tree) -{ - vector TAOps = tree->getContainingTAOps(); - unsigned int TAOpsID = 0; - OpBuilder builder(TAOps[TAOpsID]); - auto loc = TAOps[TAOpsID]->getLoc(); - auto context = builder.getContext(); +struct TensorElewsMultOpLowering : public mlir::ConversionPattern { + TensorElewsMultOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TensorElewsMultOp::getOperationName(), 1, ctx) {} - std::map nodeToOp; + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + return generalIndexOperationRewrite(op, operands, rewriter); + } +}; - IntegerType i64Type = IntegerType::get(context, 64); +struct TensorAddOpLowering : public mlir::ConversionPattern { + TensorAddOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TensorAddOp::getOperationName(), 1, ctx) {} - for (auto &node : tree->getNodesInReverseTopoOrder()) - { - if (node->isComputeNode()) - { - assert(nodeToOp.count(node) == 0); - builder.setInsertionPoint(TAOps[TAOpsID]); - nodeToOp[node] = createComputeNodeOp(builder, node, loc); - TAOpsID++; - } - else if (node->isRealIndexNode()) - { - if (node->getChildren().empty()) - { - continue; /// to skip nodes that become childless after fusion - } - SmallVector children; - for (auto c : node->getChildren()) - { - assert(nodeToOp.count(c) > 0); - children.push_back(nodeToOp[c]); - } - /// assert(!children.empty()); - SmallVector indices; - indices.push_back(node->getIndex()); - auto indicesAttr = builder.getI64ArrayAttr(indices); - - SmallVector ids; - ids.push_back(node->getId()); - - Value indexNodeOp = builder.create(loc, - i64Type, - children, - indicesAttr); - - nodeToOp[node] = indexNodeOp; - - if (node->getParent() != nullptr && node->getParent()->isFillerIndexNode()) - { -#ifdef DEBUG_MODE_LowerTensorAlgebraToIndexTreePass - Value op = builder.create(loc, i64Type, indexNodeOp); - comet_vdump(op); -#else - builder.create(loc, i64Type, indexNodeOp); -#endif - } - } + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + return generalIndexOperationRewrite(op, operands, rewriter); } +}; - for (auto op : tree->getContainingTAOps()) - { - auto setOp = getSetOpForTC(op); - setOp->erase(); - op->erase(); +struct TensorSubtractOpLowering : public mlir::ConversionPattern { + TensorSubtractOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TensorSubtractOp::getOperationName(), 1, ctx) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + return generalIndexOperationRewrite(op, operands, rewriter); } -} +}; void LowerTensorAlgebraToIndexTreePass::runOnOperation() { - unique_ptr tree; - func::FuncOp func = getOperation(); - - tree = Index_Tree::createTreeWithRoot(); - bool formIndexTreeDialect = false; - - comet_debug() << "IndexTree pass running on Function\n"; - for (Block &B : func.getBody()) - { - for (Operation &op : B) - { - if (isa(&op)) - { - doTensorMultOp(cast(&op), tree); - formIndexTreeDialect = true; - } - else if (isa(&op)) - { - doElementWiseOp(cast(&op), tree); - formIndexTreeDialect = true; - } - else if (isa(&op) || isa(&op)) - { - /// elementwise addition and subtraction - if (isa(&op)) - { - doElementWiseOp(cast(&op), tree); - } - - if (isa(&op)) - { - doElementWiseOp(cast(&op), tree); - } - formIndexTreeDialect = true; - } - } - } - - if (formIndexTreeDialect) - { - comet_debug() << " Dumping Index tree IR\n"; - /// only do this for TensorMultOp or TensorElewsMultOp - treeToDialect(tree.get()); - } + mlir::ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalOp(); + + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); } /// create all the passes. diff --git a/lib/Conversion/TensorAlgebraToSCF/LowerPCToLoops.cpp b/lib/Conversion/TensorAlgebraToSCF/.LowerPCToLoops.cpp similarity index 100% rename from lib/Conversion/TensorAlgebraToSCF/LowerPCToLoops.cpp rename to lib/Conversion/TensorAlgebraToSCF/.LowerPCToLoops.cpp diff --git a/lib/Conversion/TensorAlgebraToSCF/CMakeLists.txt b/lib/Conversion/TensorAlgebraToSCF/CMakeLists.txt index 4a563c79..0572b56c 100644 --- a/lib/Conversion/TensorAlgebraToSCF/CMakeLists.txt +++ b/lib/Conversion/TensorAlgebraToSCF/CMakeLists.txt @@ -1,10 +1,9 @@ add_mlir_conversion_library(COMETTensorAlgebraToSCF - EarlyLowering.cpp - LateLowering.cpp LowerFunc.cpp - LowerPCToLoops.cpp TensorAlgebraToSCF.cpp - + EarlyLowering.cpp + LateLowering.cpp + SparseTensorConversionPass.cpp ADDITIONAL_HEADER_DIRS ${COMET_MAIN_INCLUDE_DIR}/comet/Conversion/TensorAlgebraToIndexTree diff --git a/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp b/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp index fef490e7..138bec30 100644 --- a/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/EarlyLowering.cpp @@ -90,8 +90,9 @@ namespace tensorLoadOp = cast(tensorOperand.getDefiningOp()); auto memref = tensorLoadOp.getMemref(); auto valueAttr = tensorFillOp.getValue(); + + rewriter.setInsertionPoint(tensorLoadOp); Value constantOp = rewriter.create(loc, valueAttr); - rewriter.create(loc, constantOp, memref); rewriter.eraseOp(op); diff --git a/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp b/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp index d532a72b..d082db46 100644 --- a/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/LateLowering.cpp @@ -72,6 +72,7 @@ namespace FloatType f64Type = FloatType::getF64(ctx); IndexType indexType = IndexType::get(ctx); Type unrankedMemrefType_f64 = UnrankedMemRefType::get(f64Type, 0); + Type unrankedMemref_index = mlir::UnrankedMemRefType::get(indexType, 0); auto printTensorF64Func = FunctionType::get(ctx, {mlir::UnrankedMemRefType::get(f64Type, 0)}, {}); auto printTensorIndexFunc = FunctionType::get(ctx, {mlir::UnrankedMemRefType::get(indexType, 0)}, {}); @@ -112,60 +113,38 @@ namespace module.push_back(print_func); } + std::string comet_print_i64Str = "comet_print_memref_i64"; + if (!hasFuncDeclaration(module, comet_print_i64Str)) + { + print_func = func::FuncOp::create(loc, comet_print_i64Str, printTensorIndexFunc, ArrayRef{}); + print_func.setPrivate(); + module.push_back(print_func); + } + if (inputType.isa()) { auto alloc_op = cast(op->getOperand(0).getDefiningOp()); comet_vdump(alloc_op); auto u = rewriter.create(loc, unrankedMemrefType_f64, alloc_op); rewriter.create(loc, comet_print_f64Str, SmallVector{}, ValueRange{u}); - } - else + }else if (inputType.isa()) { - /// If the Input type is tensor - if (inputType.isa()) - { - auto rhs = op->getOperand(0).getDefiningOp(); - auto alloc_op = cast(rhs->getOperand(0).getDefiningOp()); - comet_vdump(alloc_op); - auto u = rewriter.create(loc, unrankedMemrefType_f64, alloc_op); - rewriter.create(loc, comet_print_f64Str, SmallVector{}, ValueRange{u}); - } - else if (inputType.isa()) - { - std::string comet_print_i64Str = "comet_print_memref_i64"; - if (!hasFuncDeclaration(module, comet_print_i64Str)) - { - print_func = func::FuncOp::create(loc, comet_print_i64Str, printTensorIndexFunc, ArrayRef{}); - print_func.setPrivate(); - module.push_back(print_func); - } - - auto sp_op = cast(op->getOperand(0).getDefiningOp()); - Type unrankedMemref_index = mlir::UnrankedMemRefType::get(indexType, 0); - - auto rhs = op->getOperand(0).getDefiningOp(); - for (int rsize = 0; rsize < sp_op.getDimArrayCount(); rsize += 2) - { - /// accessing xD_pos array and creating cast op for its alloc - auto xD_pos = rhs->getOperand(rsize).getDefiningOp(); - auto alloc_rhs = cast(xD_pos->getOperand(0).getDefiningOp()); - auto u = rewriter.create(loc, unrankedMemref_index, alloc_rhs); - rewriter.create(loc, comet_print_i64Str, SmallVector{}, ValueRange{u}); - - /// accessing xD_crd array and creating cast op for its alloc - auto xD_crd = rhs->getOperand(rsize + 1).getDefiningOp(); - alloc_rhs = cast(xD_crd->getOperand(0).getDefiningOp()); - u = rewriter.create(loc, unrankedMemref_index, alloc_rhs); - rewriter.create(loc, comet_print_i64Str, SmallVector{}, ValueRange{u}); - } - - auto xD_value = rhs->getOperand(sp_op.getValueArrayPos()).getDefiningOp(); - auto alloc_rhs = cast(xD_value->getOperand(0).getDefiningOp()); - auto u = rewriter.create(loc, unrankedMemrefType_f64, alloc_rhs); + auto rhs = op->getOperand(0); + auto tensor_type = llvm::cast(inputType); + auto memref_type = MemRefType::get(tensor_type.getShape(), tensor_type.getElementType()); + auto buffer = rewriter.create(loc, memref_type, rhs); + + if(llvm::isa(tensor_type.getElementType())){ + auto u = rewriter.create(loc, unrankedMemref_index, buffer); + rewriter.create(loc, comet_print_i64Str, SmallVector{}, ValueRange{u}); + } else { + auto u = rewriter.create(loc, unrankedMemrefType_f64, buffer); rewriter.create(loc, comet_print_f64Str, SmallVector{}, ValueRange{u}); } - else - llvm::errs() << __FILE__ << " " << __LINE__ << "Unknown Data type\n"; + } + else + { + llvm::errs() << __FILE__ << " " << __LINE__ << "Unknown Data type\n"; } } diff --git a/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp b/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp new file mode 100644 index 00000000..1a5137c9 --- /dev/null +++ b/lib/Conversion/TensorAlgebraToSCF/SparseTensorConversionPass.cpp @@ -0,0 +1,952 @@ +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/Utils/Utils.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" +#include "comet/Conversion/IndexTreeToSCF/IndexTreeToSCF.h" +#include "comet/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.h" +#include "comet/Dialect/IndexTree/Patterns.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Index/IR/IndexAttrs.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/ScopedPrinter.h" + +using namespace mlir; +using namespace mlir::tensorAlgebra; + +#define DEBUG_TYPE "sparse_tensor" + +namespace mlir { + namespace comet{ + #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS + #include "comet/Conversion/Passes.h.inc" + } +} + +/** Helper structures to turn sparse tensor into pointers */ +struct Dimension { + Value dim_size; + Value insert_pos; + TensorFormatEnum format; + + Value pos; + Value pos_size; + Value crd; + Value crd_size; + + bool has_block; + Value block_pos; + Value block_pos_size; + Value block_crd; + Value block_crd_size; + +}; + +struct SparseTensor { + SmallVector dims; + Value vals; + Value val_size; +}; + +struct Workspace { + Value workspace; + Value mark_value; + Value mark_array; + Value num_crds; + Value crds; +}; + +static bool unpack_sparse_tensor(Value sparse_tensor, SparseTensor& result) +{ + /** Helper function to turn arguments from an unrealized cast to sparse tensor */ + if (auto cast = + sparse_tensor.getDefiningOp()) { + SparseTensorType type = llvm::dyn_cast(sparse_tensor.getType()); + if(!type) + return false; + + auto format = type.getFormat(); + auto dim_sizes = type.getDims(); + auto cur_arg = cast.getInputs().begin(); + for(unsigned i = 0; i < dim_sizes.size(); i++){ + Dimension d; + d.dim_size = *cur_arg; + cur_arg++; + d.insert_pos = *cur_arg; + cur_arg++; + d.format = (TensorFormatEnum)format[2 * i]; + switch(d.format){ + case TensorFormatEnum::D: { + d.pos = *cur_arg; + cur_arg++; + d.pos_size = *cur_arg; + cur_arg++; + break; + } + case TensorFormatEnum::CU: + case TensorFormatEnum::CN: { + d.pos = *cur_arg; + cur_arg++; + d.pos_size = *cur_arg; + cur_arg++; + + d.crd = *cur_arg; + cur_arg++; + d.crd_size = *cur_arg; + cur_arg++; + break; + } + case TensorFormatEnum::S: { + d.crd = *cur_arg; + cur_arg++; + d.crd_size = *cur_arg; + cur_arg++; + break; + } + default: { + assert(false && "Could not unpack unknown format to sparse tensor."); + } + } + result.dims.push_back(d); + } + result.vals = *cur_arg; + cur_arg++; + result.val_size = *cur_arg; + return true; + } + return false; +} + +static void pack_sparse_tensor(SparseTensorType type, SparseTensor& sparse_tensor, SmallVectorImpl& result) +{ + for(Dimension d : sparse_tensor.dims) + { + result.push_back(d.dim_size); + result.push_back(d.insert_pos); + switch(d.format){ + case TensorFormatEnum::D: { + result.push_back(d.pos); + result.push_back(d.pos_size); + break; + } + case TensorFormatEnum::CU: + case TensorFormatEnum::CN: { + result.push_back(d.pos); + result.push_back(d.pos_size); + result.push_back(d.crd); + result.push_back(d.crd_size); + break; + } + case TensorFormatEnum::S: { + result.push_back(d.crd); + result.push_back(d.crd_size); + break; + } + default: { + assert(false && "Could not unpack unknown format to sparse tensor."); + } + } + } + result.push_back(sparse_tensor.vals); + result.push_back(sparse_tensor.val_size); + return; +} + +static bool unpack_workspace(Value workspace_val, Workspace& result) +{ + if (auto cast = + workspace_val.getDefiningOp()) { + auto cur_arg = cast.getInputs().begin(); + result.workspace = *cur_arg; + cur_arg++; + result.mark_value = *cur_arg; + cur_arg++; + result.mark_array = *cur_arg; + cur_arg++; + result.num_crds = *cur_arg; + cur_arg++; + result.crds = *cur_arg; + return true; + } + return false; + +} +static void pack_workspace(WorkspaceType type, Workspace& workspace, SmallVectorImpl& result) +{ + result.push_back(workspace.workspace); + result.push_back(workspace.mark_value); + result.push_back(workspace.mark_array); + result.push_back(workspace.num_crds); + result.push_back(workspace.crds); +} + +namespace { +class ConvertSpTensorConstructOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SparseTensorConstructOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // This is unnecessarily complicated because sptensor_construct does not have named arguments + SparseTensor sp_tensor; + SparseTensorType sp_tensor_type = llvm::cast(op->getResult(0).getType()); + auto inputs = op.getIndices(); + unsigned rank = sp_tensor_type.getDims().size(); + for(unsigned i = 0; i < rank; i++) + { + Dimension d; + d.format = (TensorFormatEnum) sp_tensor_type.getFormat()[2 * i]; + d.dim_size = inputs[(8 * rank) + 2 + i]; + d.insert_pos = rewriter.create(op.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(0)); + d.pos = inputs[(4 * i)]; + d.crd = inputs[(4 * i) + 1]; + d.pos_size = inputs[(4 * rank) + 1 + (4 * i)]; + d.crd_size = inputs[(4 * rank) + 1 + (4 * i) + 1]; + sp_tensor.dims.push_back(d); + } + + sp_tensor.vals = inputs[(4 * rank)]; + sp_tensor.val_size = inputs[(8*rank)+1]; + + SmallVector cast_args; + pack_sparse_tensor(sp_tensor_type, sp_tensor, cast_args); + rewriter.replaceOpWithNewOp(op, sp_tensor_type, cast_args); + return success(); + } +}; + +class ConvertSpTensorInsertOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertSpTensorInsertOp(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(tensorAlgebra::TensorInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if(!llvm::isa(op.getTensor().getType())){ + return failure(); + } + + SparseTensor sp_tensor; + TensorInsertOpAdaptor insertAdpator = llvm::cast(adaptor); + if(!unpack_sparse_tensor(insertAdpator.getTensor(), sp_tensor)) { + return failure(); + } + + // Match successful! + auto loc = op.getLoc(); + Type index_type = rewriter.getIndexType(); + unsigned i = 0; + for(Dimension& dim : sp_tensor.dims) { + if(dim.format != TensorFormatEnum::D) { + Value crd_idx = insertAdpator.getPos()[i]; + Value crd = insertAdpator.getCrds()[i]; + Value crd_tensor = dim.crd; + Value crd_size = dim.crd_size; + crd_tensor = rewriter.create( + loc, + crd_tensor.getType(), + crd, + crd_tensor, + crd_idx); + dim.crd = crd_tensor; + + // TODO: This is wrong if we insert the same crd multiple times but format is CU? + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + dim.insert_pos = rewriter.create(loc, index_type, dim.insert_pos, inc); + + /** TODO: Implement tensor resize */ + /** TODO: Insert into CSR only has to be done once per idx? */ + } + + i++; + } + Value vals = sp_tensor.vals; + Value val_idx = insertAdpator.getPos()[insertAdpator.getPos().size() - 1]; + vals = rewriter.create(loc, + vals.getType(), + insertAdpator.getValue(), + vals, + val_idx); + sp_tensor.vals = vals; + SparseTensorType sp_tensor_type = llvm::cast(op.getTensor().getType()); + + SmallVector cast_args; + pack_sparse_tensor(sp_tensor_type, sp_tensor, cast_args); + rewriter.replaceOpWithNewOp(op, sp_tensor_type, cast_args); + return success(); + } +}; + +class ConvertSpTensorExtractOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertSpTensorExtractOp(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(tensorAlgebra::TensorExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorExtractOpAdaptor extractAdaptor = llvm::cast(adaptor); + if(!llvm::isa(extractAdaptor.getTensor().getType())){ + return failure(); + } + + SparseTensor sp_tensor; + if(!unpack_sparse_tensor(extractAdaptor.getTensor(), sp_tensor)) { + return failure(); + } + + llvm::ScopedPrinter logger{llvm::dbgs()}; + LLVM_DEBUG({ + logger.startLine() << "Unpacked sparse tensor: " << extractAdaptor.getTensor().getDefiningOp() << "\n"; + }); + // Match successful! + auto loc = op.getLoc(); + Type float_type = llvm::cast(sp_tensor.vals.getType()).getElementType(); + Value result = rewriter.create(loc, float_type, sp_tensor.vals, extractAdaptor.getPos()); + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class ConvertSpTensorInsertCrd + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertSpTensorInsertCrd(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(SpTensorInsertCrd op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + SparseTensor sp_tensor; + auto opAdaptor = llvm::cast(adaptor); + if(!unpack_sparse_tensor(opAdaptor.getTensor(), sp_tensor)) { + return failure(); + } + + // Match successful! + auto loc = op.getLoc(); + Type index_type = rewriter.getIndexType(); + Dimension& dim = sp_tensor.dims[opAdaptor.getDim()]; + if(dim.format != TensorFormatEnum::D) { + Value crd_idx = opAdaptor.getIdx(); + Value crd = opAdaptor.getCrd(); + Value crd_tensor = dim.crd; + Value crd_size = dim.crd_size; + crd_tensor = rewriter.create(loc, + crd_tensor.getType(), + crd, + crd_tensor, + crd_idx); + dim.crd = crd_tensor; + + // Update tensor insert state + Value inc = rewriter.create(loc, index_type, rewriter.getIndexAttr(1)); + dim.insert_pos = rewriter.create(loc, index_type, dim.insert_pos, inc); + } + + SparseTensorType sp_tensor_type = llvm::cast(opAdaptor.getTensor().getType()); + SmallVector cast_args; + pack_sparse_tensor(sp_tensor_type, sp_tensor, cast_args); + rewriter.replaceOpWithNewOp(op, sp_tensor_type, cast_args); + return success(); + } +}; + +class ConvertSpTensorGetDimSize + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertSpTensorGetDimSize(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(SpTensorGetDimSize op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SpTensorGetDimSizeAdaptor tensorAdaptor = llvm::cast(adaptor); + SparseTensor sp_tensor; + if(!unpack_sparse_tensor(tensorAdaptor.getTensor(), sp_tensor)) { + return failure(); + } + rewriter.replaceOp(op, {sp_tensor.dims[tensorAdaptor.getDim()].dim_size}); + return success(); + } +}; + +class ConvertSpTensorFindPos + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertSpTensorFindPos(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(TensorFindPos op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TensorFindPosAdaptor tensorAdaptor = llvm::cast(adaptor); + SparseTensor sp_tensor; + if(!unpack_sparse_tensor(tensorAdaptor.getTensor(), sp_tensor)) { + return failure(); + } + + if(tensorAdaptor.getIsLinear()) + { + rewriter.replaceOp(op, {sp_tensor.dims[tensorAdaptor.getDim()].insert_pos}); + } else { + assert(false && "Lowering non-unique inserts is not yet supported, please use workspace transform"); + } + + return success(); + } +}; + +class ConvertAllocWorkspaceOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertAllocWorkspaceOp(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(AllocWorkspaceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto alloc_adaptor = llvm::cast(adaptor); + auto loc = op.getLoc(); + Type index_type = rewriter.getIndexType(); + + Value sp_tensor = op.getTensor(); + auto sp_tensor_type = llvm::cast(sp_tensor.getType()); + auto dims = alloc_adaptor.getDims(); + SmallVector dim_attrs(dims.size(), ShapedType::kDynamic); + SmallVector sizes; + for(auto dim : dims) + { + Value dim_size = rewriter.create(loc, index_type, sp_tensor, llvm::cast(dim)); + sizes.push_back(dim_size); + } + + Workspace workspace; + auto workspace_tensor_type = RankedTensorType::get(dim_attrs, sp_tensor_type.getElementType()); + workspace.workspace = rewriter.create(loc, workspace_tensor_type, sizes); + workspace.mark_value = rewriter.create(loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + auto workspace_mark_type = RankedTensorType::get(dim_attrs, rewriter.getI32Type()); + workspace.mark_array = rewriter.create(loc, workspace_mark_type, sizes); + workspace.num_crds = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + auto crds_type = RankedTensorType::get({ShapedType::kDynamic,}, index_type); + workspace.crds = rewriter.create(loc, crds_type, sizes); + + auto workspace_type = llvm::cast(op->getResult(0).getType()); + /** TODO: Support higher dimensional workspaces! */ + assert(workspace_type.getDims().size() == 1 && "Workspace dimensions > 1 are currently unsupported."); + + SmallVector cast_args; + pack_workspace(workspace_type, workspace, cast_args); + rewriter.replaceOpWithNewOp(op, workspace_type, cast_args); + return success(); + } +}; + +class ConvertWorkspaceGetNNZ + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertWorkspaceGetNNZ(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(SpTensorGetNNZ op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto opAdaptor = llvm::cast(adaptor); + if(!llvm::isa(opAdaptor.getTensor().getType())){ + return failure(); + } + Workspace workspace; + if(!unpack_workspace(opAdaptor.getTensor(), workspace)){ + return failure(); + } + + rewriter.replaceOp(op, {workspace.num_crds,}); + return success(); + } +}; + +class ConvertWorkspaceGetCrds + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertWorkspaceGetCrds(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(SpTensorGetCrd op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto opAdaptor = llvm::cast(adaptor); + if(!llvm::isa(opAdaptor.getTensor().getType())){ + return failure(); + } + Workspace workspace; + if(!unpack_workspace(opAdaptor.getTensor(), workspace)){ + return failure(); + } + + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), workspace.crds, opAdaptor.getIdx()); + return success(); + } +}; + +class ConvertWorkspaceGetDimSize + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertWorkspaceGetDimSize(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(SpTensorGetDimSize op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto opAdaptor = llvm::cast(adaptor); + if(!llvm::isa(opAdaptor.getTensor().getType())){ + return failure(); + } + + Workspace workspace; + if(!unpack_workspace(opAdaptor.getTensor(), workspace)) { + return failure(); + } + Value dim = rewriter.create(op->getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(opAdaptor.getDim())); + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), workspace.workspace, dim); + return success(); + } +}; + +class ConvertWorkspaceTensorInsertOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertWorkspaceTensorInsertOp(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(TensorInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto opAdaptor = llvm::cast(adaptor); + if(!llvm::isa(opAdaptor.getTensor().getType())){ + return failure(); + } + WorkspaceType workspace_type = llvm::cast(opAdaptor.getTensor().getType()); + assert(workspace_type.getDims().size() == 1 && "Workspace dimensions > 1 are currently unsupported."); + + Workspace workspace; + if(!unpack_workspace(opAdaptor.getTensor(), workspace)) { + return failure(); + } + + auto loc = op.getLoc(); + auto context = op.getContext(); + ValueRange crds = opAdaptor.getCrds(); + Value crd = crds[opAdaptor.getCrds().size() - 1]; + Value mark_at_crd = rewriter.create( + loc, + rewriter.getI32Type(), + workspace.mark_array, + crd + ); + Value not_seen = rewriter.create( + loc, + rewriter.getI1Type(), + arith::CmpIPredicateAttr::get(context, arith::CmpIPredicate::ne), + mark_at_crd, + workspace.mark_value + ); + + Operation* if_op = rewriter.create( + loc, + not_seen, + [workspace, crd] (OpBuilder& builder, Location loc) { + Type index_type = builder.getIndexType(); + Value new_mark = builder.create(loc, workspace.mark_array.getType(), workspace.mark_value, workspace.mark_array, crd); + Value new_crds = builder.create(loc, workspace.crds.getType(), crd, workspace.crds, workspace.num_crds); + Value inc = builder.create(loc, index_type, builder.getIndexAttr(1)); + Value new_crd_size = builder.create(loc, index_type, workspace.num_crds, inc); + builder.create(loc, ArrayRef({new_mark, new_crd_size, new_crds})); + }, + [workspace] (OpBuilder& builder, Location loc) { + builder.create(loc, ArrayRef({workspace.mark_array, workspace.num_crds, workspace.crds})); + } + ); + workspace.mark_array = if_op->getResult(0); + workspace.num_crds = if_op->getResult(1); + workspace.crds = if_op->getResult(2); + workspace.workspace = rewriter.create( + loc, + workspace.workspace.getType(), + opAdaptor.getValue(), + workspace.workspace, + crd + ); + + SmallVector cast_args; + pack_workspace(workspace_type, workspace, cast_args); + rewriter.replaceOpWithNewOp(op, workspace_type, cast_args); + return success(); + } +}; + +class ConvertWorkspaceTensorExtractOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertWorkspaceTensorExtractOp(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(TensorExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto opAdaptor = llvm::cast(adaptor); + if(!llvm::isa(opAdaptor.getTensor().getType())){ + return failure(); + } + + Workspace workspace; + if(!unpack_workspace(opAdaptor.getTensor(), workspace)) { + return failure(); + } + + auto loc = op.getLoc(); + auto context = op.getContext(); + Value pos = opAdaptor.getPos(); + Value mark_at_pos = rewriter.create( + loc, + rewriter.getI32Type(), + workspace.mark_array, + pos + ); + Value seen = rewriter.create( + loc, + rewriter.getI1Type(), + arith::CmpIPredicateAttr::get(context, arith::CmpIPredicate::eq), + mark_at_pos, + workspace.mark_value + ); + + + Operation* if_op = rewriter.create( + loc, + seen, + [op, workspace, pos] (OpBuilder& builder, Location loc) { + Value extracted = builder.create(loc, op->getResultTypes(), workspace.workspace, pos); + builder.create(loc, ArrayRef({extracted})); + }, + [op] (OpBuilder& builder, Location loc) { + // TODO: Does the zero value depend on the semi-ring? + Type result_type = op->getResult(0).getType(); + Value zero = builder.create(loc, result_type, builder.getFloatAttr(result_type, 0)); + builder.create(loc, ArrayRef({zero})); + } + ); + + rewriter.replaceOp(op, if_op->getResults()); + return success(); + } +}; + +class ConvertWorkspaceClearOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ConvertWorkspaceClearOp(MLIRContext *context) + : OpConversionPattern(context) {} + LogicalResult + matchAndRewrite(WorkspaceClearOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto opAdaptor = llvm::cast(adaptor); + if(!llvm::isa(opAdaptor.getTensor().getType())){ + return failure(); + } + WorkspaceType workspace_type = llvm::cast(opAdaptor.getTensor().getType()); + Workspace workspace; + if(!unpack_workspace(opAdaptor.getTensor(), workspace)) { + return failure(); + } + auto loc = op.getLoc(); + Value inc = rewriter.create(loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + workspace.mark_value = rewriter.create(loc, rewriter.getI32Type(), workspace.mark_value, inc); + workspace.num_crds = rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + + SmallVector cast_args; + pack_workspace(workspace_type, workspace, cast_args); + rewriter.replaceOpWithNewOp(op, workspace_type, cast_args); + return success(); + } +}; + +class PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + PrintOpLowering(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + Location loc = op->getLoc(); + auto inputType = adaptor.getInput().getType(); + Type index_type = rewriter.getIndexType(); + SmallVector empty_size(1, 1); + auto empty_type = RankedTensorType::get(empty_size, index_type); + Value empty_tensor = rewriter.create(loc, empty_type, ValueRange(), (Value)nullptr); + Value neg = rewriter.create(loc, index_type, rewriter.getIndexAttr(-1)); + Value zero = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + empty_tensor = rewriter.create(loc, empty_type, neg, empty_tensor, zero); + + if (inputType.isa()) + { + SparseTensor sp_tensor; + if(!unpack_sparse_tensor(adaptor.getInput(), sp_tensor)) { + return failure(); + } + for (Dimension& dim : sp_tensor.dims) + { + switch(dim.format){ + case TensorFormatEnum::D: { + rewriter.create(loc, dim.pos); + rewriter.create(loc, empty_tensor); + break; + } + case TensorFormatEnum::CU: + case TensorFormatEnum::CN: { + rewriter.create(loc, dim.pos); + rewriter.create(loc, dim.crd); + break; + } + case TensorFormatEnum::S: { + rewriter.create(loc, empty_tensor); + rewriter.create(loc, dim.crd); + break; + } + default: { + assert(false && "Could not print unknown format to sparse tensor."); + } + } + } + rewriter.create(loc, sp_tensor.vals); + rewriter.eraseOp(op); + return success(); + } + return failure(); + } +}; + +class GetTimeLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + GetTimeLowering(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(GetTimeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + auto ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + auto f64Type = rewriter.getF64Type(); + std::string getTimeStr = "getTime"; + + if (!hasFuncDeclaration(module, getTimeStr)) + { + auto getTimeFunc = FunctionType::get(ctx, {}, {FloatType::getF64(ctx)}); + /// func @getTime() -> f64 + func::FuncOp func1 = func::FuncOp::create(op->getLoc(), getTimeStr, + getTimeFunc, ArrayRef{}); + func1.setPrivate(); + module.push_back(func1); + } + + rewriter.replaceOpWithNewOp(op, getTimeStr, SmallVector{f64Type}); + + return success(); + } +}; + +class PrintElapsedTimeLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + PrintElapsedTimeLowering(MLIRContext *context) : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(PrintElapsedTimeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + auto ctx = rewriter.getContext(); + auto module = op->getParentOfType(); + + auto start = adaptor.getStart(); + auto end = adaptor.getEnd(); + std::string printElapsedTimeStr = "printElapsedTime"; + auto f64Type = rewriter.getF64Type(); + + if (!hasFuncDeclaration(module, printElapsedTimeStr)) + { + auto printElapsedTimeFunc = FunctionType::get(ctx, {f64Type, f64Type}, {}); + /// func @printElapsedTime(f64, f64) -> () + func::FuncOp func1 = func::FuncOp::create(op->getLoc(), printElapsedTimeStr, + printElapsedTimeFunc, ArrayRef{}); + func1.setPrivate(); + module.push_back(func1); + } + + rewriter.replaceOpWithNewOp(op, printElapsedTimeStr, SmallVector{}, ValueRange{start, end}); + + return success(); + } +}; + +} + +void mlir::comet::populateSparseTensorConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, TypeConverter &typeConverter) { + typeConverter.addConversion( + [](tensorAlgebra::SparseTensorType type, SmallVectorImpl &types) { + ArrayRef dim_sizes = type.getDims(); + ArrayRef format = type.getFormat(); + + auto context = type.getContext(); + Type index_type = IndexType::get(context); + bool is_known_size = true; + int known_size = 1; + for(unsigned i = 0; i < dim_sizes.size(); i++) { + types.push_back(index_type); //Dimension size + types.push_back(index_type); //Insert pos + switch((TensorFormatEnum)format[2 * i]) + { + case TensorFormatEnum::D: + { + if(dim_sizes[i] != ShapedType::kDynamic) { + known_size *= dim_sizes[i]; + } else { + is_known_size = false; + } + auto pos_type = mlir::RankedTensorType::get({1,}, index_type); + types.push_back(pos_type); //Pos tensor + types.push_back(index_type); //Pos size + break; + } + case TensorFormatEnum::CU: + case TensorFormatEnum::CN: + { + Type pos_type = mlir::RankedTensorType::get({ShapedType::kDynamic,}, + index_type); + Type crd_type = mlir::RankedTensorType::get({ShapedType::kDynamic,}, + index_type); + is_known_size = false; + + types.push_back(pos_type); //Pos tensor + types.push_back(index_type); //Pos size + types.push_back(crd_type); //Crd tensor + types.push_back(index_type); //Crd size + break; + } + case TensorFormatEnum::S: + { + Type crd_type = mlir::RankedTensorType::get({ShapedType::kDynamic,}, index_type); + types.push_back(crd_type); //Crd tensor + types.push_back(index_type); //Crd size + break; + } + default: { + assert(false && "Could not unpack unknown format to sparse tensor."); + } + } + } + Type element_type = type.getElementType(); + Type value_type = mlir::RankedTensorType::get({ShapedType::kDynamic,}, element_type); + types.push_back(value_type); //Value tensor + types.push_back(index_type); //Value size + return success(); + }); + + typeConverter.addConversion( + [](WorkspaceType type, SmallVectorImpl &types) { + Type element_type = type.getElementType(); + ArrayRef dim_sizes = type.getDims(); + auto context = type.getContext(); + types.push_back(RankedTensorType::get(dim_sizes, element_type)); // Workspace + types.push_back(IntegerType::get(context, 32)); // Mark Value + types.push_back(RankedTensorType::get(dim_sizes, IntegerType::get(context, 32))); // Mark array + types.push_back(IndexType::get(context)); // Crd Size + types.push_back(RankedTensorType::get({ShapedType::kDynamic,}, IndexType::get(context)));// Crd tensors + return success(); + }); + + typeConverter.addArgumentMaterialization( + [](OpBuilder &builder, SparseTensorType resultType, ValueRange inputs, + Location loc) -> Optional { + auto op = builder.create(loc, resultType, inputs); + return op->getResult(0); + }); + + typeConverter.addSourceMaterialization( + [](OpBuilder &builder, SparseTensorType resultType, ValueRange inputs, + Location loc) -> Optional { + auto op = builder.create(loc, resultType, inputs); + return op->getResult(0); + }); + + typeConverter.addArgumentMaterialization( + [](OpBuilder &builder, WorkspaceType resultType, ValueRange inputs, + Location loc) -> Optional { + auto op = builder.create(loc, resultType, inputs); + return op->getResult(0); + }); + + typeConverter.addSourceMaterialization( + [](OpBuilder &builder, WorkspaceType resultType, ValueRange inputs, + Location loc) -> Optional { + auto op = builder.create(loc, resultType, inputs); + return op->getResult(0); + }); + + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); +} + +struct SparseTensorConversionPass : comet::impl::SparseTensorConversionPassBase { + using SparseTensorConversionPassBase::SparseTensorConversionPassBase; + + SparseTensorConversionPass() = default; + SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default; + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + TypeConverter typeConverter; + ConversionTarget target(*ctx); + + // Everything in the TADialect must go + target.addIllegalDialect(); + + // The following operations and dialects may be introduced by the + // rewriting rules, and are therefore marked as legal. + target.addLegalOp(); + target.addLegalDialect< + arith::ArithDialect, bufferization::BufferizationDialect, + tensor::TensorDialect, memref::MemRefDialect, scf::SCFDialect, + func::FuncDialect, index::IndexDialect + >(); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](tensorAlgebra::PrintOp op) { + return typeConverter.isLegal(op->getOperandTypes()); + }); + + typeConverter.addConversion([](Type type) { return type; }); + + + // Populate with rules and apply rewriting rules. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, + target); + mlir::indexTree::populateIndexTreeTypeConversionPatterns(ctx, patterns, typeConverter, target); + mlir::comet::populateSparseTensorConversionPatterns(ctx, patterns, typeConverter); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +std::unique_ptr mlir::comet::createSparseTensorConversionPass() +{ + return std::make_unique(); +} diff --git a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp index 314dd8db..8fe25da2 100644 --- a/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp +++ b/lib/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.cpp @@ -25,10 +25,10 @@ //===----------------------------------------------------------------------===// #include "comet/Dialect/TensorAlgebra/IR/TADialect.h" -#include "comet/Dialect/TensorAlgebra/IR/TATypes.h" #include "comet/Dialect/Utils/Utils.h" #include "comet/Conversion/TensorAlgebraToSCF/TensorAlgebraToSCF.h" #include "comet/Dialect/TensorAlgebra/Passes.h" +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -43,7 +43,7 @@ using namespace mlir::bufferization; using namespace mlir::tensorAlgebra; // *********** For debug purpose *********// -//#define COMET_DEBUG_MODE +// #define COMET_DEBUG_MODE #include "comet/Utils/debug.h" #undef COMET_DEBUG_MODE // *********** For debug purpose *********// @@ -196,8 +196,9 @@ namespace { /// for dense comet_debug() << "Dense transpose\n"; - auto inputTensorLoadOp = cast(op->getOperand(0).getDefiningOp()); - auto inputMemref = inputTensorLoadOp.getMemref(); + auto inputTensor = op->getOperand(0); + // auto inputTensorLoadOp = cast(op->getOperand(0).getDefiningOp()); + // auto inputMemref = inputTensorLoadOp.getMemref(); for (auto u : op.getOperation()->getResult(0).getUsers()) { @@ -219,13 +220,10 @@ namespace } comet_vdump(lhs); - auto outputMemref = lhs.getDefiningOp()->getOperand(0); - rewriter.create(loc, inputMemref, outputMemref, llvm::ArrayRef(allPerms[1])); - Value res_value = rewriter.create(loc, outputMemref); - - op.replaceAllUsesWith(res_value); - rewriter.eraseOp(op); - + // auto outputMemref = lhs.getDefiningOp()->getOperand(0); + auto la_transpose = rewriter.create(loc, inputTensor, lhs, llvm::ArrayRef(allPerms[1])); + // Value res_value = rewriter.create(loc, outputMemref, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + rewriter.replaceOp(op, la_transpose.getResults()); return success(); } else @@ -700,6 +698,26 @@ namespace } }; /// ScalarOpsLowering +class ConvertSetOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TensorSetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + + auto opAdaptor = llvm::cast(adaptor); + Value lhs = opAdaptor.getLhs(); + Value rhs = opAdaptor.getRhs(); + rewriter.replaceUseIf(rhs, lhs, [&](OpOperand& use) { + auto user = use.getOwner(); + auto ancestor = op->getBlock()->findAncestorOpInBlock(*user); + return (ancestor && op->isBeforeInBlock(ancestor)); + }); + rewriter.eraseOp(op); + return success(); + } +}; + } /// end anonymous namespace. /// This is a partial lowering to linear algebra of the tensor algebra operations that are @@ -743,6 +761,12 @@ void LowerTensorAlgebraToSCFPass::runOnOperation() bufferization::BufferizationDialect>(); target.addLegalOp(); + target.addLegalDialect(); + target.addIllegalOp(); /// Now that the conversion target has been defined, we just need to provide /// the set of patterns that will lower the TA operations. @@ -751,7 +775,8 @@ void LowerTensorAlgebraToSCFPass::runOnOperation() patterns.insert(&getContext()); + ConstantOpLowering, + ConvertSetOp>(&getContext()); /// With the target and rewrite patterns defined, we can now attempt the /// conversion. The conversion will signal failure if any of our `illegal` /// operations were not converted successfully. diff --git a/lib/Dialect/IndexTree/CMakeLists.txt b/lib/Dialect/IndexTree/CMakeLists.txt index 3c0f114c..93c4badd 100644 --- a/lib/Dialect/IndexTree/CMakeLists.txt +++ b/lib/Dialect/IndexTree/CMakeLists.txt @@ -1,24 +1,23 @@ -add_llvm_library(COMETIndexTreeDialect +add_mlir_dialect_library(COMETIndexTreeDialect IR/IndexTreeDialect.cpp IR/IndexTree.cpp - Transforms/IterationDomain.cpp - Transforms/Tensor.cpp - Transforms/UnitExpression.cpp - Transforms/WorkspaceTransforms.cpp - Transforms/Fusion.cpp + Transforms/IterationDomainInference.cpp + Transforms/DomainConcretization.cpp + Transforms/SymbolicCompute.cpp + Transforms/WorkspaceTransforms.cpp + + # Transforms/Fusion.cpp ADDITIONAL_HEADER_DIRS ${COMET_MAIN_INCLUDE_DIR}/comet/Dialect/IndexTree - ) - -add_dependencies( - COMETIndexTreeDialect + DEPENDS COMETIndexTreeOpsIncGen + COMETIndexTreeTypesIncGen COMETIndexTreePassIncGen MLIRSupport - ) - -target_link_libraries(COMETIndexTreeDialect MLIRIR) + LINK_LIBS PUBLIC + MLIRIR +) diff --git a/lib/Dialect/IndexTree/IR/IndexTreeDialect.cpp b/lib/Dialect/IndexTree/IR/IndexTreeDialect.cpp index fa05336a..c82847ac 100644 --- a/lib/Dialect/IndexTree/IR/IndexTreeDialect.cpp +++ b/lib/Dialect/IndexTree/IR/IndexTreeDialect.cpp @@ -25,14 +25,17 @@ // //===----------------------------------------------------------------------===// #include -#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" - #include "mlir/IR/DialectImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" +#include "llvm/ADT/TypeSwitch.h" +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" using namespace mlir; using namespace mlir::indexTree; +using namespace mlir::tensorAlgebra; //===----------------------------------------------------------------------===// // IndexTreeDialect @@ -40,26 +43,16 @@ using namespace mlir::indexTree; #include "comet/Dialect/IndexTree/IR/IndexTreeDialect.cpp.inc" -Type mlir::indexTree::IndexTreeDialect::parseType(DialectAsmParser &parser) const -{ - /// Parse the main keyword for the type. - StringRef keyword; - /// for "range" and "sptensor" type - if (parser.parseKeyword(&keyword)) - return Type(); +//===----------------------------------------------------------------------===// +// Tablegen Type Definitions +//===----------------------------------------------------------------------===// - parser.emitError(parser.getNameLoc(), - "unknown IndexTree type: " + keyword); - return Type(); -} +#define GET_TYPEDEF_CLASSES +#include "comet/Dialect/IndexTree/IR/IndexTreeTypes.cpp.inc" + +// Include the op interface definitions +#include "comet/Dialect/IndexTree/IR/IndexTreeOpInterfaces.cpp.inc" -/// Print an instance of a type registered to the index tree dialect. -/// No type definition yet -void mlir::indexTree::IndexTreeDialect::printType(mlir::Type type, - mlir::DialectAsmPrinter &printer) const -{ - return; -} #define GET_OP_CLASSES #include "comet/Dialect/IndexTree/IR/IndexTreeOps.cpp.inc" @@ -68,8 +61,16 @@ void mlir::indexTree::IndexTreeDialect::printType(mlir::Type type, /// the point of registration of types and operations for the dialect. void IndexTreeDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "comet/Dialect/IndexTree/IR/IndexTreeTypes.cpp.inc" + >(); + addOperations< #define GET_OP_LIST #include "comet/Dialect/IndexTree/IR/IndexTreeOps.cpp.inc" >(); -} \ No newline at end of file +} + +using namespace mlir; +using namespace mlir::indexTree; diff --git a/lib/Dialect/IndexTree/Transforms/DomainConcretization.cpp b/lib/Dialect/IndexTree/Transforms/DomainConcretization.cpp new file mode 100644 index 00000000..4f178ff1 --- /dev/null +++ b/lib/Dialect/IndexTree/Transforms/DomainConcretization.cpp @@ -0,0 +1,486 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/IndexedMap.h" + +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" +#include "comet/Dialect/IndexTree/Passes.h" +#include "comet/Dialect/IndexTree/Patterns.h" + +using namespace mlir; +using namespace mlir::indexTree; +using namespace mlir::tensorAlgebra; + +namespace mlir { + namespace comet{ + #define GEN_PASS_DEF_INDEXTREEDOMAINCONCRETIZATION + #include "comet/Dialect/IndexTree/Passes.h.inc" + } +} + +struct ConcretizeTensorDomain : public OpRewritePattern { + ConcretizeTensorDomain(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + mlir::LogicalResult + liftAccessOp(mlir::Operation *dependent_op, + IndexTreeIndexToTensorOp access_op) const { + if(access_op->isBeforeInBlock(dependent_op)) + return success(); + + Value prev_access_value; + if((prev_access_value = access_op.getPrevDim())) + { + if(mlir::failed( + liftAccessOp( + dependent_op, + llvm::cast(prev_access_value.getDefiningOp()) + ) + )) + return failure(); + } + access_op->moveBefore(dependent_op); + return success(); + } + + mlir::LogicalResult + matchAndRewrite(IndexTreeTensorDomainOp domain_op, + mlir::PatternRewriter &rewriter) const override { + auto loc = domain_op->getLoc(); + auto context = rewriter.getContext(); + indexTree::DomainType domain_type = indexTree::DomainType::get(context); + uint32_t dim = domain_op.getDim(); + Value new_domain; + + Value tensor = domain_op.getTensor(); + SparseTensorConstructOp construct_op = tensor.getDefiningOp(); + if(construct_op) + { + //Domain comes from a sparse tensor (may still be dense) + int32_t rank = construct_op.getTensorRank(); + TensorFormatEnum format = construct_op.getDimensionFormats()[2 * dim].cast().getValue(); + + if(format == TensorFormatEnum::D) + { + Value max = construct_op.getOperand((8*rank) + 2 + dim); //TODO: Fix magic numbers + new_domain = rewriter.create(loc, domain_type, max, tensor, rewriter.getI32ArrayAttr({static_cast(dim)})); + } else + { + Value pos = construct_op.getOperand(4 * dim); + Value crd = construct_op.getOperand((4 * dim) + 1); + Value pos_size = construct_op.getOperand((4 * rank) + (4 * dim) + 1); + Value crd_size = construct_op.getOperand((4 * rank) + (4 * dim) + 2); + Value dim_size = construct_op.getOperand((8*rank) + 2 + dim); + Value parent = domain_op.getParent(); + if(!parent) + { + // Get associated index + IndexTreeIndicesOp index_op; + Operation* use = *(domain_op->user_begin()); + // TODO: Fix danger of infinite loop!!! + while(!(index_op = llvm::dyn_cast(use))) + { + use = *(use->user_begin()); + } + assert(index_op); + + if(dim == 0) + { + parent = nullptr; + } else + { + // Infer parent index variable + for(Operation* use : index_op->getUsers()) + { + IndexTreeIndexToTensorOp access_op = llvm::dyn_cast(use); + if(!access_op || access_op.getTensor() != tensor || access_op.getDim() != dim) + continue; + + parent = access_op.getPrevDim(); + IndexTreeIndexToTensorOp prev_access_op = + llvm::cast(parent.getDefiningOp()); + if(mlir::failed(this->liftAccessOp(domain_op, prev_access_op))) + return failure(); + + break; + } + } + } + new_domain = rewriter.create( + loc, domain_type, tensor, domain_op.getDimAttr(), + TensorFormatEnumAttr::get(context, format), + pos, crd, pos_size, crd_size, dim_size, parent); + } + } else if(llvm::isa(tensor.getType())) { + auto index_type = rewriter.getIndexType(); + Value dim_size = rewriter.create(loc, index_type, tensor, rewriter.getI32IntegerAttr(dim)); + + Value parent = domain_op.getParent(); + if(!parent) + { + // Get associated index + IndexTreeIndicesOp index_op; + Operation* use = *(domain_op->user_begin()); + // TODO: Fix danger of infinite loop!!! + while(!(index_op = llvm::dyn_cast(use))) + { + use = *(use->user_begin()); + } + assert(index_op); + + if(dim == 0) + { + parent = nullptr; + } else + { + // Infer parent index variable + for(Operation* use : index_op->getUsers()) + { + IndexTreeIndexToTensorOp access_op = llvm::dyn_cast(use); + if(!access_op || access_op.getTensor() != tensor || access_op.getDim() != dim) + continue; + + parent = access_op.getPrevDim(); + IndexTreeIndexToTensorOp prev_access_op = + llvm::cast(parent.getDefiningOp()); + if(mlir::failed(this->liftAccessOp(domain_op, prev_access_op))) + return failure(); + + break; + } + } + } + + new_domain = rewriter.create( + loc, + domain_type, + tensor, + dim_size, + rewriter.getUI32IntegerAttr(dim), + parent + ); + } else { + //Domain is dense + //TODO (alokvk2): Figure out if we need to take the root index variable or the allocation + //Right now I don't know how to get back to the root index variable. + auto tensor_type = llvm::cast(tensor.getType()); + auto max = tensor_type.getShape()[dim]; + Value max_val; + if(max < 0) { + auto prev = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(tensor.getDefiningOp()); + Value dim_val = rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(dim)); + max_val = rewriter.create(loc, rewriter.getIndexType(), tensor, dim_val); + rewriter.restoreInsertionPoint(prev); + } else { + max_val = rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(max)); + } + new_domain = rewriter.create(loc, domain_type, max_val, tensor, rewriter.getI32ArrayAttr({static_cast(dim)})); + } + rewriter.replaceOp(domain_op, new_domain); + return success(); + } +}; + + +struct SimplifyIntersectionOp : public OpRewritePattern { + SimplifyIntersectionOp(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + mlir::LogicalResult + matchAndRewrite(IndexTreeDomainIntersectionOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + Value first_domain = op->getOperand(0); + SmallVector domains; + SmallVector to_remove; + SmallVector tensors; + SmallVector dims; + SmallVector maximums; + IndexTreeDenseDomainOp operand_op; + for(auto operand : op.getDomains()) + { + if(!llvm::isa(operand.getDefiningOp())){ + return failure(); + } + + if((operand_op = operand.getDefiningOp())) + { + to_remove.push_back(operand_op.getOperation()); + auto operand_tensors = operand_op.getTensors(); + tensors.insert(tensors.end(), operand_tensors.begin(), operand_tensors.end()); + auto tensor_dims = operand_op.getDimsAttr(); + dims.insert(dims.end(), tensor_dims.begin(), tensor_dims.end()); + maximums.push_back(operand_op.getDimSize()); + } + else + { + domains.push_back(operand); + } + } + + if(domains.size() == 0) // All domains are dense + { + if(to_remove.size() > 1){ + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + auto index_type = rewriter.getIndexType(); + Value max = maximums[0]; + // TODO: Do we need this to check if the domains are compatible? + // for(Value new_max : maximums){ + // max = rewriter.create(loc, index_type, max, new_max); + // } + indexTree::DomainType domain_type = indexTree::DomainType::get(context); + Value new_domain = rewriter.create(loc, domain_type, max, tensors, rewriter.getArrayAttr(dims)); + rewriter.replaceOp(op, {new_domain}); + } else { + rewriter.replaceOp(op, {first_domain}); + } + } else if(domains.size() == 1) + { + // Remove intersection op completely + rewriter.replaceOp(op, {domains[0]}); + } else + { + // Keep only non-dense operands + if(domains.size() == op.getDomains().size() && op.getDimSize() != nullptr){ + return failure(); + } + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + indexTree::DomainType domain_type = indexTree::DomainType::get(context); + Value dim_size = llvm::dyn_cast(domains[0].getDefiningOp()).getDimensionSize(); + Value new_domain = rewriter.create(loc, domain_type, domains, dim_size); + rewriter.replaceOp(op, {new_domain}); + } + + // Delete newly unused values + for(Operation* unused_op : to_remove) + { + if(unused_op->use_empty()) + rewriter.eraseOp(unused_op); + } + + return success(); + } +}; + +struct SimplifyUnionOp : public mlir::OpRewritePattern { + SimplifyUnionOp(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + mlir::LogicalResult + matchAndRewrite(IndexTreeDomainUnionOp op, + mlir::PatternRewriter &rewriter) const override { + bool can_replace = false; + Operation* operand_op; + for(auto operand : op.getDomains()) + { + if(!llvm::isa(operand.getDefiningOp())){ + return failure(); + } + + if((operand_op = operand.getDefiningOp())){ + rewriter.replaceOp(op, {operand}); + can_replace = true; + break; + } + } + + if(can_replace) + { + for(auto operand : op.getDomains()) + { + operand_op = operand.getDefiningOp(); + if(operand_op->use_empty()) + rewriter.eraseOp(operand_op); + } + } else { + if(op.getDimSize() != nullptr){ + return failure(); + } + + Value dim_size = llvm::dyn_cast(op.getDomains()[0].getDefiningOp()).getDimensionSize(); + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + indexTree::DomainType domain_type = indexTree::DomainType::get(context); + Value new_op = rewriter.create(loc, domain_type, op.getDomains(), dim_size); + rewriter.replaceOp(op, {new_op}); + } + return success(); + } +}; + +struct InferOutputDomains : public OpRewritePattern { + InferOutputDomains(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + Value copyDomain(Value domain, + mlir::PatternRewriter &rewriter, + IRMapping& map, + Location loc, + llvm::SmallDenseMap& index_vars) const + { + Value new_domain; + Operation* domain_op = domain.getDefiningOp(); + if(llvm::isa(domain_op)) + { + auto intersection_domain_op = llvm::cast(domain_op); + for(Value subdomain : intersection_domain_op.getDomains()){ + copyDomain(subdomain, rewriter, map, loc, index_vars); + } + } + if(llvm::isa(domain_op)) + { + auto union_domain_op = llvm::cast(domain_op); + for(Value subdomain : union_domain_op.getDomains()){ + copyDomain(subdomain, rewriter, map, loc, index_vars); + } + } + + if(llvm::isa(domain_op)) + { + auto sparse_domain_op = llvm::cast(domain_op); + + // Ensure parent domain will also be copied. Otherwise create it + Value new_parent_domain = nullptr; + if(sparse_domain_op.getParent()) + { + auto index_to_tensor_op = sparse_domain_op.getParent().getDefiningOp(); + auto index_var = index_to_tensor_op.getIndex().getDefiningOp(); + if(index_vars.find(index_var) == index_vars.end()){ + new_parent_domain = copyDomain(index_var.getDomain(), rewriter, map, loc, index_vars); + index_vars.insert(std::make_pair(index_var.getResult(), index_var.getDomain())); + } + } + + // Clone without parent + new_domain = rewriter.create(loc, + domain_op->getResultTypes(), + sparse_domain_op.getTensor(), + sparse_domain_op.getDimAttr(), + sparse_domain_op.getFormatAttr(), + sparse_domain_op.getPos(), + sparse_domain_op.getCrd(), + sparse_domain_op.getPosSize(), + sparse_domain_op.getCrdSize(), + sparse_domain_op.getDimSize(), + nullptr); + map.map(sparse_domain_op, new_domain); + + if(new_parent_domain) + { + // Create or fold so multiple levels of nested domains are foleded into one + new_domain = rewriter.createOrFold(loc, + domain_op->getResultTypes(), + llvm::SmallVector{new_parent_domain, new_domain}, + sparse_domain_op.getDimSize()); + } + } else { + // Clone + new_domain = rewriter.clone(*domain_op, map)->getResult(0); + } + + auto new_domain_op = new_domain.getDefiningOp(); + for(auto arg : new_domain_op->getOperands()) + { + Operation* origin = arg.getDefiningOp(); + if(new_domain_op->getBlock() == origin->getBlock() && new_domain_op->isBeforeInBlock(origin)) + { + rewriter.updateRootInPlace(new_domain_op, [&]() { new_domain_op->moveAfter(origin); }); + rewriter.setInsertionPointAfter(new_domain_op); + } + } + return new_domain; + } + + mlir::LogicalResult + matchAndRewrite(IndexTreeSparseTensorOp op, + mlir::PatternRewriter &rewriter) const override { + for(auto domain : op.getDomains()) + { + if(!llvm::isa(domain.getDefiningOp())) + { + return failure(); + } + } + + // Get the LHSOperandOp which creates thise tensor + Value tensor = op->getResult(0); + IndexTreeLHSOperandOp lhs_op = nullptr; + for(Operation* op : tensor.getUsers()) + { + if(llvm::isa(op)) + { + lhs_op = llvm::cast(op); + } + } + + if(lhs_op == nullptr) + return failure(); + + + auto crds = lhs_op.getCrds(); + unsigned dims = (lhs_op.getNumOperands() - 1) / 2; + Value empty_domain = lhs_op.getOperand(0); + llvm::IndexedMap domains(empty_domain); + llvm::SmallDenseMap index_vars; + domains.resize(dims); + for(Value crd : crds){ + auto access_op = llvm::dyn_cast(crd.getDefiningOp()); + if(access_op == nullptr){ + return failure(); + } + auto index_op = llvm::dyn_cast(access_op.getIndex().getDefiningOp()); + if(index_op == nullptr){ + return failure(); + } + + Value domain = index_op.getDomain(); + index_vars.insert(std::make_pair(index_op.getResult(), domain)); + domains[access_op.getDim()] = domain; + } + + // Successfully matched! Cannot fail after this point. + auto loc = op->getLoc(); + auto context = rewriter.getContext(); + SmallVector new_args; + IRMapping map; + for(unsigned dim = 0; dim < dims; dim++){ + Value domain_copy = copyDomain(domains[dim], rewriter, map, loc, index_vars); + new_args.push_back(domain_copy); + } + auto new_tensor = rewriter.create(loc, op->getResult(0).getType(), new_args); + rewriter.replaceOp(op, new_tensor->getResults()); + return success(); + } +}; + +void indexTree::populateDomainConcretizationPatterns( + MLIRContext *context, RewritePatternSet &patterns) { + patterns.add(context); +} + +struct IndexTreeDomainConcretization : comet::impl::IndexTreeDomainConcretizationBase { + using IndexTreeDomainConcretizationBase::IndexTreeDomainConcretizationBase; + + void runOnOperation() { + mlir::RewritePatternSet domain_concretization_patterns(&getContext()); + indexTree::populateDomainConcretizationPatterns(&getContext(), domain_concretization_patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(domain_concretization_patterns)); + } +}; + +/// Apply the compressed workspace transformations on the index tree IR +std::unique_ptr mlir::comet::createIndexTreeDomainConcretizationPass() +{ + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Dialect/IndexTree/Transforms/IterationDomain.cpp b/lib/Dialect/IndexTree/Transforms/IterationDomain.cpp deleted file mode 100644 index 949ea52b..00000000 --- a/lib/Dialect/IndexTree/Transforms/IterationDomain.cpp +++ /dev/null @@ -1,190 +0,0 @@ -// -// Copyright 2022 Battelle Memorial Institute -// -// Redistribution and use in source and binary forms, with or without modification, -// are permitted provided that the following conditions are met: -// -// 1. Redistributions of source code must retain the above copyright notice, this list of conditions -// and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions -// and the following disclaimer in the documentation and/or other materials provided with the distribution. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED -// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE -// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// - -#include "comet/Dialect/IndexTree/Transforms/IterationDomain.h" -#include "comet/Dialect/IndexTree/Transforms/Tensor.h" - -using namespace std; -// Since multiple threads can lower different functions, -// we need one for each thread lowering. -thread_local std::vector> domains; - -IterDomain *IterDomain::makeDomain(Tensor *tensor, int dim) -{ - auto d = make_unique(tensor, dim); - auto p = d.get(); - domains.push_back(std::move(d)); - return p; -} - -IterDomain *IterDomain::conjunct(IterDomain *a, IterDomain *b) -{ - auto d = make_unique('*', a, b); - auto p = d.get(); - domains.push_back(std::move(d)); - return p; -} - -IterDomain *IterDomain::conjunct(std::vector &domains) -{ - assert(!domains.empty()); - auto d = domains[0]; - for (unsigned long i = 1; i < domains.size(); i++) - { - d = conjunct(d, domains[i]); - } - return d; -} - -bool IterDomain::equals(IterDomain *that) -{ - auto thisSimplified = this->getSimplified(); - auto thatSimplified = that->getSimplified(); - return thisSimplified == thatSimplified; -} - -IterDomain *IterDomain::getSimplified() -{ - if (getOp() == '*') - { - if (getLeft()->isDense()) - { - return getRight(); - } - else if (getRight()->isDense()) - { - return getLeft(); - } - } - - return this; -} - -std::string IterDomain::str() -{ - if (isLeafNode()) - { - string s = "(" + getTensor()->str() + "," + to_string(getDim()) + ")"; - return s; - } - else - { - assert(getLeft() != nullptr && getRight() != nullptr); - string s = getLeft()->str() + string(1, getOp()) + getRight()->str(); - return s; - } -} - -std::string IterDomain::getFormat() -{ - return getTensor()->getFormat(getDim()); -} - -bool IterDomain::isDense() -{ - return getFormat() == "D"; -} - -std::string BoolExpr::str() -{ - if (isTrue()) - { - return "true"; - } - - auto t = value.first; - assert(t != nullptr); - auto dim = value.second; - return "(" + t->str() + ", " + to_string(dim) + ")"; -} - -void BoolExpr::setTrue() -{ - isConstantTrue = true; -} - -void BoolExpr::setFalse() -{ - isConstantFalse = true; -} - -unique_ptr BoolExprManager::trueNode; - -std::vector> BoolExprManager::exprs; - -BoolExpr *BoolExprManager::getTrue() -{ - if (trueNode == nullptr) - { - trueNode = make_unique(); - trueNode->setTrue(); - } - - return trueNode.get(); -} - -BoolExpr *BoolExprManager::makeBoolExpr(BDDElem elem) -{ - auto expr = make_unique(elem); - BoolExpr *ret = expr.get(); - exprs.push_back(std::move(expr)); - return ret; -} -Tensor *IterDomain::getTensor() const -{ - return tensor; -} -void IterDomain::setTensor(Tensor *Tensor) -{ - tensor = Tensor; -} -int IterDomain::getDim() const -{ - return dim; -} -void IterDomain::setDim(int Dim) -{ - dim = Dim; -} -char IterDomain::getOp() const -{ - return op; -} -void IterDomain::setOp(char Op) -{ - op = Op; -} -IterDomain *IterDomain::getLeft() const -{ - return left; -} -void IterDomain::setLeft(IterDomain *Left) -{ - left = Left; -} -IterDomain *IterDomain::getRight() const -{ - return right; -} -void IterDomain::setRight(IterDomain *Right) -{ - right = Right; -} diff --git a/lib/Dialect/IndexTree/Transforms/IterationDomainInference.cpp b/lib/Dialect/IndexTree/Transforms/IterationDomainInference.cpp new file mode 100644 index 00000000..6b55909b --- /dev/null +++ b/lib/Dialect/IndexTree/Transforms/IterationDomainInference.cpp @@ -0,0 +1,130 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/ADT/StringSet.h" + +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" +#include "comet/Dialect/IndexTree/Passes.h" +#include "comet/Dialect/IndexTree/Patterns.h" + +using namespace mlir; +using namespace mlir::indexTree; + +namespace mlir { + namespace comet{ + #define GEN_PASS_DEF_INDEXTREEDOMAININFERENCE + #include "comet/Dialect/IndexTree/Passes.h.inc" + } +} + +struct IndexTreeDomainInference : comet::impl::IndexTreeDomainInferenceBase { + using IndexTreeDomainInferenceBase::IndexTreeDomainInferenceBase; + void runOnOperation() override; +}; + +struct InferIndexDomain : public OpRewritePattern { + InferIndexDomain(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + mlir::LogicalResult + matchAndRewrite(IndexTreeIndicesOp op, mlir::PatternRewriter &builder) const override { + if(op.getDomain()) + return failure(); + + Location loc = op.getLoc(); + auto context = builder.getContext(); + indexTree::DomainType domain_type = indexTree::DomainType::get(context); + + // Map operands to domains + llvm::SmallDenseMap operands_to_domains; + // Set of all compute operands + llvm::SmallPtrSet compute_ops; + for(Operation* tensor_access_op : op->getUsers()) + { + if(!llvm::isa(tensor_access_op)) + continue; + + for(Operation* operand_op : tensor_access_op->getUsers()) + { + if(!llvm::isa(operand_op)) + continue; + + auto tensor_val = llvm::cast(tensor_access_op).getTensor(); + unsigned dim = llvm::cast(tensor_access_op).getDim();; + Value domain = builder.create(loc, + domain_type, + tensor_val, + builder.getUI32IntegerAttr(dim), + tensorAlgebra::TensorFormatEnumAttr::get(context, tensorAlgebra::TensorFormatEnum::UNK), + nullptr); + + operands_to_domains.insert(std::pair(operand_op, domain)); + compute_ops.insert(operand_op->user_begin(), operand_op->user_end()); + break; + } + } + + llvm::SmallVector domains; + Value zero = builder.create(loc, builder.getIndexType(), builder.getIndexAttr(0)); + for(auto compute_op : compute_ops) + { + // Check if compute op needs intersection + auto itComputeOp = cast(compute_op); + auto semiringParts = itComputeOp.getSemiring().split('_'); + + if(!Semiring_intersectOps.contains(semiringParts.second)){ + for(auto operand_op_val : itComputeOp.getRhs()) + { + auto operand_op = operand_op_val.getDefiningOp(); + if(operands_to_domains.find(operand_op) != operands_to_domains.end()) + domains.push_back(operands_to_domains[operand_op]); + } + } else { + SmallVector intersection_domains; + for(auto operand_op_val : itComputeOp.getRhs()) + { + auto operand_op = operand_op_val.getDefiningOp(); + if(operands_to_domains.find(operand_op) != operands_to_domains.end()) + intersection_domains.push_back(operands_to_domains[operand_op]); + } + if(intersection_domains.size() > 1) + domains.push_back(builder.create( + loc, domain_type, intersection_domains, nullptr)); + else + domains.push_back(intersection_domains[0]); + } + } + + Value final_domain; + if(domains.size() > 1) { + final_domain = builder.create(loc, + domain_type, domains, nullptr); + } else { + final_domain = domains[0]; + } + + indexTree::IndexNodeType index_node_type = indexTree::IndexNodeType::get(context); + builder.replaceOpWithNewOp( + op, index_node_type, op.getParent(), final_domain); + return success(); + } +}; + +void mlir::indexTree::populateDomainInferencePatterns( + MLIRContext *context, RewritePatternSet &patterns) { + patterns.add(context); +} + +void IndexTreeDomainInference::runOnOperation(){ + mlir::RewritePatternSet domain_inference_patterns(&getContext()); + populateDomainInferencePatterns(&getContext(), domain_inference_patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(domain_inference_patterns)); +} + +/// Apply the compressed workspace transformations on the index tree IR +std::unique_ptr mlir::comet::createIndexTreeDomainInferencePass() +{ + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Dialect/IndexTree/Transforms/SymbolicCompute.cpp b/lib/Dialect/IndexTree/Transforms/SymbolicCompute.cpp new file mode 100644 index 00000000..750aa73a --- /dev/null +++ b/lib/Dialect/IndexTree/Transforms/SymbolicCompute.cpp @@ -0,0 +1,298 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/IndexedMap.h" + +#include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" +#include "comet/Dialect/TensorAlgebra/IR/TADialect.h" +#include "comet/Dialect/IndexTree/Passes.h" + +using namespace mlir; +using namespace mlir::indexTree; +using namespace mlir::tensorAlgebra; + +namespace mlir { + namespace comet{ + #define GEN_PASS_DEF_INDEXTREESYMBOLICCOMPUTEPASS + #include "comet/Dialect/IndexTree/Passes.h.inc" + } +} + +struct CreateSymbolicTree : public OpRewritePattern { + CreateSymbolicTree(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + Value copyDomain(Value domain, mlir::PatternRewriter &rewriter, + Location loc, IRMapping& map, + llvm::SmallDenseMap, Value>& tensor_to_node, + Value parent_node = nullptr) const + { + Value new_domain; + Operation* domain_op = domain.getDefiningOp(); + if(llvm::isa(domain_op)) + { + auto intersection_domain_op = llvm::cast(domain_op); + for(Value subdomain : intersection_domain_op.getDomains()){ + copyDomain(subdomain, rewriter, loc, map, tensor_to_node, parent_node); + } + } + else if(llvm::isa(domain_op)) + { + auto union_domain_op = llvm::cast(domain_op); + for(Value subdomain : union_domain_op.getDomains()){ + copyDomain(subdomain, rewriter, loc, map, tensor_to_node, parent_node); + } + } + + if(llvm::isa(domain_op)) + { + auto sparse_domain_op = llvm::cast(domain_op); + auto tensor = sparse_domain_op.getTensor(); + int32_t dim = sparse_domain_op.getDim(); + Value parent = nullptr; + if(dim > 0){ + // TODO: Determine parent of this op + // Will be needed for 3 dimensional sparse tensor outputs + if(!parent_node) + { + assert(tensor_to_node[std::make_pair(tensor, dim-1)] != nullptr); + parent_node = tensor_to_node[std::make_pair(tensor, dim-1)]; + } + + + auto tensor_access_op = rewriter.create( + loc, + TypeRange({rewriter.getIndexType(), rewriter.getIndexType()}), + tensor, + parent_node, + dim-1, + nullptr); + parent = tensor_access_op.getCrd(); + } + new_domain = rewriter.create(loc, + domain_op->getResultTypes(), + sparse_domain_op.getTensor(), + sparse_domain_op.getDimAttr(), + sparse_domain_op.getFormatAttr(), + sparse_domain_op.getPos(), + sparse_domain_op.getCrd(), + sparse_domain_op.getPosSize(), + sparse_domain_op.getCrdSize(), + sparse_domain_op.getDimSize(), + parent); + map.map(sparse_domain_op, new_domain); + } else { + // Clone + new_domain = rewriter.clone(*domain_op, map)->getResult(0); + } + return new_domain; + } + + void createMapping(IndexTreeIndicesOp node, Value domain, llvm::SmallDenseMap, Value>& tensor_to_node) const + { + Operation* domain_op = domain.getDefiningOp(); + if(llvm::isa(domain_op)) + { + auto intersection_domain_op = llvm::cast(domain_op); + for(Value subdomain : intersection_domain_op.getDomains()){ + createMapping(node, subdomain, tensor_to_node); + } + } else if(llvm::isa(domain_op)) + { + auto union_domain_op = llvm::cast(domain_op); + for(Value subdomain : union_domain_op.getDomains()){ + createMapping(node, subdomain, tensor_to_node); + } + } else if(llvm::isa(domain_op)) + { + auto sparse_domain_op = llvm::cast(domain_op); + auto tensor = sparse_domain_op.getTensor(); + int32_t dim = sparse_domain_op.getDim(); + tensor_to_node.insert(std::make_pair( + std::make_pair(tensor, dim), + node.getOutput() + )); + } else if(llvm::isa(domain_op)) + { + auto dense_domain_op = llvm::cast(domain_op); + auto tensors = dense_domain_op.getTensors(); + auto dims = dense_domain_op.getDims(); + unsigned i = 0; + for(auto tensor : tensors) + { + int32_t dim = dims[i].cast().getValue().getSExtValue(); + tensor_to_node.insert(std::make_pair( + std::make_pair(tensor, dim), + node.getOutput() + )); + i += 1; + } + } + } + + mlir::LogicalResult + match(IndexTreeSparseTensorOp op) const override { + for(auto domain : op.getDomains()) + { + Operation* domain_op = domain.getDefiningOp(); + if(domain_op->hasTrait()) + { + return success(); + } + } + return failure(); + } + + void + rewrite(IndexTreeSparseTensorOp it_tensor_decl_op, + mlir::PatternRewriter &rewriter) const override { + auto loc = it_tensor_decl_op->getLoc(); + auto context = rewriter.getContext(); + + // Declare Sparse Domains and allocate position vectors for each dimension + llvm::SmallDenseMap symbolic_domains; + auto domain_type = SymbolicDomainType::get(context); + auto index_type = rewriter.getIndexType(); + Value cur_pos_size = nullptr; + unsigned dim = 0; + for(Value domain : it_tensor_decl_op.getDomains()){ + Operation* domain_op = domain.getDefiningOp(); + + if(domain_op->hasTrait()) + { + if(dim == 0) + { + // If this is the first dimension, the previous dimension could be expanded as "1" + cur_pos_size = rewriter.create(loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + } + + Value num_rows = nullptr; + BoolAttr is_dynamic = rewriter.getBoolAttr(0); + if(cur_pos_size == nullptr){ + is_dynamic = rewriter.getBoolAttr(1); + num_rows = rewriter.create(loc, index_type, rewriter.getIndexAttr(0)); + } else { + num_rows = cur_pos_size; + } + auto concrete_domain = llvm::cast(domain_op); + Value dim_size = concrete_domain.getDimensionSize(); + Value symbolic_domain = rewriter.create(loc, domain_type, dim_size, num_rows, is_dynamic); + symbolic_domains.insert(std::make_pair(domain, symbolic_domain)); + } + + if(llvm::isa(domain_op)) + { + Value dim_size = llvm::cast(domain_op).getDimSize(); + if(dim != 0) + cur_pos_size = rewriter.create(loc, rewriter.getI32Type(), cur_pos_size, dim_size); + else + cur_pos_size = dim_size; + } else + { + cur_pos_size = nullptr; + } + dim += 1; + } + + auto itree_op = rewriter.create(loc, llvm::SmallVector(symbolic_domains.size(), domain_type)); + Region* body = &itree_op.getRegion(); + loc = body->getLoc(); + Block* block = rewriter.createBlock(body); + rewriter.setInsertionPointToStart(block); + + indexTree::IndexTreeType tree_type = indexTree::IndexTreeType::get(context); + Value parent = rewriter.create(loc, tree_type); + indexTree::IndexNodeType index_node_type = indexTree::IndexNodeType::get(context); + IRMapping map; + llvm::SmallDenseMap, Value> tensor_to_node; + SmallVector yield_args; + Value prev_dim = parent; + bool is_unique = true; + for (Value domain : it_tensor_decl_op.getDomains()) + { + Operation* domain_op = domain.getDefiningOp(); + if(llvm::isa(domain_op)){ + Value parent_node = nullptr; // By construction, nested domains are direct parents of each other? + auto nested_domain = llvm::cast(domain_op); + for(Value subdomain : nested_domain.getDomains()){ + Value new_domain = copyDomain(subdomain, rewriter, loc, map, tensor_to_node, parent_node); + indexTree::IndexTreeIndicesOp index_node_op = rewriter.create(loc, index_node_type, parent, new_domain); + createMapping(index_node_op, subdomain, tensor_to_node); + parent = index_node_op.getOutput(); + parent_node = index_node_op.getOutput(); + } + is_unique = false; // One reduction index variable means all future inserts are non-unique + } else + { + Value new_domain = copyDomain(domain, rewriter, loc, map, tensor_to_node); + indexTree::IndexTreeIndicesOp index_node_op = rewriter.create(loc, index_node_type, parent, new_domain); + createMapping(index_node_op, domain, tensor_to_node); + parent = index_node_op.getOutput(); + } + + if(domain_op->hasTrait()) + { + Value symbolic_domain = symbolic_domains[domain]; + Value new_symbolic_domain = rewriter.create( + loc, + domain_type, + parent, + symbolic_domain, + rewriter.getBoolAttr(is_unique) + ); + new_symbolic_domain = rewriter.create( + loc, + domain_type, + prev_dim, + new_symbolic_domain, + rewriter.getBoolAttr(!is_unique) + ); + yield_args.push_back(new_symbolic_domain); + } + + prev_dim = parent; + } + rewriter.create(loc, TypeRange(), yield_args); + + rewriter.setInsertionPointAfter(itree_op); + SmallVector args; + unsigned i = 0; + for (Value domain : it_tensor_decl_op.getDomains()) + { + if(domain.getDefiningOp()->hasTrait()) + { + args.push_back(itree_op->getResult(i)); + i += 1; + } else + { + args.push_back(domain); + } + } + auto new_tensor = rewriter.create(loc, it_tensor_decl_op->getResultTypes(), args); + rewriter.replaceOp(it_tensor_decl_op, new_tensor->getResults()); + return; + } +}; + +struct IndexTreeSymbolicComputePass : comet::impl::IndexTreeSymbolicComputePassBase { + using IndexTreeSymbolicComputePassBase::IndexTreeSymbolicComputePassBase; + + void runOnOperation() { + mlir::RewritePatternSet sp_output_patterns(&getContext()); + sp_output_patterns.add(&getContext()); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(sp_output_patterns)); + } +}; + +/// Apply the compressed workspace transformations on the index tree IR +std::unique_ptr mlir::comet::createIndexTreeSymbolicComputePass() +{ + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Dialect/IndexTree/Transforms/WorkspaceTransforms.cpp b/lib/Dialect/IndexTree/Transforms/WorkspaceTransforms.cpp index bde85887..cc6d6a0a 100644 --- a/lib/Dialect/IndexTree/Transforms/WorkspaceTransforms.cpp +++ b/lib/Dialect/IndexTree/Transforms/WorkspaceTransforms.cpp @@ -26,6 +26,7 @@ #include "comet/Dialect/IndexTree/IR/IndexTreeDialect.h" #include "comet/Dialect/IndexTree/Passes.h" +#include "comet/Dialect/IndexTree/Patterns.h" #include "comet/Dialect/TensorAlgebra/IR/TADialect.h" #include "comet/Dialect/Utils/Utils.h" @@ -33,6 +34,8 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + #include "llvm/Support/Debug.h" #include @@ -46,6 +49,7 @@ #include #include #include +#include using namespace mlir; using namespace mlir::bufferization; @@ -65,14 +69,6 @@ using llvm::StringRef; // *********** For debug purpose *********// -const bool compressedworkspace = true; - -struct dimInTensor -{ - int dim; - int tensorId; - int dimOrder; -}; /// Apply workspace transformation on the lhs /// Consider CSR first @@ -90,1003 +86,265 @@ struct dimInTensor /// Apply workspace transformations on the ta.tc and tc.elews_mul namespace { - struct WorkspaceTransformsPass - : public PassWrapper> - { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WorkspaceTransformsPass) - void runOnOperation() override; - void WorkspaceTransforms(mlir::func::FuncOp function); - }; - struct IndexTreeWorkspaceTransformationsPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IndexTreeWorkspaceTransformationsPass) void runOnOperation() override; - void CompressedWorkspaceTransforms(mlir::func::FuncOp function); }; - } /// end anonymous namespace. -/// Need a function, dfs traverse the itree -/// get the sparse index that is sparse in the output -std::vector getSparseDimsOutput(std::vector> opFormats, std::vector> opPerms) -{ - std::vector sparseDimsOutput; - assert(opFormats.size() > 0 && "opFormats.size() less than 0\n"); - std::vector outputFormat = opFormats[opFormats.size() - 1]; - std::vector outputPerm = opPerms[opPerms.size() - 1]; - for (unsigned int i = 0; i < outputFormat.size(); i++) - { - if (outputFormat[i].compare("D") != 0) - { /// sparse dim - sparseDimsOutput.push_back(outputPerm[i]); - comet_debug() << "sparse dim in output: " << outputPerm[i] << " with format: " << outputFormat[i] << "\n"; +struct TransformSparseOutput : public OpRewritePattern { + TransformSparseOutput(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0.5) {} + + mlir::LogicalResult + matchAndRewrite(IndexTreeComputeOp compute_op, mlir::PatternRewriter &rewriter) const override { + IndexTreeLHSOperandOp lhs_op = compute_op.getLhs().getDefiningOp(); + Value old_output = lhs_op.getTensor(); + + // Check to see if output is sparse + if(!llvm::isa(old_output.getType())) + return failure(); + + // Check to see if there are "redundant" inserts + llvm::SmallDenseMap index_vars; + for(auto pos : lhs_op.getPos()) { + auto index_to_tensor = pos.getDefiningOp(); + if(index_to_tensor){ + index_vars.insert(std::make_pair( + index_to_tensor.getIndex(), + index_to_tensor + )); + } } - } - return sparseDimsOutput; -} - -/// get the sparse index that has sparse format in at least two input tensors -/// which tensor, which dimension. use std::pair represent the information -std::vector getSparseDimsInput(std::vector> opFormats, std::vector> opPerms) -{ - std::vector sparseDimsInput; - - std::vector> inputFormats = {opFormats.begin(), opFormats.end() - 1}; - std::vector> inputPerms = {opPerms.begin(), opPerms.end() - 1}; - - /// Get all dims in input tensors - std::vector allPermsInput = getUnionOf2Dvector(inputPerms); - comet_debug() << " allPermsInput.size(): " << allPermsInput.size() << "\n"; - comet_debug() << "allPermsInput: "; - for (auto n : allPermsInput) - { - comet_debug() << n << " "; - } - comet_debug() << "\n"; - for (unsigned int i = 0; i < allPermsInput.size(); i++) - { - int cur_index = allPermsInput[i]; - comet_debug() << " cur_index: " << cur_index << "\n"; - /// Get the format of cur_index from each input tensor - std::vector cur_formats; - std::vector tensor_ids; - std::vector dim_orders; - for (unsigned int j = 0; j < inputPerms.size(); j++) - { - unsigned int whichFormat = findIndexInVector(inputPerms[j], cur_index); - if (whichFormat < inputPerms[j].size()) - { /// found - std::string format = inputFormats[j][whichFormat]; - cur_formats.push_back(format); - tensor_ids.push_back(j); - dim_orders.push_back(whichFormat); - } + // Find last output dimension + Value parent = compute_op.getParent(); + auto node = parent.getDefiningOp(); + while(index_vars.find(parent) == index_vars.end()) { + parent = node.getParent(); + node = parent.getDefiningOp(); } - comet_debug() << " cur_formats.size(): " << cur_formats.size() << "\n"; - comet_debug() << "cur_formats: "; - for (auto n : cur_formats) - { - comet_debug() << n << " "; + // node contains output domain + + // Find output dimensions after reduction variable + // to include in workspace + unsigned workspace_rank = 0; + llvm::SmallVector accesses; + llvm::SmallVector dims; + while(index_vars.find(parent) != index_vars.end()){ + auto access_op = index_vars[parent]; + accesses.push_back(access_op); + dims.push_back(access_op.getDim()); + workspace_rank++; + + parent = node.getParent(); + node = parent.getDefiningOp(); + if(!node){ + return failure(); + } } - comet_debug() << "\n"; - - /// check if there is sparse format in cur_formats vector - std::vector cur_sparse_formats; - std::vector sparse_tensor_ids; - std::vector sparse_dim_orders; - for (unsigned int j = 0; j < cur_formats.size(); j++) + //Match success! + // Parent contains reduction variable + + // Declare the workspace outside of the tree + auto loc = compute_op.getLoc(); + auto tree_op = compute_op->getParentOfType(); + rewriter.setInsertionPoint(tree_op); + Type element_type = llvm::cast(old_output.getType()).getElementType(); + llvm::SmallVector dim_sizes(workspace_rank, ShapedType::kDynamic); + Type workspace_type = WorkspaceType::get(compute_op.getContext(), element_type, dim_sizes); + std::reverse(dims.begin(), dims.end()); + Value workspace = rewriter.create(loc, workspace_type, old_output, rewriter.getI32ArrayAttr(dims)); + + + // Clean the workspace before use + rewriter.setInsertionPoint(node); + Value clean_workspace = rewriter.create(loc, workspace_type, node.getParent(), workspace); + + // Create new compute op + auto context = getContext(); + rewriter.setInsertionPoint(compute_op); + Type index_type = rewriter.getIndexType(); + llvm::SmallVector pos; + llvm::SmallVector crds; + std::reverse(accesses.begin(), accesses.end()); + int32_t dim = 0; + Value prev_dim = nullptr; + for(auto access_op : accesses) { - comet_debug() << " cur_formats[" << j << "]: " << cur_formats[j] << "\n"; - if (cur_formats[j].compare("D") != 0) - { /// sparse format - cur_sparse_formats.push_back(cur_formats[j]); - sparse_tensor_ids.push_back(tensor_ids[j]); - sparse_dim_orders.push_back(dim_orders[j]); - comet_debug() << " sparse dim in format: " << cur_index << " with format: " << cur_formats[j] << "\n"; - } + auto new_access_op = rewriter.create( + loc, + TypeRange({index_type, index_type}), + clean_workspace, + access_op.getIndex(), + rewriter.getUI32IntegerAttr(dim), + prev_dim + ); + + pos.push_back(new_access_op.getPos()); + crds.push_back(new_access_op.getCrd()); + prev_dim = new_access_op.getPos(); + dim++; } - if (cur_sparse_formats.size() > 1) - { /// More than one sparse format - struct dimInTensor dim_in_tensor; - dim_in_tensor.dim = cur_index; - dim_in_tensor.tensorId = sparse_tensor_ids[0]; /// Any sparse tensor is ok - dim_in_tensor.dimOrder = sparse_dim_orders[0]; - sparseDimsInput.push_back(dim_in_tensor); + Type operand_type = OperandType::get(context); + Value new_lhs = rewriter.create( + loc, + operand_type, + clean_workspace, + pos, + crds + ); + Value new_workspace = rewriter.create( + loc, + workspace_type, + compute_op.getParent(), + new_lhs, + compute_op.getRhs(), + compute_op.getSemiringAttr() + ); + + pos.clear(); + crds.clear(); + dim = 0; + prev_dim = nullptr; + for(auto access_op : accesses) + { + auto new_access_op = rewriter.create( + loc, + TypeRange({index_type, index_type}), + new_workspace, + access_op.getIndex(), + rewriter.getUI32IntegerAttr(dim), + prev_dim + ); + + pos.push_back(new_access_op.getPos()); + crds.push_back(new_access_op.getCrd()); + prev_dim = new_access_op.getPos(); + dim++; } - } - comet_debug() << "sparseDimsInput: "; - for (auto n : sparseDimsInput) - { - comet_debug() << "(" << n.dim << ", " << n.tensorId << ", " << n.dimOrder << ") "; + Value new_rhs = rewriter.create( + loc, + operand_type, + new_workspace, + pos, + crds + ); + + rewriter.replaceOpWithNewOp( + compute_op, + old_output.getType(), + compute_op.getParent(), + compute_op.getLhs(), + ValueRange{new_rhs,}, + "noop_noop" + ); + + return success(); } - comet_debug() << "\n"; - return sparseDimsInput; -} - -/// Split one indicesOp into several one, i.e. each computeOp has its own parent op -/// i -> j -> V=0;V=A;W=V*B ===> i -> j -> V=0; -/// -> j -> V=A; -/// -> j -> W=V*B -void splitIndicesOp(Operation *needSplitNode, Value denseIndicesOp, OpBuilder &builder, Location loc) -{ - while (isa(needSplitNode)) - { - - comet_pdump(needSplitNode); - /// check how many operands, split into many operands. - indexTree::IndexTreeIndicesOp indicesOp = dyn_cast(needSplitNode); - - comet_vdump(indicesOp); - - Operation *indicesOpFirstUsers = *(indicesOp.getOperation()->getResult(0).getUsers().begin()); - comet_pdump(indicesOpFirstUsers); - - builder.setInsertionPoint(indicesOpFirstUsers); - comet_debug() << "\n"; - - if (needSplitNode != denseIndicesOp.getDefiningOp()) - { - ArrayAttr indices = indicesOp.getIndices(); - - comet_debug() << " indicesOp.getOperation()->getNumOperands(): " << indicesOp.getOperation()->getNumOperands() << "\n"; - std::vector operands; - std::vector newIndicesOp; - for (unsigned int i = 0; i < indicesOp.getOperation()->getNumOperands(); i++) - { - operands.push_back(indicesOp.getOperation()->getOperand(i)); - - comet_vdump(indicesOp.getOperation()->getOperand(i)); - comet_vdump(operands[i]); - auto i64Type = builder.getI64Type(); - Value t1 = builder.create(loc, i64Type, operands[i], indices); - - comet_debug() << "New IndexTreeIndicesOp added:\n"; - comet_vdump(t1); - newIndicesOp.push_back(t1); - } - - /// put it here - comet_debug() << " finished calling replacereplaceOperands \n"; - /// This parentIndicesOp is the operation that need to be splitted next time - -#ifdef DEBUG_MODE_WorkspaceTransformsPass - comet_vdump(indicesOp.getOperation()->getResult(0)); - for (auto ppp : indicesOp.getOperation()->getResult(0).getUsers()) - { - - comet_pdump(ppp); - } -#endif - - Operation *parentIndicesOp = *(indicesOp.getOperation()->getResult(0).getUsers().begin()); - - comet_pdump(parentIndicesOp); - - replaceOperands(needSplitNode, newIndicesOp); - needSplitNode = parentIndicesOp; +}; - comet_debug() << " plan to erase the following Op\n"; - comet_debug() << " Indices operations:\n"; - comet_vdump(indicesOp); - comet_debug() << " Split Nodes:\n"; - comet_pdump(needSplitNode); - comet_debug() << " Indices op first users:\n"; - comet_pdump(indicesOpFirstUsers); - indicesOp.erase(); - } - else - { - comet_debug() << "\n"; - break; +struct MoveInvariantComputeOp : public OpRewritePattern { + MoveInvariantComputeOp (MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/2) {} + + mlir::LogicalResult + matchAndRewrite(IndexTreeComputeOp compute_op, mlir::PatternRewriter &rewriter) const override { + // Collect all indices used in this compute expression + llvm::SmallDenseSet used_indices; + IndexTreeLHSOperandOp lhs_op = compute_op.getLhs().getDefiningOp(); + for(auto pos : lhs_op.getPos()) { + auto index_to_tensor = pos.getDefiningOp(); + if(!index_to_tensor) + return failure(); + used_indices.insert(index_to_tensor.getIndex()); } - } - comet_debug() << "\n"; -} - -void removeRedundantIndices(std::vector newComputeOps, - std::map indexValueMap, - int denseDimInOutput, - OpBuilder &builder, - Location loc) -{ - - /// Check whether need to remove redundant indices or not - /// Get the - - /// -------Remove redundant indices------------- - /// For C, the 1st dim i is Dense, the second dim j is sparse. - /// ---- the index including i and before i is not included - mlir::Value denseIndicesOp = indexValueMap[denseDimInOutput]; - /// The indices after denseIndicesOp need to be splitted - /// start from the computeOp, - /// Finished one level - /// Only one User, because it's a tree structure, the leaf only has one parent - assert(newComputeOps[0].getDefiningOp()->getResult(0).hasOneUse() && " the computeOp has more than one users\n"); - /// Get the only one user - Operation *onlyUser = *(newComputeOps[0].getDefiningOp()->getResult(0).getUsers().begin()); - comet_pdump(onlyUser); - - /// needSplitNode is the parent node of the "denseIndicesOp" - Operation *needSplitNode = onlyUser; - /// iterate until the - /// call splitIndicesOp function to split indicesOp until latest "root" - splitIndicesOp(needSplitNode, denseIndicesOp, builder, loc); - comet_debug() << "\n"; - - /// Remove the indices for each tensor - /// iterate over all itComputeOps, get the indices for each tensor - for (auto n : newComputeOps) - { - /// get allPerms, put all indices id into a vector, - /// iterater up until reach the root noe, if the index of indicesOp is not in the vector - /// remove this one: set the operand of the parent of the indicesOp into current op - - comet_debug() << " current computeOp: \n"; - comet_vdump(n); - ArrayAttr allperms_rhs = dyn_cast(n.getDefiningOp()->getOperand(0).getDefiningOp()).getAllPerms(); - std::vector> allpermsInt_rhs = convertArrayAttrIntTo2DVector(allperms_rhs); - std::vector permsInt = getUnionOf2Dvector(allpermsInt_rhs); - comet_debug() << " print permsInt: "; - for (auto p : permsInt) - { - comet_debug() << p << " "; + + auto rhs_operands = compute_op.getRhs(); + for(Value rhs : rhs_operands) { + IndexTreeOperandOp operand_op = rhs.getDefiningOp(); + for(auto pos : operand_op.getPos()) { + auto index_to_tensor = pos.getDefiningOp(); + if(!index_to_tensor) + return failure(); + used_indices.insert(index_to_tensor.getIndex()); + } } - comet_debug() << "\n"; - - mlir::Value computeOp = n; - - /// iterate over the IndexTreeIndicesOp; - mlir::Value computeOpParent; /// computeOpParent is IndexTreeIndicesOp - - comet_vdump(n); - assert(n.getDefiningOp()->getResult(0).hasOneUse() && " indicesOp has more than one user\n"); - Operation *computeOpParentPointer = *(n.getDefiningOp()->getResult(0).getUsers().begin()); - computeOpParent = computeOpParentPointer->getResult(0); - - comet_pdump(computeOpParentPointer); - comet_vdump(computeOpParent); - - while (!isRealRoot(computeOpParent.getDefiningOp())) - { - comet_vdump(computeOpParent); - if (isa(computeOpParent.getDefiningOp())) - { - comet_debug() << " indicesOp's parent can not be computeOp\n"; - } - else if (isa(computeOpParent.getDefiningOp())) - { - comet_debug() << " indicesOp's parent is IndexTreeOp\n"; + // We want to find all of the indices that this compute op is nested under + // and check if they are used in this compute expression. Every time we come + // across an unused index, the index nodes that we have seen so far need to be copied + // to form a new branch of the tree. We also keep track of the parent at the fork + llvm::SmallVector seen_indices; + llvm::SmallVector indices_to_copy; + Value parent = compute_op.getParent(); + Value fork = parent; + IndexTreeIndicesOp node = parent.getDefiningOp(); + while(node) { + if(used_indices.find(parent) != used_indices.end()) { + // Used index variable + seen_indices.push_back(parent); + } else { + // Unused index variable + fork = node.getParent(); + indices_to_copy.insert(indices_to_copy.begin(), seen_indices.rbegin(), seen_indices.rend()); + seen_indices.clear(); } - else if (isa(computeOpParent.getDefiningOp())) - { - /// get the indices integer, to see if it is in permsInt - /// if yes, don't remove - /// if no, remove: - indexTree::IndexTreeIndicesOp curIndicesOp = dyn_cast(computeOpParent.getDefiningOp()); - comet_debug() << " \n"; - ArrayAttr idsArrayAttr = curIndicesOp.getIndices(); /// should be 1D vector - std::vector idsVec; - for (auto n : idsArrayAttr) - { - idsVec.push_back(n.cast().getInt()); - } - comet_debug() << " print idsVec: "; - for (auto p : idsVec) - { - comet_debug() << p << " "; - } - comet_debug() << "\n"; - - assert(idsVec.size() == 1 && " indicesOp contain more than 1 index\n"); - bool isNeedRemove = false; - for (auto n : idsVec) - { /// only 1 index actually, because each indicesOp contain one index - if (std::find(permsInt.begin(), permsInt.end(), n) != permsInt.end()) - { - /// found - isNeedRemove = false; - } - else - { /// the index in curIndicesOp is not found in the computeOp indices - isNeedRemove = true; - } - } - - /// if curIndicesOp is the "real root" of the index tree (has only one user) - /// contain more than 1 index - if (curIndicesOp.getOperation()->getNumOperands() > 1 && curIndicesOp.getOperation()->getResult(0).hasOneUse() && isa(*(curIndicesOp.getOperation()->getResult(0).getUsers().begin()))) - { - isNeedRemove = false; - } - comet_debug() << " isNeedRemove = " << isNeedRemove << "\n"; - - if (isNeedRemove) - { - assert(curIndicesOp.getOperation()->getResult(0).hasOneUse() && " indicesOp has more than one user\n"); - Operation *curIndicesOpParent = *(curIndicesOp.getOperation()->getResult(0).getUsers().begin()); - comet_vdump(computeOpParent); - comet_pdump(curIndicesOpParent); - - computeOpParent.replaceAllUsesWith(computeOp); /// replace all uses of the indexOp with the new indecesOp - computeOp = computeOpParent; - computeOpParent.getDefiningOp()->erase(); /// erase the previous indecesOp - computeOpParent = curIndicesOpParent->getResult(0); - } - else - { -#ifdef DEBUG_MODE_WorkspaceTransformsPass - comet_vdump(curIndicesOp); - int count = 0; - for (auto p : curIndicesOp.getOperation()->getResult(0).getUsers()) - { - comet_pdump(p); - count++; - } - comet_debug() << " count: " << count << "\n"; -#endif - assert(curIndicesOp.getOperation()->getResult(0).hasOneUse() && " indicesOp has more than one user\n"); - Operation *curIndicesOpParent = *(curIndicesOp.getOperation()->getResult(0).getUsers().begin()); - - comet_pdump(curIndicesOpParent); - - computeOp = computeOpParent; - computeOpParent = curIndicesOpParent->getResult(0); - } - } + parent = node.getParent(); + node = parent.getDefiningOp(); } - } /// end for n -} - -std::vector CompressedWorkspaceOutput(std::vector sparseDimsOutput, - indexTree::IndexTreeComputeOp itComputeOp, - std::vector> opFormats, - std::vector> opPerms, - std::map indexValueMap, - OpBuilder &builder, indexTree::IndexTreeOp op) -{ - Location loc = op.getLoc(); - auto comp_worksp_opt = builder.getBoolAttr(compressedworkspace); - int sparseDimOutput = -1; - int sparseDimOrderInOutput = -1; - int denseDimInOutput = -1; - auto i64Type = builder.getI64Type(); - - for (unsigned int j = 0; j < opFormats[opFormats.size() - 1].size(); j++) - { - /// sparse dimension - if (opFormats[opFormats.size() - 1][j].compare("D") != 0) - { - sparseDimOutput = opPerms[opPerms.size() - 1][j]; - sparseDimOrderInOutput = j; + if(fork == compute_op.getParent()) { + return failure(); // Match failed, no indces to move. } - else /// dense dimension - denseDimInOutput = opPerms[opPerms.size() - 1][j]; - } - comet_debug() << " " << sparseDimOutput << "\n"; - - /// 3. Find the ta.itIndices op which represents sparseDimOutput - /// Find its parent ... - Value sparseIndicesOp = indexValueMap[sparseDimOutput]; - - comet_vdump(sparseIndicesOp); - comet_debug() << " sparseDimOrderInOutput: " << sparseDimOrderInOutput << "\n"; - Value sparseDimsPerent = indexValueMap[sparseDimOrderInOutput - 1]; - comet_debug() << " sparseDimsPerent: \n"; - comet_vdump(sparseDimsPerent); - - /// Cij = Aik * Bkj ==> - /// ComputeNode(c1): Wj = 0; - /// ComputeNode(c2): Wj += Aik * Bkj; - /// ComputeNode(c3): Cij = Wj - std::vector tensors; - getTensorsOfComputeOp(itComputeOp.getOperation()->getResult(0), tensors); - - /// 4. create W, j dim size of - /// Value outputItComputeOp = itComputeOp.getOperation()->getOperand(itComputeOp.getOperation()->getNumOperands() - 1).getDefiningOp()->getOperand(0); - /// new version - Value outputItComputeOp = tensors[tensors.size() - 1]; - comet_vdump(outputItComputeOp); - - std::vector w_lbls_value = {outputItComputeOp.getDefiningOp()->getOperand(sparseDimOrderInOutput)}; - - comet_vdump(outputItComputeOp.getDefiningOp()->getOperand(sparseDimOrderInOutput)); - std::string w_format = "Dense"; /// tensor - auto w_type = RankedTensorType::get({mlir::ShapedType::kDynamic}, builder.getF64Type()); - - Operation *itComputeOpFirstUsers = *(itComputeOp.getOperation()->getUsers().begin()); - builder.setInsertionPoint(itComputeOpFirstUsers); /// Insert before itree Op - - mlir::Value w = builder.create(loc, w_type, w_lbls_value, w_format); - comet_vdump(w); - auto w_index_list_type = RankedTensorType::get({mlir::ShapedType::kDynamic}, builder.getIndexType()); /// tensor - mlir::Value w_already_set = builder.create(loc, w_index_list_type, w_lbls_value, w_format); - comet_vdump(w_already_set); - mlir::Value w_index_list = builder.create(loc, w_index_list_type, w_lbls_value, w_format); - comet_vdump(w_index_list); - - MemRefType w_index_list_size_type = MemRefType::get({1}, builder.getIndexType()); /// tensor<1xindex> - mlir::Value w_index_list_size_alloc = builder.create(loc, w_index_list_size_type); /// tensor<1xindex> - Value w_index_list_size = builder.create(loc, w_index_list_size_alloc); - - std::vector workspaceTensors = {w, w_already_set, w_index_list, w_index_list_size}; - tensors.push_back(w); /// {A, B, C, W} - - std::vector> formats = {opFormats[0], opFormats[1], opFormats[2], {"D"}}; - std::vector> perms = {opPerms[0], opPerms[1], opPerms[2], {sparseDimOutput}}; - - /// Start building an IndexTreeCompute Operation to represent Wj = 0; - std::vector c1_perms_int_0; - std::vector c1_perms_int_1; - std::vector> c1_perms_int = {c1_perms_int_0, c1_perms_int_1}; - std::vector c1_formats_str_0; - std::vector c1_formats_str_1; - std::vector> c1_formats_str = {c1_formats_str_0, c1_formats_str_1}; - - Value const_index_0 = builder.create(loc, 0); - std::vector c1_rhs = {const_index_0}; - mlir::Value c1_lhs = {w_index_list_size}; - std::string semiringName(itComputeOp.getSemiring().data()); - std::string maskNone = "none"; - std::string maskTypeName(itComputeOp.getMaskType().data()); - auto c1_semiring = builder.getStringAttr(semiringName); - auto c1_maskType = builder.getStringAttr(maskNone); /// masking attribute - /// for c1_rhs - std::vector> c1_rhsop_perms_str = {c1_perms_int_0}; - ArrayAttr c1_rhsop_perms = convert2DVectorToArrayAttrInt(c1_rhsop_perms_str, builder); - std::vector> c1_rhsop_formats_str = {c1_formats_str_0}; - ArrayAttr c1_rhsop_formats = convert2DVectorToArrayAttrStr(c1_rhsop_formats_str, builder); - mlir::Value c1_rhsop = builder.create(loc, mlir::UnrankedTensorType::get(builder.getIndexType()), c1_rhs, c1_rhsop_perms, c1_rhsop_formats); - comet_debug() << "IndexTreeComputeRHS Operation in Output (c1_rhs):\n"; - comet_vdump(c1_rhsop); - - /// for c1_lhs - std::vector> c1_lhsop_perms_str = {c1_perms_int_1}; - ArrayAttr c1_lhsop_perms = convert2DVectorToArrayAttrInt(c1_lhsop_perms_str, builder); - std::vector> c1_lhsop_formats_str = {c1_formats_str_1}; - ArrayAttr c1_lhsop_formats = convert2DVectorToArrayAttrStr(c1_lhsop_formats_str, builder); - mlir::Value c1_lhsop = builder.create(loc, mlir::UnrankedTensorType::get(builder.getF64Type()), c1_lhs, c1_lhsop_perms, c1_lhsop_formats); - comet_debug() << "IndexTreeComputeLHS Operation in Output (c1_lhs):\n"; - comet_vdump(c1_lhsop); - - /// for c1 ==> Wj = 0; - mlir::Value c1 = builder.create(loc, builder.getI64Type(), c1_rhsop, c1_lhsop, comp_worksp_opt, c1_semiring, c1_maskType); - comet_debug() << "IndexTreeCompute Operation in Output (c1):\n"; - comet_vdump(c1); - - /// insert c1 to sparseDimsParent - sparseDimsPerent.getDefiningOp()->insertOperands(0, c1); - - /// Start building an IndexTreeCompute Operation to represent Wj += Aik * Bkj; - std::vector c2_tensors = {tensors[0], tensors[1], w}; - std::vector c2_perms_int_0 = opPerms[0]; - std::vector c2_perms_int_1 = opPerms[1]; - std::vector c2_perms_int_2 = {sparseDimOutput}; - std::vector> c2_perms_int = {c2_perms_int_0, c2_perms_int_1, c2_perms_int_2}; - - /// Convert formats string array into StrAttr - std::vector c2_formats_str_0 = opFormats[0]; - std::vector c2_formats_str_1 = opFormats[1]; - std::vector c2_formats_str_2 = {"D"}; - std::vector> c2_formats_str = {c2_formats_str_0, c2_formats_str_1, c2_formats_str_2}; - std::vector c2_rhs; - std::vector> c2_rhsop_formats_str; - std::vector> c2_rhsop_perms_str; - if (tensors.size() > 4) /// masking input is available: tensors = {%op0, %op1, %mask, %out, %W} - { - c2_rhs = {c2_tensors[0], c2_tensors[1], tensors[2]}; /// tensors[2] val is the mask - c2_rhsop_formats_str = {c2_formats_str_0, c2_formats_str_1, opFormats[2]}; /// mask format is same as the output - c2_rhsop_perms_str = {c2_perms_int_0, c2_perms_int_1, opPerms[2]}; /// perms of mask are same as the output - } - else /// no masking input is provided: tensors = {%op0, %op1, %out, %W} - { - c2_rhs = {c2_tensors[0], c2_tensors[1]}; - c2_rhsop_formats_str = {c2_formats_str_0, c2_formats_str_1}; - c2_rhsop_perms_str = {c2_perms_int_0, c2_perms_int_1}; - } - std::vector c2_lhs = workspaceTensors; - - auto c2_semiring = builder.getStringAttr(semiringName); - auto c2_maskType = builder.getStringAttr(maskTypeName); /// masking attribute - - /// for c2_rhsop - ArrayAttr c2_rhsop_perms = convert2DVectorToArrayAttrInt(c2_rhsop_perms_str, builder); - ArrayAttr c2_rhsop_formats = convert2DVectorToArrayAttrStr(c2_rhsop_formats_str, builder); - mlir::Value c2_rhsop = builder.create(loc, mlir::UnrankedTensorType::get(builder.getF64Type()), c2_rhs, c2_rhsop_perms, c2_rhsop_formats); - comet_debug() << "IndexTreeComputeRHS Operation in Output (c2_rhs):\n"; - comet_vdump(c2_rhsop); - - /// for c2_lhsop - std::vector> c2_lhsop_perms_str = {c2_perms_int_2}; - ArrayAttr c2_lhsop_perms = convert2DVectorToArrayAttrInt(c2_lhsop_perms_str, builder); - std::vector> c2_lhsop_formats_str = {c2_formats_str_2}; - ArrayAttr c2_lhsop_formats = convert2DVectorToArrayAttrStr(c2_lhsop_formats_str, builder); - mlir::Value c2_lhsop = builder.create(loc, mlir::UnrankedTensorType::get(builder.getF64Type()), c2_lhs, c2_lhsop_perms, c2_lhsop_formats); - comet_debug() << "IndexTreeComputeLHS Operation in Output (c2_lhs):\n"; - comet_vdump(c2_lhsop); - - /// for c2 - mlir::Value c2 = builder.create(loc, i64Type, c2_rhsop, c2_lhsop, comp_worksp_opt, c2_semiring, c2_maskType); - comet_debug() << "IndexTreeCompute Operation in Output (c2):\n"; - comet_vdump(c2); - - /// Start building an IndexTreeCompute Operation to represent Cij = Wj; - std::vector c3_tensors; - if (tensors.size() > 4) /// masking input is available: tensors = {%op0, %op1, %mask, %out, %W} - { - c3_tensors = {tensors[3]}; - } - else /// masking input is NOT available: tensors = {%op0, %op1, %out, %W} - { - c3_tensors = {tensors[2]}; - } - std::vector c3_perms_int_0 = {sparseDimOutput}; - std::vector c3_perms_int_1 = opPerms[2]; - std::vector> c3_perms_int = {c3_perms_int_0, c3_perms_int_1}; - - /// Convert formats string array into StrAttr - std::vector c3_formats_str_0 = {"D"}; - std::vector c3_formats_str_1 = opFormats[2]; - std::vector> c3_formats_str = {c3_formats_str_0, c3_formats_str_1}; - - std::vector c3_rhs = workspaceTensors; - mlir::Value c3_lhs = c3_tensors[0]; - auto c3_semiring = builder.getStringAttr(semiringName); - auto c3_maskType = builder.getStringAttr(maskNone); /// masking attribute - - /// for c3_rhs - std::vector> c3_rhsop_perms_str = {c3_perms_int_0}; - ArrayAttr c3_rhsop_perms = convert2DVectorToArrayAttrInt(c3_rhsop_perms_str, builder); - std::vector> c3_rhsop_formats_str = {c3_formats_str_0}; - ArrayAttr c3_rhsop_formats = convert2DVectorToArrayAttrStr(c3_rhsop_formats_str, builder); - mlir::Value c3_rhsop = builder.create(loc, mlir::UnrankedTensorType::get(builder.getF64Type()), c3_rhs, c3_rhsop_perms, c3_rhsop_formats); - comet_debug() << "IndexTreeComputeRHS Operation in Output (c3_rhs):\n"; - comet_vdump(c3_rhsop); - - /// for c3_lhs - std::vector> c3_lhsop_perms_str = {c3_perms_int_1}; - ArrayAttr c3_lhsop_perms = convert2DVectorToArrayAttrInt(c3_lhsop_perms_str, builder); - std::vector> c3_lhsop_formats_str = {c3_formats_str_1}; - ArrayAttr c3_lhsop_formats = convert2DVectorToArrayAttrStr(c3_lhsop_formats_str, builder); - mlir::Value c3_lhsop = builder.create(loc, mlir::UnrankedTensorType::get(builder.getF64Type()), c3_lhs, c3_lhsop_perms, c3_lhsop_formats); - comet_debug() << "IndexTreeComputeLHS Operation in Output (c3_lhs):\n"; - comet_vdump(c3_lhsop); - - /// for c3 ==> Cij = Wj; - mlir::Value c3 = builder.create(loc, i64Type, c3_rhsop, c3_lhsop, comp_worksp_opt, c3_semiring, c3_maskType); - comet_debug() << "IndexTreeCompute Operation in Output (c3):\n"; - comet_vdump(c3); - - std::vector newComputeOps = {c2, c3}; - sparseIndicesOp.getDefiningOp()->setOperands(newComputeOps); - - /// remove redundant indices by calling a function - /// in elementwise: not remove - /// in spgemm: remove - /// check if there is redundant index - bool existRedundantIndex = false; - for (auto n : newComputeOps) - { - std::vector> perms; - getPermsOfComputeOp(n, perms); - std::vector allperms = getUnionOf2Dvector(perms); - comet_debug() << " print allperms \n"; - print_vector(allperms); - - std::vector ancestors; - std::vector dfsOps; - dfsRootOpTree(op.getChildren(), dfsOps); - getAncestorsWp(n, ancestors, dfsOps); - comet_debug() << " print ancestors \n"; - print_vector_value(ancestors); - - /// Iterate over every indicesOp - for (auto ancestor : ancestors) + // Success! + IRMapping map; + auto context = rewriter.getContext(); + auto loc = compute_op.getLoc(); + IndexNodeType index_node_type = IndexNodeType::get(context); + parent = fork; + for(auto index : indices_to_copy) { - /// If indicesOp's index is in allperms, no redundant - /// the indicesOp is real root, no redundant - /// Otherwise, redundant - if (isa(ancestor.getDefiningOp())) - { - indexTree::IndexTreeIndicesOp indicesOp = dyn_cast(ancestor.getDefiningOp()); - - ArrayAttr idsArrayAttr = indicesOp.getIndices(); /// should be 1D vector - /// actually only one index for the indicesOp in our implementation - for (auto m : idsArrayAttr) - { - int perm = m.cast().getInt(); - comet_debug() << " perm: " << perm << "\n"; - - if (findIndexInVector(allperms, perm) == allperms.size()) - { /// not exit - comet_debug() << " perm not exist in allperms\n"; - if (!isRealRoot(indicesOp.getOperation())) - { - existRedundantIndex = true; - comet_debug() << " existRedundantIndex: " << existRedundantIndex << "\n"; - } - } - } - } + Value new_index = rewriter.create(loc, index_node_type, parent); + map.map(index, new_index); + parent = new_index; } - } - if (existRedundantIndex) - { - comet_debug() << "There is loop invariant\n"; - removeRedundantIndices(newComputeOps, indexValueMap, denseDimInOutput, builder, loc); - } - - return newComputeOps; -} /// end CompressedWorkspaceOutput() - -void CompressedWorkspaceInput(std::vector computeOps, OpBuilder &builder, Location loc) -{ - auto comp_worksp_opt = builder.getBoolAttr(compressedworkspace); - for (auto computeOp : computeOps) - { - /// 1. get the opFormats and opPerms of the computeOp - std::vector> opFormats; - std::vector> opPerms; - std::vector> inputOutputMapping; - getFormatsPermsOfComputeOp(computeOp, opFormats, opPerms, inputOutputMapping); - comet_debug() << " \n"; - for (auto n : opFormats) - { - - print_vector(n); + for(auto pos : lhs_op.getPos()) { + Operation* index_to_tensor = pos.getDefiningOp(); + rewriter.clone(*index_to_tensor, map); } - for (auto n : opPerms) - { - - print_vector(n); + rewriter.clone(*lhs_op.getOperation(), map); + + for(Value rhs : rhs_operands) { + IndexTreeOperandOp operand_op = rhs.getDefiningOp(); + for(auto pos : operand_op.getPos()) { + Operation* index_to_tensor = pos.getDefiningOp(); + rewriter.clone(*index_to_tensor, map); + } + rewriter.clone(*operand_op.getOperation(), map); } - std::vector tensors; - getTensorsOfComputeOp(computeOp, tensors); - comet_debug() << " tensors.size(): " << tensors.size() << "\n"; - std::vector tensors_rhs; - getInputTensorsOfComputeOp(computeOp, tensors_rhs); - comet_debug() << " tensors_rhs.size(): " << tensors_rhs.size() << "\n"; - std::vector tensors_lhs; - getOutputTensorsOfComputeOp(computeOp, tensors_lhs); - comet_debug() << " tensors_lhs.size(): " << tensors_lhs.size() << "\n"; - - indexTree::IndexTreeComputeOp itComputeOp = dyn_cast(computeOp.getDefiningOp()); - std::string semiringName(itComputeOp.getSemiring().data()); - - std::vector sparseDimsOutput = getSparseDimsOutput(opFormats, opPerms); - std::vector sparseDimsInput = getSparseDimsInput(opFormats, opPerms); - comet_debug() << " sparseDimsInput.size(): " << sparseDimsInput.size() << "\n"; - - if (sparseDimsInput.size() == 1) - { /// solve only 1 sparseDimsInput - /// No need to apply workspace transformation - comet_debug() << " sparseDimsInput[0]: " << sparseDimsInput[0].dim << ", " << sparseDimsInput[0].tensorId << ", " << sparseDimsInput[0].dimOrder << "\n"; - - /// Wj=Aij*Bij ==> - /// ComputeNode(c1): Vj=0; - /// ComputeNode(c2): Vj=Aij; - /// ComputeNode(c3): Wj=Vj*Bij - - Value sparseInput = tensors_rhs[sparseDimsInput[0].tensorId]; - comet_vdump(sparseInput); - - std::vector v_lbls_value = {sparseInput.getDefiningOp()->getOperand(sparseDimsInput[0].dimOrder)}; - comet_debug() << "Dumping v_lbls_value\n"; - comet_vdump(v_lbls_value[0]); - comet_debug() << "Done\n"; - comet_vdump(sparseInput.getDefiningOp()->getOperand(sparseDimsInput[0].dimOrder)); - std::string v_format = "Dense"; /// tensor - auto v_type = RankedTensorType::get({mlir::ShapedType::kDynamic}, builder.getF64Type()); - - builder.setInsertionPoint(computeOp.getDefiningOp()); - mlir::Value v = builder.create(loc, v_type, v_lbls_value, v_format); - comet_vdump(v); - - /// Start building an IndexTreeCompute Operation to represent Vj=0 - std::vector c1_perms_int_0; - std::vector c1_perms_int_1 = {sparseDimsInput[0].dim}; - std::vector> c1_perms_int = {c1_perms_int_0, c1_perms_int_1}; - - std::vector c1_formats_str_0; - std::vector c1_formats_str_1 = {"D"}; - std::vector> c1_formats_str = {c1_formats_str_0, c1_formats_str_1}; - - auto i64Type = builder.getI64Type(); - Value const_f64_0 = builder.create(loc, builder.getF64Type(), builder.getF64FloatAttr(0.0)); - std::vector c1_rhs = {const_f64_0}; - mlir::Value c1_lhs = {v}; - std::string semiringName(itComputeOp.getSemiring().data()); - std::string maskNone = "none"; - auto c1_semiring = builder.getStringAttr(semiringName); - auto c1_maskType = builder.getStringAttr(maskNone); /// masking attribute - - /// for c1_rhs - std::vector> c1_rhsop_perms_str = {c1_perms_int_0}; - ArrayAttr c1_rhsop_perms = convert2DVectorToArrayAttrInt(c1_rhsop_perms_str, builder); - std::vector> c1_rhsop_formats_str = {c1_formats_str_0}; - ArrayAttr c1_rhsop_formats = convert2DVectorToArrayAttrStr(c1_rhsop_formats_str, builder); - mlir::Value c1_rhsop = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), - c1_rhs, - c1_rhsop_perms, - c1_rhsop_formats); - comet_debug() << "IndexTreeComputeRHS Operation in Input (c1_rhs):"; - comet_vdump(c1_rhsop); - - /// for c1_lhs - std::vector> c1_lhsop_perms_str = {c1_perms_int_1}; - ArrayAttr c1_lhsop_perms = convert2DVectorToArrayAttrInt(c1_lhsop_perms_str, builder); - std::vector> c1_lhsop_formats_str = {c1_formats_str_1}; - ArrayAttr c1_lhsop_formats = convert2DVectorToArrayAttrStr(c1_lhsop_formats_str, builder); - mlir::Value c1_lhsop = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), - c1_lhs, - c1_lhsop_perms, - c1_lhsop_formats); - comet_debug() << "IndexTreeComputeLHS Operation in Input (c1_lhs):"; - comet_vdump(c1_lhsop); - - /// for c1 - mlir::Value c1 = builder.create(loc, i64Type, - c1_rhsop, - c1_lhsop, - comp_worksp_opt, - c1_semiring, - c1_maskType); - comet_debug() << "IndexTreeCompute Operation in Input (c1): "; - comet_vdump(c1); - - /// Start building an IndexTreeCompute Operation to represent Vj = Aij - std::vector c2_perms_int_0 = opPerms[sparseDimsInput[0].tensorId]; - std::vector c2_perms_int_1 = {sparseDimsInput[0].dim}; - std::vector> c2_perms_int = {c2_perms_int_0, c2_perms_int_1}; - - std::vector c2_formats_str_0 = opFormats[sparseDimsInput[0].tensorId]; - std::vector c2_formats_str_1 = {"D"}; - std::vector> c2_formats_str = {c2_formats_str_0, c2_formats_str_1}; - - std::vector c2_rhs = {tensors_rhs[sparseDimsInput[0].tensorId]}; - - mlir::Value c2_lhs = {v}; - auto c2_semiring = builder.getStringAttr(semiringName); - auto c2_maskType = builder.getStringAttr(maskNone); /// masking attribute - - /// for c2_rhs - std::vector> c2_rhsop_perms_str = {c2_perms_int_0}; - ArrayAttr c2_rhsop_perms = convert2DVectorToArrayAttrInt(c2_rhsop_perms_str, builder); - std::vector> c2_rhsop_formats_str = {c2_formats_str_0}; - ArrayAttr c2_rhsop_formats = convert2DVectorToArrayAttrStr(c2_rhsop_formats_str, builder); - mlir::Value c2_rhsop = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), - c2_rhs, - c2_rhsop_perms, - c2_rhsop_formats); - comet_debug() << "IndexTreeComputeRHS Operation in Input (c2_rhs):"; - comet_vdump(c2_rhsop); - - /// for c2_lhs - std::vector> c2_lhsop_perms_str = {c2_perms_int_1}; - ArrayAttr c2_lhsop_perms = convert2DVectorToArrayAttrInt(c2_lhsop_perms_str, builder); - std::vector> c2_lhsop_formats_str = {c2_formats_str_1}; - ArrayAttr c2_lhsop_formats = convert2DVectorToArrayAttrStr(c2_lhsop_formats_str, builder); - mlir::Value c2_lhsop = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), - c2_lhs, - c2_lhsop_perms, - c2_lhsop_formats); - comet_debug() << "IndexTreeComputeLHS Operation in Input (c2_lhs):"; - comet_vdump(c2_lhsop); - - /// for c2 - mlir::Value c2 = builder.create(loc, i64Type, - c2_rhsop, - c2_lhsop, - comp_worksp_opt, - c2_semiring, - c2_maskType); - comet_debug() << "IndexTreeCompute Operation in Input (c2): "; - comet_vdump(c2); - - /// Start building an IndexTreeCompute Operation to represent Wj=Vj*Bij - std::vector c3_perms_int_0 = {sparseDimsInput[0].dim}; - std::vector c3_perms_int_1 = opPerms[1]; - std::vector c3_perms_int_2 = opPerms[opPerms.size() - 1]; - std::vector> c3_perms_int = {c3_perms_int_0, c3_perms_int_1, c3_perms_int_2}; - - /// Convert formats string array into StrAttr - std::vector c3_formats_str_0 = {"D"}; - std::vector c3_formats_str_1 = opFormats[1]; - std::vector c3_formats_str_2 = opFormats[opFormats.size() - 1]; - - std::vector> c3_formats_str = {c3_formats_str_0, c3_formats_str_1, c3_formats_str_2}; - std::vector c3_rhs = {v, tensors[1]}; - - comet_debug() << " tensors.size(): " << tensors.size() << "\n"; - std::vector c3_lhs = tensors_lhs; - - auto c3_semiring = builder.getStringAttr(semiringName); - auto c3_maskType = builder.getStringAttr(maskNone); /// masking attribute - - /// for c3_rhs - std::vector> c3_rhsop_perms_str = {c3_perms_int_0, c3_perms_int_1}; - ArrayAttr c3_rhsop_perms = convert2DVectorToArrayAttrInt(c3_rhsop_perms_str, builder); - std::vector> c3_rhsop_formats_str = {c3_formats_str_0, c3_formats_str_1}; - ArrayAttr c3_rhsop_formats = convert2DVectorToArrayAttrStr(c3_rhsop_formats_str, builder); - mlir::Value c3_rhsop = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), - c3_rhs, - c3_rhsop_perms, - c3_rhsop_formats); - comet_debug() << "IndexTreeComputeRHS Operation in Input (c3_rhs):"; - comet_vdump(c3_rhsop); - - /// for c3_lhs - std::vector> c3_lhsop_perms_str = {c3_perms_int_2}; - ArrayAttr c3_lhsop_perms = convert2DVectorToArrayAttrInt(c3_lhsop_perms_str, builder); - std::vector> c3_lhsop_formats_str = {c3_formats_str_2}; - ArrayAttr c3_lhsop_formats = convert2DVectorToArrayAttrStr(c3_lhsop_formats_str, builder); - mlir::Value c3_lhsop = builder.create(loc, - mlir::UnrankedTensorType::get(builder.getF64Type()), - c3_lhs, - c3_lhsop_perms, - c3_lhsop_formats); - comet_debug() << "IndexTreeComputeLHS Operation in Input (c3_lhs):"; - comet_vdump(c3_lhsop); - - /// for c3 - mlir::Value c3 = builder.create(loc, i64Type, - c3_rhsop, - c3_lhsop, - comp_worksp_opt, - c3_semiring, - c3_maskType); - comet_debug() << "IndexTreeCompute Operation in Input (t3): "; - comet_vdump(c3); - - /// old version for new children ops - std::vector newComputeOps = {c1, c2, c3}; - replaceOperands(itComputeOp.getOperation(), newComputeOps); - - /// Step 2: split j into 3. - Operation *needSplitNode = *(newComputeOps[0].getDefiningOp()->getResult(0).getUsers().begin()); - Operation *parentSplitNode = *(needSplitNode->getResult(0).getUsers().begin()); - comet_debug() << " call splitIndicesOp for applying workspace in Input \n"; - comet_pdump(needSplitNode); - splitIndicesOp(needSplitNode, parentSplitNode->getResult(0), builder, loc); - comet_debug() << "\n"; - - } /// end if(sparseDimsInput.size() == 1) + Operation* new_compute_op = rewriter.clone(*compute_op.getOperation(), map); + rewriter.replaceOp(compute_op, new_compute_op->getResults()); + return success(); } -} - -void IndexTreeWorkspaceTransformationsPass::CompressedWorkspaceTransforms(mlir::func::FuncOp funcop) -{ - funcop.walk([](indexTree::IndexTreeOp op) - { - OpBuilder builder(op); - comet_vdump(op); - - Location loc = op.getLoc(); - - /// 1. Find its child, until reach the ta.itCompute op - /// Get first user - Value computeOp = op.getOperation()->getOperand(0); - comet_vdump(computeOp); - - /// Only one child?? - /// Build a map, which index is in which IndexTreeIndicesOp - /// ------ Notice: each index is only in one IndicesOp in original index tree here - /// ------ TODO(gkestor): handle more complicate cases: one index is in more than one IndicesOp - /// For an indexTree, the indices ids are - std::map indexValueMap; - - while (!(isa(computeOp.getDefiningOp()))) - { - if (isa(computeOp.getDefiningOp())) - { - auto indicesop = dyn_cast(computeOp.getDefiningOp()); - ArrayAttr idsArrayAttr = indicesop.getIndices(); - for (auto n : idsArrayAttr) - { - int ids = n.cast().getInt(); - indexValueMap.emplace(ids, computeOp); - } - } - computeOp = computeOp.getDefiningOp()->getOperand(0); /// put here - } - comet_vdump(computeOp); - - /// 2. Check if there is sparse dim in the ta.itCompute op, - std::vector> opFormats; - std::vector> opPerms; - std::vector> inputOutputMapping; - getFormatsPermsOfComputeOp(computeOp, opFormats, opPerms, inputOutputMapping); - -#ifdef DEBUG_MODE_WorkspaceTransformsPass - comet_debug() << "Print opFormats:\n"; - for (auto n : opFormats) - { - - print_vector(n); - } -#endif - - indexTree::IndexTreeComputeOp itComputeOp = dyn_cast(computeOp.getDefiningOp()); - - /// Check the input tensors, and the output tensor, to see if it contains sparse dimensions - /// get the dim ids - std::vector sparseDimsOutput = getSparseDimsOutput(opFormats, opPerms); - -#ifdef DEBUG_MODE_WorkspaceTransformsPass - comet_debug() << " Print sparseDimsOutput: "; - for (auto p : sparseDimsOutput) - { - comet_debug() << p << " "; - } - comet_debug() << "\n"; -#endif - - std::vector sparseDimsInput = getSparseDimsInput(opFormats, opPerms); - - if (sparseDimsOutput.size() == 0 && sparseDimsInput.size() == 0) - { - /// No need to apply workspace transformation - comet_debug() << __FILE__ << __LINE__ << " No need to apply workspace transformation\n"; - return; - } - - assert(sparseDimsOutput.size() == 1 && " More than one sparse index in the output, we are expecting to support it in the future\n"); - - std::vector newComputeOps; - /// create three IndexTreeComputeOp op - /// sparse dim in output tensor - if (sparseDimsOutput.size() == 1) - { - newComputeOps = CompressedWorkspaceOutput(sparseDimsOutput, itComputeOp, opFormats, opPerms, indexValueMap, builder, op); - } - /// initially here workspaceOutput content - -#ifdef DEBUG_MODE_WorkspaceTransformsPass - /// Should notice, the itree has been the new itree already after call workspaceOutput - for (auto n : newComputeOps) - { - - comet_vdump(n); - } -#endif - if (sparseDimsInput.size() == 1) - { - comet_vdump(op); - /// Need the newComputeOps - CompressedWorkspaceInput(newComputeOps, builder, loc); - } - - /// Also remove previous IndexTreeComputeOp's LHS and RHS. - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(itComputeOp->getOperand(0).getDefiningOp()); - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(itComputeOp->getOperand(1).getDefiningOp()); - - itComputeOp.erase(); - itComputeOp_rhs.erase(); - itComputeOp_lhs.erase(); }); /// end function traverse - - comet_debug() << __FILE__ << " " << __LINE__ << "CompressedWorkspaceTransforms pass is done\n"; -} +}; void IndexTreeWorkspaceTransformationsPass::runOnOperation() { comet_debug() << __FILE__ << " " << __LINE__ << " starting CompressedWorkspaceTransforms pass \n"; - func::FuncOp function = getOperation(); - /// Traverse the function, only handle ta.itree operation - CompressedWorkspaceTransforms(function); + mlir::RewritePatternSet workspace_transformation_patterns(&getContext()); + + workspace_transformation_patterns.add(&getContext()); + indexTree::populateDomainInferencePatterns(&getContext(), workspace_transformation_patterns); //For new index variables + indexTree::populateDomainConcretizationPatterns(&getContext(), workspace_transformation_patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(workspace_transformation_patterns)); comet_debug() << __FILE__ << " " << __LINE__ << " ending CompressedWorkspaceTransforms pass \n"; } diff --git a/lib/Dialect/TensorAlgebra/CMakeLists.txt b/lib/Dialect/TensorAlgebra/CMakeLists.txt index c5a765de..ef4ce65b 100644 --- a/lib/Dialect/TensorAlgebra/CMakeLists.txt +++ b/lib/Dialect/TensorAlgebra/CMakeLists.txt @@ -15,6 +15,7 @@ add_llvm_library(COMETTensorAlgebraDialect add_dependencies( COMETTensorAlgebraDialect + COMETTensorAlgebraTypesIncGen COMETTensorAlgebraOpsIncGen COMETTensorAlgebraPassIncGen ) diff --git a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp index 6b0d2f54..94d3caf4 100644 --- a/lib/Dialect/TensorAlgebra/IR/TADialect.cpp +++ b/lib/Dialect/TensorAlgebra/IR/TADialect.cpp @@ -27,11 +27,11 @@ //===----------------------------------------------------------------------===// #include #include "comet/Dialect/TensorAlgebra/IR/TADialect.h" -#include "comet/Dialect/TensorAlgebra/IR/TATypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -43,93 +43,6 @@ using namespace mlir::tensorAlgebra; /// TADialect //===----------------------------------------------------------------------===// -Type mlir::tensorAlgebra::TADialect::parseType(DialectAsmParser &parser) const -{ - /// Parse the main keyword for the type. - StringRef keyword; - /// for "range" and "sptensor" type - if (parser.parseKeyword(&keyword)) - return Type(); - - MLIRContext *context = getContext(); - - /// Handle 'range' types. - if (keyword == "range") - { - return RangeType::get(context); - } - - /// Parse the element types of the sptensor. - if (keyword == "sptensor") - { - if (parser.parseLess()) - { - return Type(); - } - - SmallVector elementTypes; - do - { - /// Parse the current element type. - llvm::SMLoc typeLoc = parser.getCurrentLocation(); - mlir::Type elementType; - - if (parser.parseType(elementType)) - return nullptr; - - /// Check that the type is either a TensorType or another StructType. - if (!elementType.isa()) - { - parser.emitError(typeLoc, "element type for a struct must either " - "be a TensorType or a StructType, got: ") - << elementType; - return Type(); - } - elementTypes.push_back(elementType); - - /// Parse the optional: `,` - } while (succeeded(parser.parseOptionalComma())); - - /// Parse: `>` - if (parser.parseGreater()) - return Type(); - - return SparseTensorType::get(elementTypes); - } - - parser.emitError(parser.getNameLoc(), - "unknown TensorAlgebra type: " + keyword); - return Type(); -} - -/// RangeType prints as just "range". -static void print(RangeType type, DialectAsmPrinter &printer) -{ - printer << "range"; -} - -void mlir::tensorAlgebra::TADialect::printType( - Type type, DialectAsmPrinter &printer) const -{ - if (type.isa()) - { - print(type.cast(), printer); - } - else if (type.isa()) - { - /// Currently the only toy type is a struct type. - SparseTensorType sparseTensorType = type.cast(); - - /// Print the struct type according to the parser format. - printer << "sptensor<"; - llvm::interleaveComma(sparseTensorType.getElementTypes(), printer); - printer << '>'; - } - else - { - llvm_unreachable("Unhandled TensorAlgebra type"); - } -} //===----------------------------------------------------------------------===// /// ConstantOp @@ -339,93 +252,53 @@ mlir::LogicalResult TAReturnOp::verify() /// TA Types //===----------------------------------------------------------------------===// -namespace mlir +// Implements the shaped type interface for the workspace type +ShapedType WorkspaceType::cloneWith(std::optional> shape, Type elementType) const { - namespace tensorAlgebra - { - namespace detail - { - /// This class represents the internal storage of the Toy `SparseTensorType`. - struct SparseTensorTypeStorage : public mlir::TypeStorage - { - /// The `KeyTy` is a required type that provides an interface for the storage - /// instance. This type will be used when uniquing an instance of the type - /// storage. For our struct type, we will unique each instance structurally on - /// the elements that it contains. - using KeyTy = llvm::ArrayRef; - - /// A constructor for the type storage instance. - SparseTensorTypeStorage(llvm::ArrayRef elementTypes) - : elementTypes(elementTypes) {} - - /// Define the comparison function for the key type with the current storage - /// instance. This is used when constructing a new instance to ensure that we - /// haven't already uniqued an instance of the given key. - bool operator==(const KeyTy &key) const { return key == elementTypes; } - - /// Define a hash function for the key type. This is used when uniquing - /// instances of the storage, see the `StructType::get` method. - /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type - /// have hash functions available, so we could just omit this entirely. - static llvm::hash_code hashKey(const KeyTy &key) - { - return llvm::hash_value(key); - } - - /// Define a construction function for the key type from a set of parameters. - /// These parameters will be provided when constructing the storage instance - /// itself. - /// Note: This method isn't necessary because KeyTy can be directly - /// constructed with the given parameters. - static KeyTy getKey(llvm::ArrayRef elementTypes) - { - return KeyTy(elementTypes); - } - - /// Define a construction method for creating a new instance of this storage. - /// This method takes an instance of a storage allocator, and an instance of a - /// `KeyTy`. The given allocator must be used for *all* necessary dynamic - /// allocations used to create the type storage and its internal. - static SparseTensorTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) - { - /// Copy the elements from the provided `KeyTy` into the allocator. - llvm::ArrayRef elementTypes = allocator.copyInto(key); - - /// Allocate the storage instance and construct it. - return new (allocator.allocate()) - SparseTensorTypeStorage(elementTypes); - } - - /// The following field contains the element types of the struct. - llvm::ArrayRef elementTypes; - }; - - } /// end namespace detail - } /// end namespace tensoralgebra -} /// end namespace mlir - -/// Create an instance of a `SparseTensorType` with the given element types. There -/// *must* be at least one element type. -SparseTensorType SparseTensorType::get(llvm::ArrayRef elementTypes) + // TODO: This may (?) require converting dimensions? Not sure + assert(false && "Workspace tensor cannot not be closed into another type"); + return NULL; +} + +bool WorkspaceType::hasRank() const +{ + return true; +} + +llvm::ArrayRef WorkspaceType::getShape() const +{ + return getDims(); +} + +// Implements the shaped type interface for the sparse tensor type +ShapedType SparseTensorType::cloneWith(std::optional> shape, Type elementType) const { - assert(!elementTypes.empty() && "expected at least 1 element type"); - - /// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance - /// of this type. The first two parameters are the context to unique in and the - /// kind of the type. The parameters after the type kind are forwarded to the - /// storage instance. - mlir::MLIRContext *ctx = elementTypes.front().getContext(); - return Base::get(ctx, elementTypes); + // TODO: This may (?) require converting dimensions? Not sure + assert(false && "Sparse tensor cannot not be closed into another type"); + return NULL; } -/// Returns the element types of this sparse tensor type. -llvm::ArrayRef SparseTensorType::getElementTypes() +bool SparseTensorType::hasRank() const { - /// 'getImpl' returns a pointer to the internal storage instance. - return getImpl()->elementTypes; + return true; } +llvm::ArrayRef SparseTensorType::getShape() const +{ + return getDims(); +} + +//===----------------------------------------------------------------------===// +/// TableGen'd type definitions +//===----------------------------------------------------------------------===// +#define GET_TYPEDEF_CLASSES +#include "comet/Dialect/TensorAlgebra/IR/TATypes.cpp.inc" + +//===----------------------------------------------------------------------===// +/// TableGen'd enum definitions +//===----------------------------------------------------------------------===// +#include "comet/Dialect/TensorAlgebra/IR/TAEnums.cpp.inc" + //===----------------------------------------------------------------------===// /// TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -441,9 +314,18 @@ llvm::ArrayRef SparseTensorType::getElementTypes() /// the point of registration of types and operations for the dialect. void TADialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "comet/Dialect/TensorAlgebra/IR/TATypes.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "comet/Dialect/TensorAlgebra/IR/TAAttrs.cpp.inc" + >(); + addOperations< #define GET_OP_LIST #include "comet/Dialect/TensorAlgebra/IR/TAOps.cpp.inc" >(); - addTypes(); } diff --git a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp index b97333e5..b491dbcf 100644 --- a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp +++ b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp @@ -61,231 +61,6 @@ using namespace mlir::indexTree; //===----------------------------------------------------------------------===// namespace { - void mixModeEltWiseMultSparseTensorOutputLowering(Value computeOp, Location loc, - std::vector> rshPerms, - std::vector &dimSizes, - std::vector &tensorload_sizes_vec, - std::vector &array_sizes_vec, - PatternRewriter &rewriter) - { - - IndexType indexType = IndexType::get(computeOp.getContext()); - FloatType f64Type = FloatType::getF64(computeOp.getContext()); - auto dynamicmemTy_1d_index = MemRefType::get({ShapedType::kDynamic}, indexType); /// memref - auto dynamicmemTy_1d_f64 = MemRefType::get({ShapedType::kDynamic}, f64Type); /// memref - - comet_debug() << "mixModeEltWiseMultSparseTensorOutputLowering computeOp\n"; - comet_vdump(computeOp); - - /// elementwise mul op in mix sparse dense case - /// If elementwise, copy sparse input arrays for elementwise mul - int sparse_inputtensor_id = -1; - auto rhsComputeOp = computeOp.getDefiningOp()->getOperand(0).getDefiningOp(); - - auto first_operand = rhsComputeOp->getOperand(0).getDefiningOp(); - auto second_operand = rhsComputeOp->getOperand(1).getDefiningOp(); - comet_debug() << "EltWiseMult Operands:\n"; - comet_pdump(first_operand); - comet_pdump(second_operand); - - if (isa(first_operand)) - { - sparse_inputtensor_id = 0; - } - else if (isa(second_operand)) - { - sparse_inputtensor_id = 1; - } - else - { - llvm::errs() << "ERROR: SparseTensorConstructOp was not found as one of the operands for itCompute\n"; - } - - comet_debug() << " SparseTensorConstructOp for computeOp: \n"; - comet_pdump(rhsComputeOp->getOperand(sparse_inputtensor_id).getDefiningOp()); - auto sptensor_construct_op = cast(rhsComputeOp->getOperand(sparse_inputtensor_id).getDefiningOp()); - - for (unsigned int i = 0; i < 4 * (rshPerms[sparse_inputtensor_id].size()) + 1; i++) - { - comet_debug() << " in for loop\n"; - Value intput_tensorload_op = cast(sptensor_construct_op.getOperand(i).getDefiningOp()); - Value input_alloc_op = cast(intput_tensorload_op.getDefiningOp()->getOperand(0).getDefiningOp()); - comet_debug() << " AllocOp: "; - comet_vdump(input_alloc_op); - - comet_debug() << " "; - Value input_alloc_op_param = input_alloc_op.getDefiningOp()->getOperand(0); - comet_debug() << " "; - - Value output_alloc_op; - if (i < 4 * (rshPerms[sparse_inputtensor_id].size())) - { - /// Memory allocation for position and coordinate arrays in sparse tensor contractions - output_alloc_op = insertAllocAndInitialize(loc, dynamicmemTy_1d_index, ValueRange{input_alloc_op_param}, rewriter); - } - else - { - /// Memory allocation for value array in sparse tensor contractions - output_alloc_op = insertAllocAndInitialize(loc, dynamicmemTy_1d_f64, ValueRange{input_alloc_op_param}, rewriter); /// Cval array - comet_debug() << " AllocOp: "; - comet_vdump(output_alloc_op); - } - - Value output_tensorload_op = rewriter.create(loc, output_alloc_op); - tensorload_sizes_vec.push_back(output_tensorload_op); - } - comet_debug() << " "; - - /// [0...2d, 2d+1...4d+1, 4d+2...5d+1] - for (unsigned int i = 0; i < 4 * (rshPerms[sparse_inputtensor_id].size()) + 1; i++) - { - int sizes_i = i + 4 * (rshPerms[sparse_inputtensor_id].size()) + 1; - comet_debug() << " "; - comet_pdump(sptensor_construct_op.getOperand(sizes_i).getDefiningOp()); - - Value input_load_op = sptensor_construct_op.getOperand(sizes_i); - comet_debug() << "Ops push_back for Sparse Tensor Construct Op for MixedMode elementwise multiplication (array_sizes_vec):\n"; - comet_vdump(input_load_op); - array_sizes_vec.push_back(input_load_op); - } - - for (unsigned int i = 0; i < rshPerms[sparse_inputtensor_id].size(); i++) - { - int sizes_i = i + 2 * (2 * (rshPerms[sparse_inputtensor_id].size()) + 1); - comet_debug() << " "; - comet_pdump(sptensor_construct_op.getOperand(sizes_i).getDefiningOp()); - - Value input_load_op = sptensor_construct_op.getOperand(sizes_i); - comet_debug() << "Ops push_back for Sparse Tensor Construct Op for MixedMode elementwise multiplication (dimSizes):\n"; - comet_vdump(input_load_op); - dimSizes.push_back(input_load_op); - } - } - - template - void pureSparseMultSparseTensorOutputLowering(T op, - Location loc, - std::string sparseOutputFormat, - std::vector &dimSizes, - std::vector &tensorload_sizes_vec, - std::vector &array_sizes_vec, - PatternRewriter &rewriter) - { - comet_debug() << " sparse output is used in itComputeOp op\n"; - comet_debug() << " sparseOutputFormat: " << sparseOutputFormat << "\n"; - - comet_vdump(op); - - IndexType indexType = IndexType::get(op.getContext()); - FloatType f64Type = FloatType::getF64(op.getContext()); - auto dynamicmemTy_1d_index = MemRefType::get({ShapedType::kDynamic}, indexType); /// memref - auto dynamicmemTy_1d_f64 = MemRefType::get({ShapedType::kDynamic}, f64Type); /// memref - - Value cst_index_0 = rewriter.create(loc, IndexType::get(op.getContext()), rewriter.getIndexAttr(0)); - comet_vdump(cst_index_0); - Value cst_index_1 = rewriter.create(loc, IndexType::get(op.getContext()), rewriter.getIndexAttr(1)); - comet_vdump(cst_index_1); - - unsigned int tensor_rank = op.getOperation()->getNumOperands(); - - std::vector array_sizes; - std::vector array_sizes_alloc_vec; - std::vector initial_array_sizes; - - if (sparseOutputFormat.compare("CSR") == 0) - { /// CSR format - comet_debug() << " 2D CSR format in sparse output decl op\n"; - /// AllocOp, storeOp, LoadOp - initial_array_sizes.push_back(cst_index_1); - initial_array_sizes.push_back(cst_index_1); - - /// A1tile - initial_array_sizes.push_back(cst_index_0); - initial_array_sizes.push_back(cst_index_0); - - /// The other three size information size.. - /// get the dimension size from operand - /// std::vector dim_sizes; - for (unsigned int i = 0; i < op.getOperation()->getNumOperands(); i++) - { - if (isa(op.getOperation()->getOperand(i).getDefiningOp())) - { - Value indexlabelop = dyn_cast(op.getOperation()->getOperand(i).getDefiningOp()); - dimSizes.push_back(indexlabelop.getDefiningOp()->getOperand(1)); - } - } - /// The dim size is the second parameter of the - Value dim2_posSize = rewriter.create(loc, dimSizes[0], cst_index_1); - comet_debug() << "AddIOp generated for dim2_posSize:\n"; - comet_vdump(dim2_posSize); - initial_array_sizes.push_back(dim2_posSize); - - Value dim2_crdSize = rewriter.create(loc, dimSizes[0], dimSizes[1]); - initial_array_sizes.push_back(dim2_crdSize); - - /// A2tile - initial_array_sizes.push_back(cst_index_0); - initial_array_sizes.push_back(cst_index_0); - - /// Aval - initial_array_sizes.push_back(dim2_crdSize); - comet_debug() << " "; - comet_vdump(dim2_crdSize); - } - else - { - llvm::errs() << __FILE__ << ":" << __LINE__ << "Not supported format\n"; - } - - /// same with transpose case - comet_debug() << " initial_array_sizes.size(): " << initial_array_sizes.size() << "\n"; - comet_debug() << " tensor_rank: " << tensor_rank << "\n"; - std::vector array_alloc_vec; - for (unsigned int i = 0; i < 4 * tensor_rank + 1; i++) - { - Value alloc_sizes; - if (i < 4 * tensor_rank) - { - comet_debug() << " Inserting AllocOp: "; - alloc_sizes = insertAllocAndInitialize(loc, dynamicmemTy_1d_index, ValueRange{initial_array_sizes[i]}, rewriter); - comet_debug() << " AllocOp: "; - comet_vdump(alloc_sizes); - } - else - { - alloc_sizes = insertAllocAndInitialize(loc, dynamicmemTy_1d_f64, ValueRange{initial_array_sizes[i]}, rewriter); - comet_debug() << " AllocOp: "; - comet_vdump(alloc_sizes); - } - Value tensorload_sizes = rewriter.create(loc, alloc_sizes); - tensorload_sizes_vec.push_back(tensorload_sizes); - array_alloc_vec.push_back(alloc_sizes); - } - - /// Initialize the sizes of pos/crd/val arrays - array_sizes.push_back(cst_index_1); /// A1pos_size - array_sizes.push_back(cst_index_1); /// A1crd_size - array_sizes.push_back(cst_index_1); /// A1tile_pos_size - array_sizes.push_back(cst_index_0); /// A1tile_crd_size - array_sizes.push_back(cst_index_1); /// A2pos_size - array_sizes.push_back(cst_index_0); /// A2crd_size - array_sizes.push_back(cst_index_0); /// A2tile_pos_size - array_sizes.push_back(cst_index_1); /// A2tile_crd_size - array_sizes.push_back(cst_index_0); /// Aval_size - /// put the array sizes into alloc/store/loadOp - for (auto size : array_sizes) - { - MemRefType memTy_alloc_sizes = MemRefType::get({1}, IndexType::get(op.getContext())); - Value allocop = rewriter.create(loc, memTy_alloc_sizes); - rewriter.create(loc, size, allocop, ValueRange{cst_index_0}); - Value loadop = rewriter.create(loc, allocop, ValueRange{cst_index_0}); - array_sizes_vec.push_back(loadop); - array_sizes_alloc_vec.push_back(allocop); - } - - rewriter.create(loc, dimSizes[0], array_alloc_vec[0], ValueRange{cst_index_0}); - } - void insertReadFileLibCall(int rank_size, MLIRContext *ctx, ModuleOp &module, func::FuncOp function) { comet_debug() << "Inserting insertReadFileLibCall\n"; @@ -457,6 +232,97 @@ namespace } } +Value insertSparseTensorDeclOp(PatternRewriter & rewriter, + MLIRContext* ctx, + Location loc, + unsigned rank_size, + std::vector& tensorload_sizes_vec, + std::vector& array_sizes_vec, + std::vector>& allPerms, + std::vector& dimSizes, + std::string formats_str, + Type ty) + { + comet_debug() << " Get users after "; + /// create sparse tensor construct after lowering each sparse tensor output users + comet_debug() << " tensorload_sizes_vec.size(): " << tensorload_sizes_vec.size() << ", rank_size: " << rank_size << "\n"; + /// create sptensor_construct + + std::vector dim_formats = mlir::tensorAlgebra::getFormats(formats_str, rank_size, ctx); + + Value sptensor; + if (rank_size == 2) + { + sptensor = rewriter.create(loc, ty, + ValueRange{ + tensorload_sizes_vec[0], /// A1pos (each dimension consists of pos and crd arrays) + tensorload_sizes_vec[1], /// A1crd + tensorload_sizes_vec[2], /// A1tile_pos + tensorload_sizes_vec[3], /// A1tile_crd + tensorload_sizes_vec[4], /// A2pos + tensorload_sizes_vec[5], /// A2crd + tensorload_sizes_vec[6], /// A2tile_pos + tensorload_sizes_vec[7], /// A2tile_crd + tensorload_sizes_vec[8], /// Aval + array_sizes_vec[0], /// A1pos_size (size of each pos and crd arrays) + array_sizes_vec[1], /// A1crd_size + array_sizes_vec[2], /// A1tile_pos_size + array_sizes_vec[3], /// A1tile_crd_size + array_sizes_vec[4], /// A2pos_size + array_sizes_vec[5], /// A2crd_size + array_sizes_vec[6], /// A2tile_pos_size + array_sizes_vec[7], /// A2tile_crd_size + array_sizes_vec[8], /// Aval_size (size of value array) + dimSizes[0], /// dim1_size(size of each dimension in sparse tensor) + dimSizes[1] /// dim2_size (size of each dimension in sparse tensor) + }, 2, rewriter.getI32ArrayAttr(dim_formats)); + } + else if (rank_size == 3) + { + sptensor = rewriter.create(loc, ty, + ValueRange{ + tensorload_sizes_vec[0], /// A1pos (each dimension consists of pos and crd arrays) + tensorload_sizes_vec[1], /// A1crd + tensorload_sizes_vec[2], /// A1tile_pos + tensorload_sizes_vec[3], /// A1tile_crd + tensorload_sizes_vec[4], /// A2pos + tensorload_sizes_vec[5], /// A2crd + tensorload_sizes_vec[6], /// A2tile_pos + tensorload_sizes_vec[7], /// A2tile_crd + tensorload_sizes_vec[8], /// A3pos + tensorload_sizes_vec[9], /// A3crd + tensorload_sizes_vec[10], /// A3tile_pos + tensorload_sizes_vec[11], /// A3tile_crd + tensorload_sizes_vec[12], /// Aval + array_sizes_vec[0], /// A1pos_size (size of each pos and crd arrays) + array_sizes_vec[1], /// A1crd_size + array_sizes_vec[2], /// A1tile_pos_size + array_sizes_vec[3], /// A1tile_crd_size + array_sizes_vec[4], /// A2pos_size + array_sizes_vec[5], /// A2crd_size + array_sizes_vec[6], /// A2tile_pos_size + array_sizes_vec[7], /// A2tile_crd_size + array_sizes_vec[8], /// A3pos_size + array_sizes_vec[9], /// A3crd_size + array_sizes_vec[10], /// A3tile_pos_size + array_sizes_vec[11], /// A3tile_crd_size + array_sizes_vec[12], /// Aval_size (size of value array) + dimSizes[0], /// dim1_size (size of each dimension in sparse tensor) + dimSizes[1], /// dim2_size (size of each dimension in sparse tensor) + dimSizes[2] /// dim3_size + }, 3, rewriter.getI32ArrayAttr(dim_formats)); + } + else + { + llvm::errs() << __FILE__ << ":" << __LINE__ << "ERROR: Not supported format (Tensors of dimensions greater than 3 are currently not supported).\n"; + } + + comet_debug() << "SparseTensorConstructOp generated for sparse output tensor:\n"; + comet_vdump(sptensor); + + return sptensor; + } + /// This a common lowering function used to lower SparseOutputTensorDeclOp and TempSparseOutputTensorDeclOp template void lowerSparseOutputTensorDec(T op, PatternRewriter &rewriter) @@ -493,6 +359,8 @@ namespace comet_debug() << " " << formats_str << " isDense: " << isDense(formats_str, ", ") << "\n"; + Value new_tensor; + /// sparse output if (isDense(formats_str, ", ") == false) { @@ -744,264 +612,46 @@ namespace comet_debug() << " AllocOp: "; comet_vdump(alloc_sizes); } - Value tensorload_sizes = rewriter.create(loc, alloc_sizes); + Value tensorload_sizes = rewriter.create(loc, alloc_sizes, rewriter.getUnitAttr(), rewriter.getUnitAttr()); tensorload_sizes_vec.push_back(tensorload_sizes); } + new_tensor = insertSparseTensorDeclOp(rewriter, op.getContext(), loc, rank_size, tensorload_sizes_vec, array_sizes_vec, allPerms, dimSizes, formats_str, op.getResult().getType()); + break; } - else if (isa(u)) + else if (isa(u)) { - comet_debug() << " sparse output is used in itComputeOp op\n"; - - /// Set the insertion point before its user - rewriter.setInsertionPoint(u); - - indexTree::IndexTreeComputeLHSOp lhsOp = cast(u); - comet_debug() << " formats_str: " << formats_str << "\n"; - comet_debug() << " current Op: "; - comet_vdump(lhsOp); - - for (auto uLHS : lhsOp.getOperation()->getUsers()) - { - assert(isa(uLHS) && "User of IndexTreeComputeLHSOp can only be IndexTreeComputeOp"); - - comet_debug() << " lhsOp user: "; - comet_pdump(uLHS); - - auto computeOp = cast(uLHS); - comet_debug() << " Get RHS op: "; - comet_vdump(computeOp); - - std::vector> rhsPerms; - getRHSPermsOfComputeOp(computeOp, rhsPerms); - - std::vector> rhsFormats; - getRHSFormatsOfComputeOp(computeOp, rhsFormats); - - comet_debug() << " rhsPerms: \n"; - for (auto m : rhsPerms) - { - comet_debug() << " \n"; - for (auto n : m) - { - comet_debug() << n << " \n"; - } - comet_debug() << "\n"; - } - - comet_debug() << " rhsFormats: \n"; - for (auto m : rhsFormats) - { - comet_debug() << " \n"; - for (auto n : m) - { - comet_debug() << n << " \n"; - } - comet_debug() << "\n"; - } - - bool isElementwise = checkIsElementwise(rhsPerms); - - comet_debug() << "Checking if it is mixed mode\n"; - bool isMixedMode = checkIsMixedMode(rhsFormats); - - comet_debug() << "IsElementWise: " << isElementwise << " isMixedMode: " << isMixedMode << "\n"; - if (isElementwise && isMixedMode) - { - comet_debug() << "It is an elementwise multiplication in mixed Mode sparse = sparse * dense\n"; - if (isMixedMode) - { - comet_debug() << "It is an mix-mode elementwise multiplication in Mix Mode\n"; - mixModeEltWiseMultSparseTensorOutputLowering(computeOp, - loc, - rhsPerms, - dimSizes, - tensorload_sizes_vec, - array_sizes_vec, rewriter); - } - else - { - comet_debug() << "It is an pure-sparse elementwise multiplication\n"; - pureSparseMultSparseTensorOutputLowering<>(op, - loc, - formats_str, - dimSizes, - tensorload_sizes_vec, - array_sizes_vec, - rewriter); - } - } - else - { - if (!isMixedMode) - { - comet_debug() << "It is an pure-sparse multiplication or assigment from dense to sparse (produced after workspace transformations)\n"; - pureSparseMultSparseTensorOutputLowering(op, - loc, - formats_str, - dimSizes, - tensorload_sizes_vec, - array_sizes_vec, - rewriter); - } - else - { - /// TODO(gkestor) Mix-mode sparse computation with sparse output not yet supported such as TTM (tensor times matrix) - /// TODO(gkestor): if the sparsity patterns is known - comet_debug() << "It is an mix mode element-wise multiplication\n"; - mixModeEltWiseMultSparseTensorOutputLowering(computeOp, - loc, - rhsPerms, - dimSizes, - tensorload_sizes_vec, - array_sizes_vec, rewriter); - } - } - } - } - else if (isa(u)) - { - comet_debug() << " Sparse output is used in TensorFillFromFileOp\n"; - auto fillfromfileop = cast(u); - /// Can get filename, from "filename" attribute of fillfromfileop - rewriter.eraseOp(fillfromfileop); - } - else if (isa(u)) - { - comet_debug() << "The tensor is in IndexTreeComputeRHSOp, no action taken\n"; - continue; - } - else if (isa(u)) - { - comet_debug() << "The tensor is in print op, no action taken\n"; - continue; - } - else if (isa(u)) - { - comet_debug() << "The tensor is in sum op, no action taken\n"; - continue; - } - else if (isa(u)) - { - /// TODO(gkestor): LabeledTensorOp is not used in the current design, needs cleaning up. - /// Look at the generated code. We should not generate LabeledTensorOp - continue; - } - else - { - comet_pdump(u); - llvm::errs() << __FILE__ << __LINE__ << " tensor is used in the following unsupported op\n"; - } - - comet_debug() << " Get users after "; - /// create sparse tensor construct after lowering each sparse tensor output users - comet_debug() << " tensorload_sizes_vec.size(): " << tensorload_sizes_vec.size() << ", rank_size: " << rank_size << "\n"; - /// create sptensor_construct - SmallVector elementTypes; - for (unsigned int i = 0; i < 4 * rank_size + 1; i++) - { - assert(tensorload_sizes_vec.size() > 0 && "ERROR: Please report this error to the developers!"); - comet_debug() << " " << i << " "; - comet_vdump(tensorload_sizes_vec[i]); - elementTypes.push_back(tensorload_sizes_vec[i].getType()); - } - comet_debug() << "\n "; - /// [0 ... 2*rank_size, 2*rank_size+1 ... 4*rank_size+1, 4*rank_size+2 ... 5*rank_size + 1] - /// 2d+1 + 2d+1 + d => 5d+2 - for (unsigned int i = 0; i < 4 * rank_size + 1; i++) - { - assert(array_sizes_vec.size() > 0 && "ERROR: Please report this error to the developers!"); - comet_debug() << " " << i << " "; - comet_vdump(array_sizes_vec[i]); - elementTypes.push_back(array_sizes_vec[i].getType()); - } - comet_debug() << "\n "; - for (unsigned int i = 0; i < rank_size; i++) - { - assert(dimSizes.size() > 0 && "ERROR: Please report this error to the developers!"); - elementTypes.push_back(dimSizes[i].getType()); - } - comet_debug() << "\n "; - - auto ty = tensorAlgebra::SparseTensorType::get(elementTypes); - - Value sptensor; - if (rank_size == 2) - { - sptensor = rewriter.create(loc, ty, - ValueRange{ - tensorload_sizes_vec[0], /// A1pos (each dimension consists of pos and crd arrays) - tensorload_sizes_vec[1], /// A1crd - tensorload_sizes_vec[2], /// A1tile_pos - tensorload_sizes_vec[3], /// A1tile_crd - tensorload_sizes_vec[4], /// A2pos - tensorload_sizes_vec[5], /// A2crd - tensorload_sizes_vec[6], /// A2tile_pos - tensorload_sizes_vec[7], /// A2tile_crd - tensorload_sizes_vec[8], /// Aval - array_sizes_vec[0], /// A1pos_size (size of each pos and crd arrays) - array_sizes_vec[1], /// A1crd_size - array_sizes_vec[2], /// A1tile_pos_size - array_sizes_vec[3], /// A1tile_crd_size - array_sizes_vec[4], /// A2pos_size - array_sizes_vec[5], /// A2crd_size - array_sizes_vec[6], /// A2tile_pos_size - array_sizes_vec[7], /// A2tile_crd_size - array_sizes_vec[8], /// Aval_size (size of value array) - dimSizes[0], /// dim1_size(size of each dimension in sparse tensor) - dimSizes[1] /// dim2_size (size of each dimension in sparse tensor) - }, - 2); - } - else if (rank_size == 3) - { - sptensor = rewriter.create(loc, ty, - ValueRange{ - tensorload_sizes_vec[0], /// A1pos (each dimension consists of pos and crd arrays) - tensorload_sizes_vec[1], /// A1crd - tensorload_sizes_vec[2], /// A1tile_pos - tensorload_sizes_vec[3], /// A1tile_crd - tensorload_sizes_vec[4], /// A2pos - tensorload_sizes_vec[5], /// A2crd - tensorload_sizes_vec[6], /// A2tile_pos - tensorload_sizes_vec[7], /// A2tile_crd - tensorload_sizes_vec[8], /// A3pos - tensorload_sizes_vec[9], /// A3crd - tensorload_sizes_vec[10], /// A3tile_pos - tensorload_sizes_vec[11], /// A3tile_crd - tensorload_sizes_vec[12], /// Aval - array_sizes_vec[0], /// A1pos_size (size of each pos and crd arrays) - array_sizes_vec[1], /// A1crd_size - array_sizes_vec[2], /// A1tile_pos_size - array_sizes_vec[3], /// A1tile_crd_size - array_sizes_vec[4], /// A2pos_size - array_sizes_vec[5], /// A2crd_size - array_sizes_vec[6], /// A2tile_pos_size - array_sizes_vec[7], /// A2tile_crd_size - array_sizes_vec[8], /// A3pos_size - array_sizes_vec[9], /// A3crd_size - array_sizes_vec[10], /// A3tile_pos_size - array_sizes_vec[11], /// A3tile_crd_size - array_sizes_vec[12], /// Aval_size (size of value array) - dimSizes[0], /// dim1_size (size of each dimension in sparse tensor) - dimSizes[1], /// dim2_size (size of each dimension in sparse tensor) - dimSizes[2] /// dim3_size - }, - 3); - } - else - { - llvm::errs() << __FILE__ << ":" << __LINE__ << "ERROR: Not supported format (Tensors of dimensions greater than 3 are currently not supported).\n"; + comet_debug() << " Sparse output is used in it.LHSOperandOp\n"; + // Tensor is created as the output of a sparse tensor operation + // For now we defer to the index tree dialect by inserting a tensor decl + // that just contains empty domains. + auto lhs_op = llvm::cast(u); + rank_size = lhs_op.getCrds().size(); + indexTree::DomainType domain_type = indexTree::DomainType::get(op.getContext()); + rewriter.setInsertionPoint(op); + Value empty_domain = rewriter.create(loc, domain_type); + llvm::SmallVector args = llvm::SmallVector(rank_size, empty_domain); + + new_tensor = rewriter.create(loc, op.getResult().getType(), args); + + + // Eventually, there are 2 cases: + // Case 1: We can determine apriori the dimension of the sparse tensor + // This is the case if none of the index variables in the output + // tensor are used in a union or a insersect op. In this case we use + // the sparse tensor decleration of the input in order to determine + // the output tensor. We allocate arrays of the same size and then + // insert a ta.SpTensorDeclOp. + // Case 2: We can't determine the dimension of the sparse tensor. + // This happens in all other cases. Here we insert a tensor + // that is defined with an (at least one) empty domain. In + // the lowering process we can either use the symbolic phase + // to determine the allocations needed, or we can perform the + // allocations during the computational phase + break; } - - comet_debug() << "SparseTensorConstructOp generated for sparse output tensor:\n"; - comet_vdump(sptensor); - - /// create ta.index_label operation. - comet_vdump(op); - - op.replaceAllUsesWith(sptensor); - rewriter.replaceOp(op, sptensor); - } /// for (auto u : op.getOperation()->getUsers()) + } + op.replaceAllUsesWith(new_tensor); + rewriter.replaceOp(op, {new_tensor}); } else { /// format == "Dense" @@ -1032,7 +682,7 @@ namespace comet_debug() << " AllocOp: "; comet_vdump(alloc_sizes1); - Value tensorLoad = rewriter.create(loc, alloc_sizes1); + Value tensorLoad = rewriter.create(loc, alloc_sizes1, rewriter.getUnitAttr(), rewriter.getUnitAttr()); comet_vdump(tensorLoad); op.replaceAllUsesWith(tensorLoad); @@ -1108,7 +758,7 @@ namespace cast(init_alloc.getDefiningOp()).setAlignmentAttr(rewriter.getI64IntegerAttr(32)); - Value tensorLoad = rewriter.create(loc, init_alloc); + Value tensorLoad = rewriter.create(loc, init_alloc, rewriter.getUnitAttr(), rewriter.getUnitAttr()); comet_debug() << " TensorLoad:\n"; comet_vdump(tensorLoad); @@ -1237,15 +887,22 @@ namespace } } } - else if (isa(u1)) + else if (isa(u1)) { - comet_debug() << " used in ta.itComputeRHS op\n"; - isOutputTensor = false; + comet_debug() << " used in it.LHSOperand op\n"; + isOutputTensor = true; } - else if (isa(u1)) + else if (isa(u1)) { - comet_debug() << " used in ta.itComputeLHS op\n"; - isOutputTensor = true; + comet_debug() << " used in it.Operand op\n"; + } + else if (isa(u1)) + { + comet_debug() << " used in it.TensorAccess op\n"; + } + else if (isa(u1)) + { + comet_debug() << " used in it.Domain op\n"; } else if (isa(u1)) { @@ -1274,6 +931,11 @@ namespace /// do nothing! comet_debug() << " the tensor has use in LabeledTensorOp and this use will be ignored!\n"; } + else if (isa(u1)) + { + /// do nothing! + comet_debug() << " the tensor has use in AllocWorkspaceOp\n"; + } else { u1->dump(); @@ -1370,6 +1032,8 @@ namespace Value alloc_sizes_cast = rewriter.create(loc, unrankedMemTy_index, alloc_sizes); std::vector dim_format = mlir::tensorAlgebra::getFormatsValue(formats_str, rank_size, rewriter, loc, indexType); + std::vector dim_format_int = mlir::tensorAlgebra::getFormats(formats_str, rank_size, ctx); + auto dim_format_attrs = rewriter.getI32ArrayAttr(dim_format_int); comet_debug() << " Get the dim_format\n"; /// inform the runtime of what env var to use for parsing input file @@ -1554,22 +1218,12 @@ namespace std::vector alloc_tensor_vec; for (unsigned int i = 0; i < sp_decl.getTotalArrayCount(); i++) { - Value tensorLoad = rewriter.create(loc, alloc_sizes_vec[i]); + Value tensorLoad = rewriter.create(loc, alloc_sizes_vec[i], rewriter.getUnitAttr(), rewriter.getUnitAttr()); alloc_tensor_vec.push_back(tensorLoad); } - /// create sptensor_construct - SmallVector elementTypes; - for (unsigned int i = 0; i < sp_decl.getTotalArrayCount(); i++) - { - elementTypes.push_back(alloc_tensor_vec[i].getType()); - } - for (unsigned int i = 0; i < 5 * rank_size + 1; i++) - { - elementTypes.push_back(array_sizes[i].getType()); - } - - auto ty = tensorAlgebra::SparseTensorType::get(elementTypes); + llvm::SmallVector dim_sizes(rank_size, ShapedType::kDynamic); // TODO: Determine sizes!!!! + auto ty = op.getResult().getType(); Value sptensor; if (rank_size == 2) @@ -1579,7 +1233,7 @@ namespace alloc_tensor_vec[4], alloc_tensor_vec[5], /// A2 alloc_tensor_vec[6], alloc_tensor_vec[7], /// A2_tile alloc_tensor_vec[8], array_sizes[0], array_sizes[1], array_sizes[2], array_sizes[3], array_sizes[4], array_sizes[5], array_sizes[6], array_sizes[7], array_sizes[8], array_sizes[9], array_sizes[10]}, - 2); + 2, dim_format_attrs); } else if (rank_size == 3) { @@ -1590,7 +1244,7 @@ namespace alloc_tensor_vec[8], alloc_tensor_vec[9], /// A3 alloc_tensor_vec[10], alloc_tensor_vec[11], /// A3_tile alloc_tensor_vec[12], array_sizes[0], array_sizes[1], array_sizes[2], array_sizes[3], array_sizes[4], array_sizes[5], array_sizes[6], array_sizes[7], array_sizes[8], array_sizes[9], array_sizes[10], array_sizes[11], array_sizes[12], array_sizes[13], array_sizes[14], array_sizes[15], array_sizes[16], array_sizes[17], array_sizes[18]}, - 3); + 3, dim_format_attrs); } else { @@ -1837,6 +1491,8 @@ namespace tensorAlgebra::SparseOutputTensorDeclOp, tensorAlgebra::IndexLabelStaticOp, tensorAlgebra::IndexLabelDynamicOp, + tensorAlgebra::AllocWorkspaceOp, + tensorAlgebra::TensorMultOp, // Should this be dynamically legal to only work with dense tensors? func::CallOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) @@ -1885,6 +1541,8 @@ namespace tensorAlgebra::DenseTensorDeclOp, tensorAlgebra::IndexLabelStaticOp, tensorAlgebra::IndexLabelDynamicOp, + tensorAlgebra::AllocWorkspaceOp, + tensorAlgebra::TensorMultOp, // Should this be dynamically legal to only work with dense tensors? func::CallOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) diff --git a/lib/Dialect/Utils/Utils.cpp b/lib/Dialect/Utils/Utils.cpp index 13a96f6b..75104e39 100644 --- a/lib/Dialect/Utils/Utils.cpp +++ b/lib/Dialect/Utils/Utils.cpp @@ -483,14 +483,14 @@ namespace mlir } /// for string delimiter - std::vector stringSplit(std::string s, std::string delimiter) + std::vector stringSplit(llvm::StringRef s, llvm::StringRef delimiter) { /// comet_debug() << "split formats string: " << s << ", deli: "<< delimiter << ".\n"; - std::vector res; + std::vector res; std::string format = ""; - for (unsigned int i = 0; i < s.length(); i++) + for (unsigned int i = 0; i < s.size(); i++) { comet_debug() << "s[" << i << "]: " << s[i] << "\n"; if (s[i] != delimiter[0] && s[i] != delimiter[1]) @@ -509,13 +509,13 @@ namespace mlir res.push_back(format); comet_debug() << "The final format: "; - print_vector(res); + print_vector(res); return res; } - std::vector> getAllFormats(ArrayAttr opFormatsArrayAttr, std::vector> allPerms) + std::vector> getAllFormats(ArrayAttr opFormatsArrayAttr, std::vector> allPerms) { - std::vector> allFormats(allPerms.size()); + std::vector> allFormats(allPerms.size()); /// format with each input matrix: ["CSR", "D", "D"] SpMM for (unsigned int i = 0; i < opFormatsArrayAttr.size(); i++) { @@ -814,7 +814,7 @@ namespace mlir return false; } - std::vector getFormatsValue(std::string formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IndexType indexType) + std::vector getFormatsValue(llvm::StringRef formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IndexType indexType) { Value format_unk = rewriter.create(loc, indexType, rewriter.getIndexAttr(-1)); Value format_dense = rewriter.create(loc, indexType, rewriter.getIndexAttr(0)); @@ -829,42 +829,42 @@ namespace mlir { /// 2D comet_debug() << " 2D\n"; /// Value dim0_format, dim1_format; - if (formats_str.compare(0, 3, "CSR") == 0) + if (formats_str.compare("CSR") == 0) { dim_format.push_back(format_dense); dim_format.push_back(format_unk); dim_format.push_back(format_compressed); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 4, "DCSR") == 0) + else if (formats_str.compare("DCSR") == 0) { dim_format.push_back(format_compressed); dim_format.push_back(format_unk); dim_format.push_back(format_compressed); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 3, "COO") == 0) + else if (formats_str.compare("COO") == 0) { /// COO dim_format.push_back(format_compressednonunique); dim_format.push_back(format_unk); dim_format.push_back(format_singleton); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 3, "ELL") == 0) + else if (formats_str.compare("ELL") == 0) { /// ELL dim_format.push_back(format_dense); dim_format.push_back(format_dense); dim_format.push_back(format_singleton); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 4, "BCSR") == 0) + else if (formats_str.compare("BCSR") == 0) { /// BCSR dim_format.push_back(format_dense); dim_format.push_back(format_compressednonunique); dim_format.push_back(format_dense); dim_format.push_back(format_dense); } - else if (formats_str.compare(0, 3, "CSB") == 0) + else if (formats_str.compare("CSB") == 0) { /// CSB dim_format.push_back(format_dense); dim_format.push_back(format_dense); @@ -873,22 +873,22 @@ namespace mlir } else if (formats_str.find("D") != std::string::npos || formats_str.find("CU") != std::string::npos || formats_str.find("CN") != std::string::npos || formats_str.find("S") != std::string::npos) { - std::vector format_vec = stringSplit(formats_str, ", "); + std::vector format_vec = stringSplit(formats_str, ", "); for (auto n : format_vec) { - if (n.compare(0, 1, "D") == 0) + if (n.compare("D") == 0) { dim_format.push_back(format_dense); } - else if (n.compare(0, 2, "CU") == 0) + else if (n.compare("CU") == 0) { dim_format.push_back(format_compressed); } - else if (n.compare(0, 2, "CN") == 0) + else if (n.compare("CN") == 0) { dim_format.push_back(format_compressednonunique); } - else if (n.compare(0, 1, "S") == 0) + else if (n.compare("S") == 0) { dim_format.push_back(format_singleton); } @@ -903,7 +903,7 @@ namespace mlir { /// 3D comet_debug() << " 3D\n"; /// Value dim0_format, dim1_format, dim2_format; - if (formats_str.compare(0, 3, "CSF") == 0) + if (formats_str.compare("CSF") == 0) { dim_format.push_back(format_compressed); dim_format.push_back(format_unk); @@ -912,7 +912,7 @@ namespace mlir dim_format.push_back(format_compressed); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 11, "ModeGeneric") == 0) + else if (formats_str.compare("ModeGeneric") == 0) { dim_format.push_back(format_compressednonunique); dim_format.push_back(format_unk); @@ -921,7 +921,7 @@ namespace mlir dim_format.push_back(format_dense); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 3, "COO") == 0) + else if (formats_str.compare("COO") == 0) { /// COO dim_format.push_back(format_compressednonunique); dim_format.push_back(format_unk); @@ -932,25 +932,25 @@ namespace mlir } else if (formats_str.find("D") != std::string::npos || formats_str.find("CU") != std::string::npos || formats_str.find("CN") != std::string::npos || formats_str.find("S") != std::string::npos) { - std::vector format_vec = stringSplit(formats_str, ", "); + std::vector format_vec = stringSplit(formats_str, ", "); comet_debug() << " format_vec.size(): " << format_vec.size() << " \n"; for (auto n : format_vec) { comet_debug() << "Current format attribute: " << n << "---\n"; - if (n.compare(0, 1, "D") == 0) + if (n.compare("D") == 0) { dim_format.push_back(format_dense); } - else if (n.compare(0, 2, "CU") == 0) + else if (n.compare("CU") == 0) { dim_format.push_back(format_compressed); } - else if (n.compare(0, 2, "CN") == 0) + else if (n.compare("CN") == 0) { dim_format.push_back(format_compressednonunique); } - else if (n.compare(0, 1, "S") == 0) + else if (n.compare("S") == 0) { dim_format.push_back(format_singleton); } @@ -977,7 +977,7 @@ namespace mlir return dim_format; } - std::vector getFormatsValueInt(std::string formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IntegerType intType) + std::vector getFormatsValueInt(llvm::StringRef formats_str, int rank_size, PatternRewriter &rewriter, Location loc, IntegerType intType) { Value format_unk = rewriter.create(loc, intType, rewriter.getIntegerAttr(intType, -1)); Value format_dense = rewriter.create(loc, intType, rewriter.getIntegerAttr(intType, 0)); @@ -992,42 +992,42 @@ namespace mlir { /// 2D comet_debug() << " 2D\n"; /// Value dim0_format, dim1_format; - if (formats_str.compare(0, 3, "CSR") == 0) + if (formats_str.compare("CSR") == 0) { dim_format.push_back(format_dense); dim_format.push_back(format_unk); dim_format.push_back(format_compressed); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 4, "DCSR") == 0) + else if (formats_str.compare("DCSR") == 0) { dim_format.push_back(format_compressed); dim_format.push_back(format_unk); dim_format.push_back(format_compressed); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 3, "COO") == 0) + else if (formats_str.compare("COO") == 0) { /// COO dim_format.push_back(format_compressednonunique); dim_format.push_back(format_unk); dim_format.push_back(format_singleton); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 3, "ELL") == 0) + else if (formats_str.compare("ELL") == 0) { /// ELL dim_format.push_back(format_dense); dim_format.push_back(format_dense); dim_format.push_back(format_singleton); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 4, "BCSR") == 0) + else if (formats_str.compare("BCSR") == 0) { /// BCSR dim_format.push_back(format_dense); dim_format.push_back(format_compressed); dim_format.push_back(format_dense); dim_format.push_back(format_dense); } - else if (formats_str.compare(0, 3, "CSB") == 0) + else if (formats_str.compare("CSB") == 0) { /// CSB dim_format.push_back(format_dense); dim_format.push_back(format_dense); @@ -1036,22 +1036,22 @@ namespace mlir } else if (formats_str.find("D") != std::string::npos || formats_str.find("CU") != std::string::npos || formats_str.find("CN") != std::string::npos || formats_str.find("S") != std::string::npos) { - std::vector format_vec = stringSplit(formats_str, ", "); + std::vector format_vec = stringSplit(formats_str, ", "); for (auto n : format_vec) { - if (n.compare(0, 1, "D") == 0) + if (n.compare("D") == 0) { dim_format.push_back(format_dense); } - else if (n.compare(0, 2, "CU") == 0) + else if (n.compare("CU") == 0) { dim_format.push_back(format_compressed); } - else if (n.compare(0, 2, "CN") == 0) + else if (n.compare("CN") == 0) { dim_format.push_back(format_compressednonunique); } - else if (n.compare(0, 1, "S") == 0) + else if (n.compare("S") == 0) { dim_format.push_back(format_singleton); } @@ -1066,7 +1066,7 @@ namespace mlir { /// 3D comet_debug() << " 3D\n"; /// Value dim0_format, dim1_format, dim2_format; - if (formats_str.compare(0, 3, "CSF") == 0) + if (formats_str.compare("CSF") == 0) { dim_format.push_back(format_compressed); dim_format.push_back(format_unk); @@ -1075,7 +1075,7 @@ namespace mlir dim_format.push_back(format_compressed); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 11, "ModeGeneric") == 0) + else if (formats_str.compare("ModeGeneric") == 0) { dim_format.push_back(format_compressednonunique); dim_format.push_back(format_unk); @@ -1084,7 +1084,7 @@ namespace mlir dim_format.push_back(format_dense); dim_format.push_back(format_unk); } - else if (formats_str.compare(0, 3, "COO") == 0) + else if (formats_str.compare("COO") == 0) { /// COO dim_format.push_back(format_compressednonunique); dim_format.push_back(format_unk); @@ -1095,25 +1095,190 @@ namespace mlir } else if (formats_str.find("D") != std::string::npos || formats_str.find("CU") != std::string::npos || formats_str.find("CN") != std::string::npos || formats_str.find("S") != std::string::npos) { - std::vector format_vec = stringSplit(formats_str, ", "); + std::vector format_vec = stringSplit(formats_str, ", "); comet_debug() << " format_vec.size(): " << format_vec.size() << " \n"; /// print_vector(format_vec); for (auto n : format_vec) { comet_debug() << "Current format attribute: " << n << "---\n"; - if (n.compare(0, 1, "D") == 0) + if (n.compare("D") == 0) { dim_format.push_back(format_dense); } - else if (n.compare(0, 2, "CU") == 0) + else if (n.compare("CU") == 0) { dim_format.push_back(format_compressed); } - else if (n.compare(0, 2, "CN") == 0) + else if (n.compare("CN") == 0) { dim_format.push_back(format_compressednonunique); } - else if (n.compare(0, 1, "S") == 0) + else if (n.compare("S") == 0) + { + dim_format.push_back(format_singleton); + } + else + { + llvm::errs() << "Uncorrect format attribute: " << n << "---\n"; + } + comet_debug() << " dim_format.size(): " << dim_format.size() << " \n"; + } + comet_debug() << " formats_str: " << formats_str << ", dim_format.size(): " << dim_format.size() << " \n"; + } + } + else + { + llvm::errs() << "Unsupported formats: " << formats_str << " (tensor dimes: " << rank_size << ") \n"; + } + + comet_debug() << " print dim_format: "; + for (auto n : dim_format) + { + comet_debug() << n << " "; + } + comet_debug() << "\n"; + return dim_format; + } + + // TODO (alokvk2): Not good to have this replicated 3 times. Ideally this is only used for "special" formats (i.e. CSR, COO etc.) + // And this converts it to a vector of TAFormatAttrs. + std::vector getFormats(llvm::StringRef formats_str, int rank_size, MLIRContext* ctx) + { + auto format_unk = (int32_t)TensorFormatEnum::UNK; + auto format_dense = (int32_t)TensorFormatEnum::D; + auto format_compressed = (int32_t)TensorFormatEnum::CU; + auto format_compressednonunique = (int32_t)TensorFormatEnum::CN; + auto format_singleton = (int32_t)TensorFormatEnum::S; + /// read_input_sizes_2D_f64 or read_input_sizes_3D_f64 + comet_debug() << "\n"; + std::vector dim_format; + + if (rank_size == 2) + { /// 2D + comet_debug() << " 2D\n"; + /// Value dim0_format, dim1_format; + if (formats_str.compare("CSR") == 0) + { + dim_format.push_back(format_dense); + dim_format.push_back(format_unk); + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + } + else if (formats_str.compare("DCSR") == 0) + { + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + } + else if (formats_str.compare("COO") == 0) + { /// COO + dim_format.push_back(format_compressednonunique); + dim_format.push_back(format_unk); + dim_format.push_back(format_singleton); + dim_format.push_back(format_unk); + } + else if (formats_str.compare("ELL") == 0) + { /// ELL + dim_format.push_back(format_dense); + dim_format.push_back(format_dense); + dim_format.push_back(format_singleton); + dim_format.push_back(format_unk); + } + else if (formats_str.compare("BCSR") == 0) + { /// BCSR + dim_format.push_back(format_dense); + dim_format.push_back(format_compressed); + dim_format.push_back(format_dense); + dim_format.push_back(format_dense); + } + else if (formats_str.compare("CSB") == 0) + { /// CSB + dim_format.push_back(format_dense); + dim_format.push_back(format_dense); + dim_format.push_back(format_compressed); + dim_format.push_back(format_singleton); + } + else if (formats_str.find("D") != std::string::npos || formats_str.find("CU") != std::string::npos || formats_str.find("CN") != std::string::npos || formats_str.find("S") != std::string::npos) + { + std::vector format_vec = stringSplit(formats_str, ", "); + for (auto n : format_vec) + { + if (n.compare("D") == 0) + { + dim_format.push_back(format_dense); + } + else if (n.compare("CU") == 0) + { + dim_format.push_back(format_compressed); + } + else if (n.compare("CN") == 0) + { + dim_format.push_back(format_compressednonunique); + } + else if (n.compare("S") == 0) + { + dim_format.push_back(format_singleton); + } + } + } + else + { + llvm::errs() << "Unsupported formats: " << formats_str << " (tensor dimes: " << rank_size << ") \n"; + } + } + else if (rank_size == 3) + { /// 3D + comet_debug() << " 3D\n"; + /// Value dim0_format, dim1_format, dim2_format; + if (formats_str.compare("CSF") == 0) + { + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + dim_format.push_back(format_compressed); + dim_format.push_back(format_unk); + } + else if (formats_str.compare("ModeGeneric") == 0) + { + dim_format.push_back(format_compressednonunique); + dim_format.push_back(format_unk); + dim_format.push_back(format_singleton); + dim_format.push_back(format_unk); + dim_format.push_back(format_dense); + dim_format.push_back(format_unk); + } + else if (formats_str.compare("COO") == 0) + { /// COO + dim_format.push_back(format_compressednonunique); + dim_format.push_back(format_unk); + dim_format.push_back(format_singleton); + dim_format.push_back(format_unk); + dim_format.push_back(format_singleton); + dim_format.push_back(format_unk); + } + else if (formats_str.find("D") != std::string::npos || formats_str.find("CU") != std::string::npos || formats_str.find("CN") != std::string::npos || formats_str.find("S") != std::string::npos) + { + std::vector format_vec = stringSplit(formats_str, ", "); + comet_debug() << " format_vec.size(): " << format_vec.size() << " \n"; + /// print_vector(format_vec); + for (auto n : format_vec) + { + comet_debug() << "Current format attribute: " << n << "---\n"; + if (n.compare("D") == 0) + { + dim_format.push_back(format_dense); + } + else if (n.compare("CU") == 0) + { + dim_format.push_back(format_compressed); + } + else if (n.compare("CN") == 0) + { + dim_format.push_back(format_compressednonunique); + } + else if (n.compare("S") == 0) { dim_format.push_back(format_singleton); } @@ -1176,33 +1341,34 @@ namespace mlir /// parent node can get from getUser() function, only one user since tree structure void dfsRootOpTree(Value tcRootOp, std::vector &ret) { - if (isa(tcRootOp.getDefiningOp())) - { - IndexTreeIndicesOp workspaceop = dyn_cast(tcRootOp.getDefiningOp()); - - comet_debug() << " dfsRootOpTree\n"; - comet_vdump(workspaceop); - - unsigned int sz = workspaceop.getChildren().size(); - - comet_debug() << " " << sz << " "; - ret.push_back(workspaceop); - comet_debug() << " "; - comet_vdump(workspaceop); - - for (unsigned int i = 0; i < sz; i++) - { - Value t = workspaceop.getChildren()[i]; - dfsRootOpTree(t, ret); - } - } - else if (isa(tcRootOp.getDefiningOp())) - { - indexTree::IndexTreeComputeOp leafop = dyn_cast(tcRootOp.getDefiningOp()); - /// comet_debug() << " dfsRootOpTree\n"; - comet_vdump(leafop); - ret.push_back(leafop); - } + return; + // if (isa(tcRootOp.getDefiningOp())) + // { + // IndexTreeIndicesOp workspaceop = dyn_cast(tcRootOp.getDefiningOp()); + + // comet_debug() << " dfsRootOpTree\n"; + // comet_vdump(workspaceop); + + // unsigned int sz = workspaceop.getChildren().size(); + + // comet_debug() << " " << sz << " "; + // ret.push_back(workspaceop); + // comet_debug() << " "; + // comet_vdump(workspaceop); + + // for (unsigned int i = 0; i < sz; i++) + // { + // Value t = workspaceop.getChildren()[i]; + // dfsRootOpTree(t, ret); + // } + // } + // else if (isa(tcRootOp.getDefiningOp())) + // { + // indexTree::IndexTreeComputeOp leafop = dyn_cast(tcRootOp.getDefiningOp()); + // /// comet_debug() << " dfsRootOpTree\n"; + // comet_vdump(leafop); + // ret.push_back(leafop); + // } } void getAncestorsWp(Value op, std::vector &ret /* output ancestors*/, std::vector &dfsOps) @@ -1368,99 +1534,105 @@ namespace mlir std::vector> &opPerms, std::vector> &inputOutputMapping) { - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - ArrayAttr opFormatsArrayAttr_rhs = itComputeOp_rhs.getAllFormats(); - ArrayAttr opPermsArrayAttr_rhs = itComputeOp_rhs.getAllPerms(); - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - ArrayAttr opFormatsArrayAttr_lhs = itComputeOp_lhs.getAllFormats(); - ArrayAttr opPermsArrayAttr_lhs = itComputeOp_lhs.getAllPerms(); - assert(opFormatsArrayAttr_rhs.size() == opPermsArrayAttr_rhs.size() && "not equal RHS formats size with perms size\n"); - assert(opFormatsArrayAttr_lhs.size() == opPermsArrayAttr_lhs.size() && "not equal LHS formats size with perms size\n"); - - /// Get output format, vector of vector - /// Convert ArrayAttr into - comet_debug() << "Start printing opFormats_rhs\n"; - std::vector> opFormats_rhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_rhs); - comet_debug() << "End printing opFormats_rhs\n"; - std::vector> opPerms_rhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_rhs); - std::vector> inputMapping = createInputOutputMapping(opPermsArrayAttr_rhs, true); - - comet_debug() << "Start printing opFormats_lhs\n"; - std::vector> opFormats_lhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_lhs); - comet_debug() << "End printing opFormats_lhs\n"; - std::vector> opPerms_lhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_lhs); - std::vector> outputMapping = createInputOutputMapping(opPermsArrayAttr_lhs, false); - - opFormats = opFormats_rhs; - opFormats.insert(opFormats.end(), opFormats_lhs.begin(), opFormats_lhs.end()); - opPerms = opPerms_rhs; - opPerms.insert(opPerms.end(), opPerms_lhs.begin(), opPerms_lhs.end()); - inputOutputMapping = inputMapping; - inputOutputMapping.insert(inputOutputMapping.end(), outputMapping.begin(), outputMapping.end()); + return; + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // ArrayAttr opFormatsArrayAttr_rhs = itComputeOp_rhs.getAllFormats(); + // ArrayAttr opPermsArrayAttr_rhs = itComputeOp_rhs.getAllPerms(); + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // ArrayAttr opFormatsArrayAttr_lhs = itComputeOp_lhs.getAllFormats(); + // ArrayAttr opPermsArrayAttr_lhs = itComputeOp_lhs.getAllPerms(); + // assert(opFormatsArrayAttr_rhs.size() == opPermsArrayAttr_rhs.size() && "not equal RHS formats size with perms size\n"); + // assert(opFormatsArrayAttr_lhs.size() == opPermsArrayAttr_lhs.size() && "not equal LHS formats size with perms size\n"); + + // /// Get output format, vector of vector + // /// Convert ArrayAttr into + // comet_debug() << "Start printing opFormats_rhs\n"; + // std::vector> opFormats_rhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_rhs); + // comet_debug() << "End printing opFormats_rhs\n"; + // std::vector> opPerms_rhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_rhs); + // std::vector> inputMapping = createInputOutputMapping(opPermsArrayAttr_rhs, true); + + // comet_debug() << "Start printing opFormats_lhs\n"; + // std::vector> opFormats_lhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_lhs); + // comet_debug() << "End printing opFormats_lhs\n"; + // std::vector> opPerms_lhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_lhs); + // std::vector> outputMapping = createInputOutputMapping(opPermsArrayAttr_lhs, false); + + // opFormats = opFormats_rhs; + // opFormats.insert(opFormats.end(), opFormats_lhs.begin(), opFormats_lhs.end()); + // opPerms = opPerms_rhs; + // opPerms.insert(opPerms.end(), opPerms_lhs.begin(), opPerms_lhs.end()); + // inputOutputMapping = inputMapping; + // inputOutputMapping.insert(inputOutputMapping.end(), outputMapping.begin(), outputMapping.end()); } /// Get the formats of the itCompute op void getFormatsOfComputeOp(Value computeOp, std::vector> &opFormats) { - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - ArrayAttr opFormatsArrayAttr_rhs = itComputeOp_rhs.getAllFormats(); - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - ArrayAttr opFormatsArrayAttr_lhs = itComputeOp_lhs.getAllFormats(); - - /// Get output format, vector of vector - /// Convert ArrayAttr into - std::vector> opFormats_rhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_rhs); - std::vector> opFormats_lhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_lhs); - - opFormats = opFormats_rhs; - opFormats.insert(opFormats.end(), opFormats_lhs.begin(), opFormats_lhs.end()); + return; + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // ArrayAttr opFormatsArrayAttr_rhs = itComputeOp_rhs.getAllFormats(); + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // ArrayAttr opFormatsArrayAttr_lhs = itComputeOp_lhs.getAllFormats(); + + // /// Get output format, vector of vector + // /// Convert ArrayAttr into + // std::vector> opFormats_rhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_rhs); + // std::vector> opFormats_lhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_lhs); + + // opFormats = opFormats_rhs; + // opFormats.insert(opFormats.end(), opFormats_lhs.begin(), opFormats_lhs.end()); } /// Get the rhs formats of the itCompute op void getRHSFormatsOfComputeOp(Value computeOp, std::vector> &opFormats) { - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - ArrayAttr opFormatsArrayAttr_rhs = itComputeOp_rhs.getAllFormats(); + return; + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // ArrayAttr opFormatsArrayAttr_rhs = itComputeOp_rhs.getAllFormats(); - /// Get output format, vector of vector - /// Convert ArrayAttr into - std::vector> opFormats_rhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_rhs); + // /// Get output format, vector of vector + // /// Convert ArrayAttr into + // std::vector> opFormats_rhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_rhs); - opFormats = opFormats_rhs; + // opFormats = opFormats_rhs; } /// Get the LHS formats of the itCompute op void getLHSFormatsOfComputeOp(Value computeOp, std::vector> &opFormats) { - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - ArrayAttr opFormatsArrayAttr_lhs = itComputeOp_lhs.getAllFormats(); - std::vector> opFormats_lhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_lhs); - opFormats = opFormats_lhs; + return; + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // ArrayAttr opFormatsArrayAttr_lhs = itComputeOp_lhs.getAllFormats(); + // std::vector> opFormats_lhs = convertArrayAttrStrTo2DVector(opFormatsArrayAttr_lhs); + // opFormats = opFormats_lhs; } /// Get the input tensors of the itCompute op void getInputTensorsOfComputeOp(Value computeOp, std::vector &inputTensors) { + return; /// indexTree::IndexTreeComputeOp itComputeOp = dyn_cast(computeOp.getDefiningOp()); - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - comet_debug() << " "; - comet_vdump(itComputeOp_rhs); - for (unsigned int i = 0; i < itComputeOp_rhs.getOperation()->getNumOperands(); i++) - { - comet_debug() << " "; - comet_vdump(itComputeOp_rhs.getOperation()->getOperand(i)); - inputTensors.push_back(itComputeOp_rhs.getOperation()->getOperand(i)); - } + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // comet_debug() << " "; + // comet_vdump(itComputeOp_rhs); + // for (unsigned int i = 0; i < itComputeOp_rhs.getOperation()->getNumOperands(); i++) + // { + // comet_debug() << " "; + // comet_vdump(itComputeOp_rhs.getOperation()->getOperand(i)); + // inputTensors.push_back(itComputeOp_rhs.getOperation()->getOperand(i)); + // } } /// Get the output tensors of the itCompute op void getOutputTensorsOfComputeOp(Value computeOp, std::vector &outputTensors) { - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - for (unsigned int i = 0; i < itComputeOp_lhs.getOperation()->getNumOperands(); i++) - { - outputTensors.push_back(itComputeOp_lhs.getOperation()->getOperand(i)); - } + return; + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // for (unsigned int i = 0; i < itComputeOp_lhs.getOperation()->getNumOperands(); i++) + // { + // outputTensors.push_back(itComputeOp_lhs.getOperation()->getOperand(i)); + // } } /// Get indices in current WorkspaceOp cur_op @@ -1471,186 +1643,187 @@ namespace mlir std::vector &ids /* output */, std::vector &formats /* output */) { - /// For each indices, find in each leaf, which tensor, the corresponding format - /// If in all tensors, the formats of the index are D, then D - /// If only one Sparse, then sparse - comet_debug() << " getFormatsInfo:Start Current op\n"; - comet_vdump(cur_op); - comet_debug() << " getFormatsInfo:indices.size(): " << indices.size() << "\n"; - for (unsigned long i = 0; i < indices.size(); i++) - { - comet_debug() << " getFormatsInfo:indices[" << i << "]: " << indices[i] << "\n"; - /// Info for each index - std::string format; - Value tensor; - unsigned int id; - bool isSet = false; - - std::vector formats_leafs; - std::vector tensors_leafs; - std::vector ids_leafs; - - for (unsigned long j = 0; j < leafs.size(); j++) - { - /// Info for each index in leaf[j] - comet_debug() << " getFormatsInfo:LeafOp: "; - comet_vdump(leafs[j]); - std::string format_in_leaf; - Value tensor_in_leaf; - unsigned int id_in_leaf; - bool isSetInLeaf = false; - - /// get All perms and formats info - if (indexTree::IndexTreeComputeOp leafop = dyn_cast(leafs[j].getDefiningOp())) - { - comet_debug() << " getFormatsInfo:leafs[" << j << "] is computeOp\n"; - std::vector> allFormats; - std::vector> allPerms; - std::vector> inputOutputMapping; - OpBuilder builder(leafop); - getFormatsPermsOfComputeOp(leafop, allFormats, allPerms, inputOutputMapping); - - comet_debug() << " getFormatsInfo:Allformats allFormats.size(): " << allFormats.size() << "\n"; - for (auto m : allFormats) - { - comet_debug() << " "; - for (auto n : m) - { - comet_debug() << n << " "; - } - comet_debug() << "\n"; - } - - std::vector leafop_inputTensors; - getInputTensorsOfComputeOp(leafop, leafop_inputTensors); - comet_debug() << " getFormatsInfo:leafop_inputTensors.size(): " << leafop_inputTensors.size() << "\n"; - - std::vector leafop_outputTensors; - getOutputTensorsOfComputeOp(leafop, leafop_outputTensors); - comet_debug() << " getFormatsInfo:leafop_outputTensors.size(): " << leafop_outputTensors.size() << "\n"; - - std::vector leafop_tensors = leafop_inputTensors; - leafop_tensors.insert(leafop_tensors.end(), leafop_outputTensors.begin(), leafop_outputTensors.end()); -#ifdef DEBUG_MODE_UTILS - comet_debug() << " getFormatsInfo:leafop_tensors.size(): " << leafop_tensors.size() << "\n"; - for (auto n : leafop_tensors) - { - comet_debug() << " "; - comet_vdump(n); - } -#endif - /// Check if this index is in this leaf's perms - - std::vector formats_local; - std::vector tensors_local; - std::vector ids_local; - std::vector rhs_vs_lhs; - /// This leafOp contain multiple tensors. - comet_debug() << " getFormatsInfo:allPerms.size()" << allPerms.size() << "\n"; - for (unsigned long k = 0; k < allPerms.size(); k++) - { - comet_debug() << " getFormatsInfo:allPerms[" << k << "].size(): " << allPerms[k].size() << ", print allPerms[" << k << "]: "; - print_vector(allPerms[k]); - comet_debug() << " getFormatsInfo:indices[" << i << "]: " << indices[i] << "\n"; - unsigned int idx = findIndexInVector(allPerms[k], indices[i]); - comet_debug() << " getFormatsInfo:idx: " << idx << ", allPerms[" << k << "].size(): " << allPerms[k].size() << "\n"; - if (idx < allPerms[k].size()) - { /// In tensor k - comet_debug() << " getFormatsInfo:AddingLocalFormat[" << k << "][" << idx << "]: " << allFormats[k][idx] << " "; - comet_vdump(leafop_tensors[k]); - formats_local.push_back(allFormats[k][idx]); - tensors_local.push_back(leafop_tensors[k]); - ids_local.push_back(idx); - rhs_vs_lhs.push_back(inputOutputMapping[k][idx]); - } - } - - comet_debug() << " getFormatsInfo:formats_local.size(): " << formats_local.size() << " \n"; - for (unsigned long k = 0; k < formats_local.size(); k++) - { - comet_debug() << " getFormatsInfo:formats_local[k]:" << formats_local[k] << " " << ids_local[k] << " "; - comet_vdump(tensors_local[k]); - } - - /// analyze _local arrays, to get final formats, tensors, idx - if (formats_local.size() > 0) - { - isSetInLeaf = true; - format_in_leaf = formats_local[0]; - tensor_in_leaf = tensors_local[0]; - id_in_leaf = ids_local[0]; - - for (unsigned long k = 1; k < formats_local.size(); k++) - { - if (format_in_leaf.compare(0, 1, "D") == 0 && formats_local[k].compare(0, 1, "D") != 0 && rhs_vs_lhs[k]) - /// if the next format in the local format is not dense and not output - /// rhs_vs_lhs determines if the format comes from input (lhs) or output (rhs) - /// C[i,j] = A[i,k] * B[k, j] -> i is in both input A and output C - /// -> j is in both input B and output C - /// -> k is in both inputs A and B - /// index format information stores in formats_local - { - format_in_leaf = formats_local[k]; - tensor_in_leaf = tensors_local[k]; - id_in_leaf = ids_local[k]; - break; /// Get the first sparse case - } - } - } - - } /// if(indexTree::IndexTreeComputeOp leafop - - if (isSetInLeaf) - { - comet_debug() << " getFormatsInfo:isSetInLeaf: " << isSetInLeaf << ", format_in_leaf: " << format_in_leaf << ", id_in_leaf: " << id_in_leaf << ", tensor: "; - comet_vdump(tensor_in_leaf); - formats_leafs.push_back(format_in_leaf); - tensors_leafs.push_back(tensor_in_leaf); - ids_leafs.push_back(id_in_leaf); - } - - } /// for(auto j = 0; j < leafs.size(); j++){ - - comet_debug() << " getFormatsInfo:formats_leafs.size(): " << formats_leafs.size() << "\n"; - for (unsigned long k = 0; k < formats_leafs.size(); k++) - { - comet_debug() << " getFormatsInfo:formats_leafs[k]:" << formats_leafs[k] << "\n"; - } - - /// analyze the _leafs info to get the current index format, tensor, id information - for (unsigned long j = 0; j < formats_leafs.size(); j++) - { - if (j == 0) - { - format = formats_leafs[j]; - tensor = tensors_leafs[j]; - id = ids_leafs[j]; - isSet = true; - } - else - { - if (formats_leafs[j].compare(0, 1, "D") != 0) - { /// not D - format = formats_leafs[j]; - tensor = tensors_leafs[j]; - id = ids_leafs[j]; - isSet = true; - break; /// Get the first sparse case - } - } - } - - if (isSet) - { - comet_debug() << " getFormatsInfo:EndFormat: " << format << ", id: " << id << ", tensor: "; - comet_vdump(tensor); - - formats.push_back(format); - tensors.push_back(tensor); - ids.push_back(id); - } - - } /// for(auto i = 0; i < indices.size(); i++){ + return; +// /// For each indices, find in each leaf, which tensor, the corresponding format +// /// If in all tensors, the formats of the index are D, then D +// /// If only one Sparse, then sparse +// comet_debug() << " getFormatsInfo:Start Current op\n"; +// comet_vdump(cur_op); +// comet_debug() << " getFormatsInfo:indices.size(): " << indices.size() << "\n"; +// for (unsigned long i = 0; i < indices.size(); i++) +// { +// comet_debug() << " getFormatsInfo:indices[" << i << "]: " << indices[i] << "\n"; +// /// Info for each index +// std::string format; +// Value tensor; +// unsigned int id; +// bool isSet = false; + +// std::vector formats_leafs; +// std::vector tensors_leafs; +// std::vector ids_leafs; + +// for (unsigned long j = 0; j < leafs.size(); j++) +// { +// /// Info for each index in leaf[j] +// comet_debug() << " getFormatsInfo:LeafOp: "; +// comet_vdump(leafs[j]); +// std::string format_in_leaf; +// Value tensor_in_leaf; +// unsigned int id_in_leaf; +// bool isSetInLeaf = false; + +// /// get All perms and formats info +// if (indexTree::IndexTreeComputeOp leafop = dyn_cast(leafs[j].getDefiningOp())) +// { +// comet_debug() << " getFormatsInfo:leafs[" << j << "] is computeOp\n"; +// std::vector> allFormats; +// std::vector> allPerms; +// std::vector> inputOutputMapping; +// OpBuilder builder(leafop); +// getFormatsPermsOfComputeOp(leafop, allFormats, allPerms, inputOutputMapping); + +// comet_debug() << " getFormatsInfo:Allformats allFormats.size(): " << allFormats.size() << "\n"; +// for (auto m : allFormats) +// { +// comet_debug() << " "; +// for (auto n : m) +// { +// comet_debug() << n << " "; +// } +// comet_debug() << "\n"; +// } + +// std::vector leafop_inputTensors; +// getInputTensorsOfComputeOp(leafop, leafop_inputTensors); +// comet_debug() << " getFormatsInfo:leafop_inputTensors.size(): " << leafop_inputTensors.size() << "\n"; + +// std::vector leafop_outputTensors; +// getOutputTensorsOfComputeOp(leafop, leafop_outputTensors); +// comet_debug() << " getFormatsInfo:leafop_outputTensors.size(): " << leafop_outputTensors.size() << "\n"; + +// std::vector leafop_tensors = leafop_inputTensors; +// leafop_tensors.insert(leafop_tensors.end(), leafop_outputTensors.begin(), leafop_outputTensors.end()); +// #ifdef DEBUG_MODE_UTILS +// comet_debug() << " getFormatsInfo:leafop_tensors.size(): " << leafop_tensors.size() << "\n"; +// for (auto n : leafop_tensors) +// { +// comet_debug() << " "; +// comet_vdump(n); +// } +// #endif +// /// Check if this index is in this leaf's perms + +// std::vector formats_local; +// std::vector tensors_local; +// std::vector ids_local; +// std::vector rhs_vs_lhs; +// /// This leafOp contain multiple tensors. +// comet_debug() << " getFormatsInfo:allPerms.size()" << allPerms.size() << "\n"; +// for (unsigned long k = 0; k < allPerms.size(); k++) +// { +// comet_debug() << " getFormatsInfo:allPerms[" << k << "].size(): " << allPerms[k].size() << ", print allPerms[" << k << "]: "; +// print_vector(allPerms[k]); +// comet_debug() << " getFormatsInfo:indices[" << i << "]: " << indices[i] << "\n"; +// unsigned int idx = findIndexInVector(allPerms[k], indices[i]); +// comet_debug() << " getFormatsInfo:idx: " << idx << ", allPerms[" << k << "].size(): " << allPerms[k].size() << "\n"; +// if (idx < allPerms[k].size()) +// { /// In tensor k +// comet_debug() << " getFormatsInfo:AddingLocalFormat[" << k << "][" << idx << "]: " << allFormats[k][idx] << " "; +// comet_vdump(leafop_tensors[k]); +// formats_local.push_back(allFormats[k][idx]); +// tensors_local.push_back(leafop_tensors[k]); +// ids_local.push_back(idx); +// rhs_vs_lhs.push_back(inputOutputMapping[k][idx]); +// } +// } + +// comet_debug() << " getFormatsInfo:formats_local.size(): " << formats_local.size() << " \n"; +// for (unsigned long k = 0; k < formats_local.size(); k++) +// { +// comet_debug() << " getFormatsInfo:formats_local[k]:" << formats_local[k] << " " << ids_local[k] << " "; +// comet_vdump(tensors_local[k]); +// } + +// /// analyze _local arrays, to get final formats, tensors, idx +// if (formats_local.size() > 0) +// { +// isSetInLeaf = true; +// format_in_leaf = formats_local[0]; +// tensor_in_leaf = tensors_local[0]; +// id_in_leaf = ids_local[0]; + +// for (unsigned long k = 1; k < formats_local.size(); k++) +// { +// if (format_in_leaf.compare(0, 1, "D") == 0 && formats_local[k].compare(0, 1, "D") != 0 && rhs_vs_lhs[k]) +// /// if the next format in the local format is not dense and not output +// /// rhs_vs_lhs determines if the format comes from input (lhs) or output (rhs) +// /// C[i,j] = A[i,k] * B[k, j] -> i is in both input A and output C +// /// -> j is in both input B and output C +// /// -> k is in both inputs A and B +// /// index format information stores in formats_local +// { +// format_in_leaf = formats_local[k]; +// tensor_in_leaf = tensors_local[k]; +// id_in_leaf = ids_local[k]; +// break; /// Get the first sparse case +// } +// } +// } + +// } /// if(indexTree::IndexTreeComputeOp leafop + +// if (isSetInLeaf) +// { +// comet_debug() << " getFormatsInfo:isSetInLeaf: " << isSetInLeaf << ", format_in_leaf: " << format_in_leaf << ", id_in_leaf: " << id_in_leaf << ", tensor: "; +// comet_vdump(tensor_in_leaf); +// formats_leafs.push_back(format_in_leaf); +// tensors_leafs.push_back(tensor_in_leaf); +// ids_leafs.push_back(id_in_leaf); +// } + +// } /// for(auto j = 0; j < leafs.size(); j++){ + +// comet_debug() << " getFormatsInfo:formats_leafs.size(): " << formats_leafs.size() << "\n"; +// for (unsigned long k = 0; k < formats_leafs.size(); k++) +// { +// comet_debug() << " getFormatsInfo:formats_leafs[k]:" << formats_leafs[k] << "\n"; +// } + +// /// analyze the _leafs info to get the current index format, tensor, id information +// for (unsigned long j = 0; j < formats_leafs.size(); j++) +// { +// if (j == 0) +// { +// format = formats_leafs[j]; +// tensor = tensors_leafs[j]; +// id = ids_leafs[j]; +// isSet = true; +// } +// else +// { +// if (formats_leafs[j].compare(0, 1, "D") != 0) +// { /// not D +// format = formats_leafs[j]; +// tensor = tensors_leafs[j]; +// id = ids_leafs[j]; +// isSet = true; +// break; /// Get the first sparse case +// } +// } +// } + +// if (isSet) +// { +// comet_debug() << " getFormatsInfo:EndFormat: " << format << ", id: " << id << ", tensor: "; +// comet_vdump(tensor); + +// formats.push_back(format); +// tensors.push_back(tensor); +// ids.push_back(id); +// } + +// } /// for(auto i = 0; i < indices.size(); i++){ } /// Find leaves of tcRootOp in the Index Tree (dfsOp). @@ -1664,36 +1837,37 @@ namespace mlir std::vector &dfsOps, std::vector &ret /* output leaves */) { - std::vector> allAncestors(dfsOps.size()); - for (unsigned int i = 0; i < dfsOps.size(); i++) - { - if (IndexTreeComputeOp cur_op = dyn_cast(dfsOps[i].getDefiningOp())) - { - getAncestorsWp(dfsOps[i], allAncestors[i] /* output ancestors */, dfsOps); - comet_debug() << " print allAncestors[" << i << "]: "; - print_vector_value(allAncestors[i]); - } - } - - /// Each wp op in which tensors - if (IndexTreeIndicesOp cur_op = dyn_cast(tcRootOp.getDefiningOp())) - { - comet_debug() << " "; - comet_vdump(tcRootOp); - for (unsigned int j = 0; j < dfsOps.size(); j++) - { - auto idx = findIndexInVector_Value(allAncestors[j], tcRootOp); - if (idx < allAncestors[j].size()) - { - if (indexTree::IndexTreeComputeOp cur_op = dyn_cast(dfsOps[j].getDefiningOp())) - { - ret.push_back(dfsOps[j]); - comet_debug() << " "; - comet_vdump(dfsOps[j]); - } - } - } - } + return; + // std::vector> allAncestors(dfsOps.size()); + // for (unsigned int i = 0; i < dfsOps.size(); i++) + // { + // if (IndexTreeComputeOp cur_op = dyn_cast(dfsOps[i].getDefiningOp())) + // { + // getAncestorsWp(dfsOps[i], allAncestors[i] /* output ancestors */, dfsOps); + // comet_debug() << " print allAncestors[" << i << "]: "; + // print_vector_value(allAncestors[i]); + // } + // } + + // /// Each wp op in which tensors + // if (IndexTreeIndicesOp cur_op = dyn_cast(tcRootOp.getDefiningOp())) + // { + // comet_debug() << " "; + // comet_vdump(tcRootOp); + // for (unsigned int j = 0; j < dfsOps.size(); j++) + // { + // auto idx = findIndexInVector_Value(allAncestors[j], tcRootOp); + // if (idx < allAncestors[j].size()) + // { + // if (indexTree::IndexTreeComputeOp cur_op = dyn_cast(dfsOps[j].getDefiningOp())) + // { + // ret.push_back(dfsOps[j]); + // comet_debug() << " "; + // comet_vdump(dfsOps[j]); + // } + // } + // } + // } } /// new version for new children ops @@ -1736,61 +1910,65 @@ namespace mlir /// Get the output tensors of the itCompute op void getTensorsOfComputeOp(Value computeOp, std::vector &tensors) { - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - comet_debug() << " "; - comet_vdump(itComputeOp_rhs); - for (unsigned int i = 0; i < itComputeOp_rhs.getOperation()->getNumOperands(); i++) - { - comet_debug() << " "; - comet_vdump(itComputeOp_rhs.getOperation()->getOperand(i)); - tensors.push_back(itComputeOp_rhs.getOperation()->getOperand(i)); - } - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - for (unsigned int i = 0; i < itComputeOp_lhs.getOperation()->getNumOperands(); i++) - { - tensors.push_back(itComputeOp_lhs.getOperation()->getOperand(i)); - } + return; + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // comet_debug() << " "; + // comet_vdump(itComputeOp_rhs); + // for (unsigned int i = 0; i < itComputeOp_rhs.getOperation()->getNumOperands(); i++) + // { + // comet_debug() << " "; + // comet_vdump(itComputeOp_rhs.getOperation()->getOperand(i)); + // tensors.push_back(itComputeOp_rhs.getOperation()->getOperand(i)); + // } + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // for (unsigned int i = 0; i < itComputeOp_lhs.getOperation()->getNumOperands(); i++) + // { + // tensors.push_back(itComputeOp_lhs.getOperation()->getOperand(i)); + // } } /// Get the perms and formats of the itCompute op void getRHSPermsOfComputeOp(Value computeOp, std::vector> &opPerms) { - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - ArrayAttr opPermsArrayAttr_rhs = itComputeOp_rhs.getAllPerms(); - /// Get output format, vector of vector - /// Convert ArrayAttr into - std::vector> opPerms_rhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_rhs); - opPerms = opPerms_rhs; + return; + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // ArrayAttr opPermsArrayAttr_rhs = itComputeOp_rhs.getAllPerms(); + // /// Get output format, vector of vector + // /// Convert ArrayAttr into + // std::vector> opPerms_rhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_rhs); + // opPerms = opPerms_rhs; } /// Get the perms and formats of the itCompute op void getLHSPermsOfComputeOp(Value computeOp, std::vector> &opPerms) { - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - ArrayAttr opPermsArrayAttr_lhs = itComputeOp_lhs.getAllPerms(); + return; + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // ArrayAttr opPermsArrayAttr_lhs = itComputeOp_lhs.getAllPerms(); - /// Get output format, vector of vector - /// Convert ArrayAttr into - std::vector> opPerms_lhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_lhs); + // /// Get output format, vector of vector + // /// Convert ArrayAttr into + // std::vector> opPerms_lhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_lhs); - opPerms = opPerms_lhs; + // opPerms = opPerms_lhs; } /// Get the perms and formats of the itCompute op void getPermsOfComputeOp(Value computeOp, std::vector> &opPerms) { - indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); - ArrayAttr opPermsArrayAttr_rhs = itComputeOp_rhs.getAllPerms(); - indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); - ArrayAttr opPermsArrayAttr_lhs = itComputeOp_lhs.getAllPerms(); - - /// Get output format, vector of vector - /// Convert ArrayAttr into - std::vector> opPerms_rhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_rhs); - std::vector> opPerms_lhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_lhs); - - opPerms = opPerms_rhs; - opPerms.insert(opPerms.end(), opPerms_lhs.begin(), opPerms_lhs.end()); + return; + // indexTree::IndexTreeComputeRHSOp itComputeOp_rhs = dyn_cast(computeOp.getDefiningOp()->getOperand(0).getDefiningOp()); + // ArrayAttr opPermsArrayAttr_rhs = itComputeOp_rhs.getAllPerms(); + // indexTree::IndexTreeComputeLHSOp itComputeOp_lhs = dyn_cast(computeOp.getDefiningOp()->getOperand(1).getDefiningOp()); + // ArrayAttr opPermsArrayAttr_lhs = itComputeOp_lhs.getAllPerms(); + + // /// Get output format, vector of vector + // /// Convert ArrayAttr into + // std::vector> opPerms_rhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_rhs); + // std::vector> opPerms_lhs = convertArrayAttrIntTo2DVector(opPermsArrayAttr_lhs); + + // opPerms = opPerms_rhs; + // opPerms.insert(opPerms.end(), opPerms_lhs.begin(), opPerms_lhs.end()); } double loopCostHeuristic(const std::vector &loopOrder, size_t dim_,