From f5ecab880015040d0f3bb36b13c9a475275febf3 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 25 Jun 2026 09:52:13 -0700 Subject: [PATCH] split-preprocessing: use new preprocessing dialect infrastructure This change modifies split-preprocessing to use the new preprocessing dialect for its output, and updates the pipelines to use preprocessing-to-memref. It changes the `split-preprocessing` flag from a numeric batching factor to a boolean enable/disable. It also adds a small amount of fixes to the lattigo lowering (which otherwise supports bufferization and memrefs). Fixes #2960 PiperOrigin-RevId: 938040235 --- .../Conversions/PreprocessingToLattigo/BUILD | 1 + .../PreprocessingToLattigo.cpp | 5 + .../Conversions/PreprocessingToMemref/BUILD | 1 + .../PreprocessingToMemref.cpp | 5 + .../Conversions/PreprocessingToOpenfhe/BUILD | 1 + .../PreprocessingToOpenfhe.cpp | 5 + .../Preprocessing/IR/PreprocessingOps.td | 6 +- lib/Dialect/Preprocessing/Transforms/BUILD | 1 + .../Transforms/ValidatePreprocessing.cpp | 49 +- .../ArithmeticPipelineRegistration.cpp | 16 +- .../ArithmeticPipelineRegistration.h | 10 +- lib/Pipelines/BUILD | 3 + lib/Target/Lattigo/BUILD | 1 - lib/Target/Lattigo/LattigoEmitter.cpp | 25 +- lib/Transforms/SplitPreprocessing/BUILD | 8 +- .../SplitPreprocessing/SplitPreprocessing.cpp | 622 ++++++++++-------- .../SplitPreprocessing/SplitPreprocessing.td | 6 +- .../Transforms/validate_errors.mlir | 15 +- .../secret_to_bgv/hamming_distance_1024.mlir | 12 +- .../secret_to_ckks/hamming_distance_1024.mlir | 12 +- tests/Examples/lattigo/ckks/batchnorm1d/BUILD | 2 +- .../lattigo/ckks/bicyclic_matmul/BUILD | 2 +- .../lattigo/ckks/conv1d_dilated/BUILD | 2 +- .../lattigo/ckks/conv2d_dilated/BUILD | 2 +- .../lattigo/ckks/matvec_512x784/BUILD | 3 +- tests/Examples/lattigo/ckks/mnist/BUILD | 4 +- tests/Examples/lattigo/ckks/pooling/BUILD | 2 +- tests/Examples/lattigo/ckks/pooling1d/BUILD | 2 +- .../Examples/lattigo/ckks/preprocessing/BUILD | 2 +- .../lattigo/ckks/preprocessing/matvec_test.go | 4 +- .../openfhe/ckks/batchnorm_sigmoid/BUILD | 2 +- .../openfhe/ckks/halevi_shoup_matvec/BUILD | 2 +- .../Examples/openfhe/ckks/preprocessing/BUILD | 2 +- .../matvec_non_power_two_crash.mlir | 3 +- .../Transforms/split_preprocessing/args.mlir | 33 +- .../split_preprocessing/double_use.mlir | 4 +- .../split_preprocessing/if_else_encode.mlir | 15 +- .../iv_dependent_encode.mlir | 12 +- .../split_preprocessing/linalg.mlir | 8 +- .../split_preprocessing/loop_encode.mlir | 18 +- .../split_preprocessing/many_plaintexts.mlir | 9 +- .../split_preprocessing/matvec_crash.mlir | 4 +- .../split_preprocessing/options.mlir | 11 +- .../pure_preproc_affine_loop.mlir | 41 ++ .../pure_preproc_loop.mlir | 12 +- .../split_preprocessing/region.mlir | 6 +- .../split_preprocessing.mlir | 12 +- 47 files changed, 613 insertions(+), 410 deletions(-) create mode 100644 tests/Transforms/split_preprocessing/pure_preproc_affine_loop.mlir diff --git a/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/BUILD b/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/BUILD index ae35baf57d..7993dcfce5 100644 --- a/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/BUILD +++ b/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/BUILD @@ -16,6 +16,7 @@ cc_library( "@heir//lib/Dialect/Lattigo/IR:Dialect", "@heir//lib/Dialect/Preprocessing/Conversions:Util", "@heir//lib/Dialect/Preprocessing/IR:Dialect", + "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/PreprocessingToLattigo.cpp b/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/PreprocessingToLattigo.cpp index c3e636482e..1134a1f69c 100644 --- a/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/PreprocessingToLattigo.cpp +++ b/lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/PreprocessingToLattigo.cpp @@ -8,6 +8,7 @@ #include "lib/Dialect/Preprocessing/Conversions/Util.h" #include "lib/Dialect/Preprocessing/IR/PreprocessingDialect.h" #include "lib/Utils/ConversionUtils.h" +#include "lib/Utils/Utils.h" #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,6 +32,10 @@ struct PreprocessingToLattigo void runOnOperation() override { ModuleOp module = getOperation(); + if (!containsDialects(module)) { + return; + } + PreprocessingStorageLayoutAnalysis analysis(module); if (!analysis.isValid()) { signalPassFailure(); diff --git a/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/BUILD b/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/BUILD index 9ceebca286..01cb2c1e1f 100644 --- a/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/BUILD +++ b/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/BUILD @@ -15,6 +15,7 @@ cc_library( "@heir//lib/Analysis/PreprocessingStorageLayoutAnalysis", "@heir//lib/Dialect/Preprocessing/Conversions:Util", "@heir//lib/Dialect/Preprocessing/IR:Dialect", + "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", diff --git a/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/PreprocessingToMemref.cpp b/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/PreprocessingToMemref.cpp index 189be55450..7fda61d0f6 100644 --- a/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/PreprocessingToMemref.cpp +++ b/lib/Dialect/Preprocessing/Conversions/PreprocessingToMemref/PreprocessingToMemref.cpp @@ -13,6 +13,7 @@ #include "lib/Dialect/Preprocessing/IR/PreprocessingOps.h" #include "lib/Dialect/Preprocessing/IR/PreprocessingTypes.h" #include "lib/Utils/ConversionUtils.h" +#include "lib/Utils/Utils.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project @@ -214,6 +215,10 @@ struct PreprocessingToMemref void runOnOperation() override { ModuleOp module = getOperation(); + if (!containsDialects(module)) { + return; + } + PreprocessingStorageLayoutAnalysis analysis(module); if (!analysis.isValid()) { signalPassFailure(); diff --git a/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/BUILD b/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/BUILD index 9605d40608..e4d6596be7 100644 --- a/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/BUILD +++ b/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/BUILD @@ -16,6 +16,7 @@ cc_library( "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Dialect/Preprocessing/Conversions:Util", "@heir//lib/Dialect/Preprocessing/IR:Dialect", + "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/PreprocessingToOpenfhe.cpp b/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/PreprocessingToOpenfhe.cpp index 85ae5d431f..6ab215d194 100644 --- a/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/PreprocessingToOpenfhe.cpp +++ b/lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/PreprocessingToOpenfhe.cpp @@ -8,6 +8,7 @@ #include "lib/Dialect/Preprocessing/Conversions/Util.h" #include "lib/Dialect/Preprocessing/IR/PreprocessingDialect.h" #include "lib/Utils/ConversionUtils.h" +#include "lib/Utils/Utils.h" #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,6 +32,10 @@ struct PreprocessingToOpenfhe void runOnOperation() override { ModuleOp module = getOperation(); + if (!containsDialects(module)) { + return; + } + PreprocessingStorageLayoutAnalysis analysis(module); if (!analysis.isValid()) { signalPassFailure(); diff --git a/lib/Dialect/Preprocessing/IR/PreprocessingOps.td b/lib/Dialect/Preprocessing/IR/PreprocessingOps.td index 92cc999b57..ec410fe621 100644 --- a/lib/Dialect/Preprocessing/IR/PreprocessingOps.td +++ b/lib/Dialect/Preprocessing/IR/PreprocessingOps.td @@ -14,7 +14,7 @@ class Preprocessing_Op traits = []> : let cppNamespace = "::mlir::heir::preprocessing"; } -def Preprocessing_EmptyOp : Preprocessing_Op<"empty"> { +def Preprocessing_EmptyOp : Preprocessing_Op<"empty", [MemoryEffectsOpInterface]> { let summary = "Allocate an empty preprocessing storage"; let description = [{ This op creates a new `preprocessing.storage` value with an undetermined size. @@ -23,7 +23,7 @@ def Preprocessing_EmptyOp : Preprocessing_Op<"empty"> { let assemblyFormat = "attr-dict `:` type($storage)"; } -def Preprocessing_StoreOp : Preprocessing_Op<"store"> { +def Preprocessing_StoreOp : Preprocessing_Op<"store", [MemoryEffectsOpInterface]> { let summary = "Store a value to preprocessing storage at given indices and site_id"; let description = [{ This op stores a value into a `preprocessing.storage`, while tracking: @@ -49,7 +49,7 @@ def Preprocessing_StoreOp : Preprocessing_Op<"store"> { let assemblyFormat = "$value `,` $storage `[` $indices `]` `site` $site_id `<` $element_type `>` attr-dict `:` type($value) `,` type($storage)"; } -def Preprocessing_LoadOp : Preprocessing_Op<"load"> { +def Preprocessing_LoadOp : Preprocessing_Op<"load", [MemoryEffectsOpInterface]> { let summary = "Load a value from preprocessing storage at given indices and site_id"; let description = [{ This op loads a value from a `preprocessing.storage`, while tracking: diff --git a/lib/Dialect/Preprocessing/Transforms/BUILD b/lib/Dialect/Preprocessing/Transforms/BUILD index bf024c1975..6b6b31b838 100644 --- a/lib/Dialect/Preprocessing/Transforms/BUILD +++ b/lib/Dialect/Preprocessing/Transforms/BUILD @@ -23,6 +23,7 @@ cc_library( ":pass_inc_gen", "@heir//lib/Dialect/Preprocessing/IR:Dialect", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.cpp b/lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.cpp index 25a446cc84..b6c2c230c9 100644 --- a/lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.cpp +++ b/lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.cpp @@ -6,12 +6,13 @@ #include "lib/Dialect/Preprocessing/IR/PreprocessingOps.h" #include "lib/Dialect/Preprocessing/IR/PreprocessingTypes.h" -#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project -#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace heir { @@ -25,30 +26,36 @@ struct ValidatePreprocessing using ValidatePreprocessingBase::ValidatePreprocessingBase; void runOnOperation() override { - SmallVector emptyOps; + Operation* module = getOperation(); + + bool hasMultipleEmpties = false; + module->walk([&](func::FuncOp funcOp) { + SmallVector emptyOps(funcOp.getOps()); + if (emptyOps.size() > 1) { + for (size_t i = 1; i < emptyOps.size(); ++i) { + auto diag = + emptyOps[i]->emitOpError() + << "more than one preprocessing.empty allocation in function"; + diag.attachNote(emptyOps[0]->getLoc()) << "previous allocation here"; + } + hasMultipleEmpties = true; + } + }); + if (hasMultipleEmpties) { + signalPassFailure(); + } + DenseMap> storesBySite; DenseMap> loadsBySite; - getOperation()->walk([&](Operation* op) { - if (auto emptyOp = dyn_cast(op)) { - emptyOps.push_back(emptyOp); - } else if (auto storeOp = dyn_cast(op)) { + module->walk([&](Operation* op) { + if (auto storeOp = dyn_cast(op)) { storesBySite[storeOp.getSiteId()].push_back(storeOp); } else if (auto loadOp = dyn_cast(op)) { loadsBySite[loadOp.getSiteId()].push_back(loadOp); } }); - if (emptyOps.size() > 1) { - for (size_t i = 1; i < emptyOps.size(); ++i) { - auto diag = emptyOps[i]->emitOpError() - << "more than one preprocessing.empty allocation in module"; - diag.attachNote(emptyOps[0]->getLoc()) << "previous allocation here"; - } - signalPassFailure(); - // Do NOT return, continue to site pairing checks - } - std::set siteIds; for (const auto& [siteId, stores] : storesBySite) { siteIds.insert(siteId); diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index d8e9669b47..ed170ca489 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -16,6 +16,9 @@ #include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h" #include "lib/Dialect/Openfhe/Transforms/CountAddAndKeySwitch.h" #include "lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.h" +#include "lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/PreprocessingToLattigo.h" +#include "lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/PreprocessingToOpenfhe.h" +#include "lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.h" #include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h" #include "lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h" #include "lib/Dialect/Secret/Conversions/SecretToModArith/SecretToModArith.h" @@ -142,6 +145,8 @@ void cleanupAfterLowerAssignLayout(OpPassManager& pm) { pm.addPass(createCSEPass()); pm.addPass(createRemoveDeadValuesPass()); pm.addPass(createSymbolDCEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); } // Implement layout conversions as shift networks @@ -423,9 +428,10 @@ void mlirToRLWEPipeline(OpPassManager& pm, pm.addPass(lwe::createImplementTrivialEncryptionAsAddition()); // Add a __preprocessed helper for offline pre-packing of plaintexts - auto splitPreprocessingOptions = SplitPreprocessingOptions{}; - splitPreprocessingOptions.maxReturnValues = options.splitPreprocessing; - pm.addPass(createSplitPreprocessing(splitPreprocessingOptions)); + if (options.enableSplitPreprocessing) { + pm.addPass(createSplitPreprocessing()); + pm.addPass(preprocessing::createValidatePreprocessing()); + } ElementwiseToAffineOptions elementwiseOptions; elementwiseOptions.convertDialects = {"ckks", "bgv", "lwe"}; @@ -471,6 +477,7 @@ BackendPipelineBuilder toOpenFhePipelineBuilder() { // Convert LWE (and scheme-specific CKKS/BGV ops) to OpenFHE pm.addPass(lwe::createLWEToOpenfhe()); + pm.addPass(preprocessing::createPreprocessingToOpenfhe()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); @@ -513,6 +520,7 @@ BackendPipelineBuilder toLattigoPipelineBuilder() { // Convert LWE (and scheme-specific BGV ops) to Lattigo pm.addPass(lwe::createLWEToLattigo()); + pm.addPass(preprocessing::createPreprocessingToLattigo()); // Convert Alloc Ops to InPlace Ops pm.addPass(lattigo::createAllocToInPlace()); @@ -567,7 +575,7 @@ void torchLinalgToCkksBuilder(OpPassManager& manager, suboptions.ckksBootstrapWaterline = options.ckksBootstrapWaterline; suboptions.scalingModBits = options.scalingModBits; suboptions.firstModBits = options.firstModBits; - suboptions.splitPreprocessing = options.splitPreprocessing; + suboptions.enableSplitPreprocessing = options.enableSplitPreprocessing; suboptions.experimentalDisableLoopUnroll = options.experimentalDisableLoopUnroll; suboptions.usePublicKey = options.usePublicKey; diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index b26ceaf6f8..a97d3d80de 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -103,11 +103,11 @@ struct MlirToRLWEPipelineOptions : public LoopOptions { llvm::cl::desc("File name to import execution result from (c.f. --secret-" "import-execution-result)"), llvm::cl::init("")}; - PassOptions::Option splitPreprocessing{ - *this, "split-preprocessing", - llvm::cl::desc("Split preprocessing into separate function with N return " - "values (default to no split)"), - llvm::cl::init(16)}; + PassOptions::Option enableSplitPreprocessing{ + *this, "enable-split-preprocessing", + llvm::cl::desc( + "Split server-side plaintext preprocessing into a separate function"), + llvm::cl::init(true)}; }; struct PlaintextBackendOptions diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index 291455c515..54aa4d25f1 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -115,6 +115,9 @@ cc_library( "@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext", "@heir//lib/Dialect/Openfhe/Transforms:CountAddAndKeySwitch", "@heir//lib/Dialect/Openfhe/Transforms:FastRotationPrecompute", + "@heir//lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo", + "@heir//lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe", + "@heir//lib/Dialect/Preprocessing/Transforms", "@heir//lib/Dialect/Secret/Conversions/SecretToBGV", "@heir//lib/Dialect/Secret/Conversions/SecretToCGGI", "@heir//lib/Dialect/Secret/Conversions/SecretToCKKS", diff --git a/lib/Target/Lattigo/BUILD b/lib/Target/Lattigo/BUILD index e241f3bcdf..28bcdcdbd1 100644 --- a/lib/Target/Lattigo/BUILD +++ b/lib/Target/Lattigo/BUILD @@ -29,7 +29,6 @@ cc_library( "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index 090bc5033f..2d311bc2cd 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -200,11 +200,15 @@ LogicalResult LattigoEmitter::printOperation(func::FuncOp funcOp) { // name and arg std::string funcName = funcOp.getName().str(); if (funcOp->hasAttr(kClientPackFuncAttrName)) { - if (!funcName.empty() && funcFilter && funcFilter(funcOp) && - std::islower(funcName[0])) { - // Upper case this only if we are emitting it in a filtered file so it - // needs exporting. - funcName[0] = std::toupper(funcName[0]); + if (!funcName.empty() && funcFilter && funcFilter(funcOp)) { + while (funcName[0] == '_') { + funcName.erase(0, 1); + } + if (std::islower(funcName[0])) { + // Upper case this only if we are emitting it in a filtered file so it + // needs exporting. + funcName[0] = std::toupper(funcName[0]); + } } } os << "func " << funcName << "("; @@ -286,11 +290,16 @@ LogicalResult LattigoEmitter::printOperation(func::CallOp op) { auto moduleOp = op->getParentOfType(); auto calleeOp = moduleOp.lookupSymbol(callee); std::string calleeName = canonicalizeDebugPort(callee).str(); - if (calleeOp && funcFilter && !funcFilter(calleeOp)) { - if (calleeOp->hasAttr(kClientPackFuncAttrName)) { - if (!calleeName.empty() && std::islower(calleeName[0])) { + if (calleeOp && calleeOp->hasAttr(kClientPackFuncAttrName)) { + if (funcFilter && !calleeName.empty()) { + while (calleeName[0] == '_') { + calleeName.erase(0, 1); + } + if (std::islower(calleeName[0])) { calleeName[0] = std::toupper(calleeName[0]); } + } + if (funcFilter && !funcFilter(calleeOp)) { calleeName = packageName + "_utils." + calleeName; extraImportsUsed = true; } diff --git a/lib/Transforms/SplitPreprocessing/BUILD b/lib/Transforms/SplitPreprocessing/BUILD index 36d8f4d618..746a65afc1 100644 --- a/lib/Transforms/SplitPreprocessing/BUILD +++ b/lib/Transforms/SplitPreprocessing/BUILD @@ -12,21 +12,19 @@ cc_library( hdrs = ["SplitPreprocessing.h"], deps = [ ":pass_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Dialect:ModuleAttributes", "@heir//lib/Dialect/LWE/IR:Dialect", - "@heir//lib/Dialect/Secret/IR:SecretPatterns", - "@heir//lib/Dialect/TensorExt/IR:Dialect", - "@heir//lib/Utils", + "@heir//lib/Dialect/Preprocessing/IR:Dialect", "@heir//lib/Utils:AttributeUtils", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/lib/Transforms/SplitPreprocessing/SplitPreprocessing.cpp b/lib/Transforms/SplitPreprocessing/SplitPreprocessing.cpp index 2e40883554..7ecbc2785e 100644 --- a/lib/Transforms/SplitPreprocessing/SplitPreprocessing.cpp +++ b/lib/Transforms/SplitPreprocessing/SplitPreprocessing.cpp @@ -1,18 +1,17 @@ #include "lib/Transforms/SplitPreprocessing/SplitPreprocessing.h" +#include #include -#include +#include -#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Dialect/ModuleAttributes.h" +#include "lib/Dialect/Preprocessing/IR/PreprocessingOps.h" +#include "lib/Dialect/Preprocessing/IR/PreprocessingTypes.h" +#include "lib/Utils/AttributeUtils.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/STLFunctionalExtras.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project -#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -21,18 +20,24 @@ #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Region.h" // from @llvm-project #include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project #define DEBUG_TYPE "split-preprocessing" @@ -46,146 +51,121 @@ using func::FuncOp; namespace { -int ceilDiv(int a, int b) { return (a + b - 1) / b; } +// Walk the IR to collect the set of operations that need to be cloned, upstream +// of a set of Encode ops, including recursing into regions and parents. +static SetVector computeDependenciesToClone( + ArrayRef encodeOps, FuncOp containingFunc) { + assert(containingFunc && "Boundary containingFunc must not be null"); + SetVector opsToClone; + SmallVector worklist; -// The following slice analysis code is copied verbatim from -// https://github.com/llvm/circt/blob/main/lib/Dialect/SV/Transforms/SVExtractTestCode.cpp -// except where noted. + // Dedupe when adding to the worklist. + auto pushAsNeeded = [&](Operation* op) { + if (op && op != containingFunc.getOperation() && opsToClone.insert(op)) { + worklist.push_back(op); + } + }; -// Reimplemented from SliceAnalysis to use a worklist rather than -// recursion and non-insert ordered set. Implement this as a DFS and not a BFS -// so that the order is stable across changes to intermediary operations. (It -// is then necessary to use the _operands_ as a worklist and not the -// _operations_.) -static void getBackwardSliceSimple( - Operation* rootOp, SetVector& backwardSlice, - llvm::function_ref filter) { - SmallVector worklist(rootOp->getOperands()); + // Seed worklist + for (Operation* op : encodeOps) { + pushAsNeeded(op); + } while (!worklist.empty()) { - Value operand = worklist.pop_back_val(); - Operation* definingOp = operand.getDefiningOp(); - - if (!definingOp || - definingOp->hasTrait()) - continue; - - // Evaluate whether we should keep this def. - // This is useful in particular to implement scoping; i.e. return the - // transitive backwardSlice in the current scope. - if (filter && !filter(definingOp)) continue; - - if (definingOp) { - if (!backwardSlice.contains(definingOp)) - for (auto newOperand : llvm::reverse(definingOp->getOperands())) - worklist.push_back(newOperand); - } else if (auto blockArg = dyn_cast(operand)) { - Block* block = blockArg.getOwner(); - Operation* parentOp = block->getParentOp(); - // Determine whether we want to recurse backward into the other - // blocks of parentOp, which are not technically backward unless they - // flow into us. For now, just bail. - assert(parentOp->getNumRegions() == 1 && - parentOp->getRegion(0).getBlocks().size() == 1); - if (!backwardSlice.contains(parentOp)) - for (auto newOperand : llvm::reverse(parentOp->getOperands())) - worklist.push_back(newOperand); - } else { - llvm_unreachable("No definingOp and not a block argument."); - } + Operation* current = worklist.pop_back_val(); - backwardSlice.insert(definingOp); - } -} + // Preserve enclosing control flow (up to FuncOp) + pushAsNeeded(current->getParentOp()); -// Compute the ops defining the blocks a set of ops are in. -static void blockSlice(SetVector& ops, - SetVector& blocks) { - for (auto op : ops) { - while (!isa(op->getParentOp())) { - op = op->getParentOp(); - blocks.insert(op); + // If the current operation contains regions, push their terminators + // to propagate to the containing region. + for (Region& region : current->getRegions()) { + for (Block& block : region) { + if (Operation* terminator = block.getTerminator()) { + pushAsNeeded(terminator); + } + } } - // Differing from SV implementation: Add ops within regions to the worklist - // to ensure that the dataflow to operands used within regions is captured. - op->walk([&](Operation* op) { blocks.insert(op); }); - } -} -static void computeSlice(SetVector& roots, - SetVector& results, - llvm::function_ref filter) { - for (auto* op : roots) getBackwardSliceSimple(op, results, filter); -} + // Trace backward through operands + for (Value operand : current->getOperands()) { + if (Operation* definingOp = operand.getDefiningOp()) { + pushAsNeeded(definingOp); + } else if (auto blockArg = dyn_cast(operand)) { + pushAsNeeded(blockArg.getOwner()->getParentOp()); + } + } + } -// Return a backward slice started from `roots` until dataflow reaches to an -// operations for which `filter` returns false. -static SetVector getBackwardSlice( - SetVector& roots, llvm::function_ref filter) { - SetVector results; - computeSlice(roots, results, filter); - - // Get Blocks - SetVector blocks; - blockSlice(roots, blocks); - blockSlice(results, blocks); - - // Make sure dataflow to block args (if conds, etc) and ops within regions are - // included - computeSlice(blocks, results, filter); - // Differing from SV implementation: don't insert the operations within the - // regions (since their parent op is already in the set). - results.insert(roots.begin(), roots.end()); - return results; + return opsToClone; } struct PreprocessingAnalysis { - // Groups of encode ops that will be returned together. - SmallVector> encodeOps; - // Use a vector to preserve insertion order for stable function signatures. + SmallVector encodeOps; SetVector inputs; SetVector opsToClone; }; +// Encode ops can be elementwise-mappable, so if we're encoding a tensor we need +// to extract and store each element separately. This is unsupported for now +// because the storage memref analysis assumes each encode op corresponds to a +// single plaintext, but a tensor<1x!pt> is OK. +static bool isAllowedPlaintextType(Type type) { + if (isa(type)) return true; + if (auto shapedTy = dyn_cast(type)) { + auto shape = shapedTy.getShape(); + return shape.size() == 1 && shape[0] == 1; + } + return false; +} + struct SplitPreprocessingPass : impl::SplitPreprocessingBase { using SplitPreprocessingBase::SplitPreprocessingBase; void runOnOperation() override { Operation* root = getOperation(); + + // Annotate each encode op with a stable site id + int32_t encodeId = 0; + root->walk([&](PlaintextEncodeOpInterface op) { + op->setAttr( + "split_preprocessing_site_id", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), encodeId++)); + }); + root->walk([&](func::FuncOp op) { if (op.isDeclaration() || isClientHelper(op)) { return; } convertFunc(op); }); + + clearAttrs(root, "split_preprocessing_site_id"); } void convertFunc(FuncOp funcOp) { - if (maxReturnValues == 0) return; - - PreprocessingAnalysis analysis = - analyzePreprocessing(funcOp, maxReturnValues); + PreprocessingAnalysis analysis = analyzePreprocessing(funcOp); if (analysis.encodeOps.empty()) { - LLVM_DEBUG({ - llvm::dbgs() << "Failed to split preprocessing for " << funcOp.getName() - << "\n"; - }); return; } FailureOr maybePreprocessingFuncOp = createPreprocessingFunction(funcOp, analysis); if (failed(maybePreprocessingFuncOp)) { - LLVM_DEBUG({ - llvm::dbgs() << "Failed to create preprocessing function for " - << funcOp.getName() << "\n"; - }); + signalPassFailure(); return; } FuncOp preprocessingFuncOp = maybePreprocessingFuncOp.value(); - FuncOp preprocessedFuncOp = + + FailureOr maybePreprocessedFuncOp = createPreprocessedFunction(funcOp, preprocessingFuncOp, analysis); + if (failed(maybePreprocessedFuncOp)) { + preprocessingFuncOp.erase(); + signalPassFailure(); + return; + } + FuncOp preprocessedFuncOp = maybePreprocessedFuncOp.value(); ModuleOp moduleOp = funcOp->getParentOfType(); moduleOp.insert(funcOp.getOperation(), preprocessingFuncOp); @@ -193,12 +173,18 @@ struct SplitPreprocessingPass updateOriginalFunc(funcOp, preprocessingFuncOp, preprocessedFuncOp, analysis); + + // Remove dead values to clean up the created/updated functions + OpPassManager pipeline("func.func"); + pipeline.addPass(createRemoveDeadValuesPass()); + (void)runPipeline(pipeline, preprocessingFuncOp); + (void)runPipeline(pipeline, preprocessedFuncOp); + (void)runPipeline(pipeline, funcOp); } void updateOriginalFunc(FuncOp funcOp, FuncOp preprocessingFuncOp, FuncOp preprocessedFuncOp, const PreprocessingAnalysis& analysis) { - // Add calls to the preprocessing and preprocessed functions. OpBuilder builder(funcOp.getContext()); builder.setInsertionPointToStart(&funcOp.getBody().front()); @@ -206,14 +192,9 @@ struct SplitPreprocessingPass func::CallOp::create(builder, funcOp.getLoc(), preprocessingFuncOp, llvm::to_vector(analysis.inputs)); - SmallVector preprocessedArgs; - for (auto arg : funcOp.getArguments()) { - if (isa(getElementTypeOrSelf(arg.getType()))) { - preprocessedArgs.push_back(arg); - } - } - preprocessedArgs.append(preprocessingCall.getResults().begin(), - preprocessingCall.getResults().end()); + SmallVector preprocessedArgs(funcOp.getArguments().begin(), + funcOp.getArguments().end()); + preprocessedArgs.push_back(preprocessingCall.getResult(0)); auto preprocessedCall = func::CallOp::create( builder, funcOp.getLoc(), preprocessedFuncOp, preprocessedArgs); @@ -221,8 +202,6 @@ struct SplitPreprocessingPass Operation* originalTerminator = originalEntry->getTerminator(); originalTerminator->setOperands(preprocessedCall.getResults()); - // At this point all operations should have been moved and we can remove all - // ops but the calls. DenseSet opsToKeep = {preprocessingCall, preprocessedCall, originalTerminator}; for (Operation& op : llvm::make_early_inc_range( @@ -231,32 +210,84 @@ struct SplitPreprocessingPass } } + // Recursively clone an op and all ops in nested regions that are in the + // opsToClone filter. + static Operation* recursiveCloneOpWithFilter( + Operation* srcOp, IRMapping& mapper, + const SetVector& opsToClone, OpBuilder& builder) { + // 1. Clone the operation shell without its nested regions + Operation* targetOp = builder.cloneWithoutRegions(*srcOp, mapper); + + // 2. Recursively populate all blocks and regions + for (auto [srcRegion, targetRegion] : + llvm::zip(srcOp->getRegions(), targetOp->getRegions())) { + for (Block& srcBlock : srcRegion) { + // Append a new block to the target region + Block* targetBlock = builder.createBlock(&targetRegion); + // Map block arguments + for (BlockArgument arg : srcBlock.getArguments()) { + targetBlock->addArgument(arg.getType(), arg.getLoc()); + mapper.map(arg, targetBlock->getArguments().back()); + } + + // Selectively clone all operations in this block that belong to our + // live slice + for (Operation& childOp : srcBlock) { + if (opsToClone.contains(&childOp)) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(targetBlock); + recursiveCloneOpWithFilter(&childOp, mapper, opsToClone, builder); + } + } + } + } + return targetOp; + } + + // Get the set of induction variable indices in a loop nest surrounding an + // encode op. + SmallVector getContextualLoopIndices(Operation* encodeOp, + Operation* maxParent) { + SmallVector indices; + Operation* parent = encodeOp->getParentOp(); + while (parent && parent != maxParent) { + if (auto loopLikeOp = dyn_cast(parent)) { + auto inductionVar = loopLikeOp.getSingleInductionVar(); + if (inductionVar.has_value()) indices.push_back(*inductionVar); + } + parent = parent->getParentOp(); + } + std::reverse(indices.begin(), indices.end()); + return indices; + } + FailureOr createPreprocessingFunction( FuncOp op, const PreprocessingAnalysis& analysis) { MLIRContext* context = op.getContext(); OpBuilder builder(context); SmallVector newInputs; + newInputs.reserve(analysis.inputs.size()); for (const auto& input : analysis.inputs) { newInputs.push_back(input.getType()); } - SmallVector newResults; - for (const auto& batchOfEncodeOps : analysis.encodeOps) { - if (batchOfEncodeOps.empty()) { - llvm::outs() << "batch of encode ops is empty\n"; - } - Type type = batchOfEncodeOps[0]->getResult(0).getType(); - int batchSize = batchOfEncodeOps.size(); - newResults.push_back(RankedTensorType::get({batchSize}, type)); + // Create a preprocessing.storage type with the set plaintext types + DenseSet encodeTypesDeduped; + for (Operation* input : analysis.encodeOps) { + encodeTypesDeduped.insert( + getElementTypeOrSelf(input->getResult(0).getType())); } + SmallVector encodeTypes(encodeTypesDeduped.begin(), + encodeTypesDeduped.end()); + auto storageTy = + preprocessing::PreprocessingStorageType::get(context, encodeTypes); - auto funcType = FunctionType::get(context, newInputs, newResults); + // Create the new func and annotate it appropriately + auto funcType = FunctionType::get(context, newInputs, {storageTy}); auto funcName = op.getName().str() + "__preprocessing"; auto funcOp = FuncOp::create(op.getLoc(), funcName, funcType); funcOp.setVisibility(op.getVisibility()); - // Add a special attribute to the new function so that - // later passes can reference the original function. funcOp->setAttr( kClientPackFuncAttrName, builder.getDictionaryAttr({ @@ -264,55 +295,134 @@ struct SplitPreprocessingPass builder.getStringAttr(op.getName())), })); + // Set up the operation cloning infra: map the analysis-identified inputs to + // the new func's block arguments IRMapping map; Block* entryBlock = funcOp.addEntryBlock(); for (auto [idx, input] : llvm::enumerate(analysis.inputs)) { map.map(input, entryBlock->getArgument(idx)); } - // Insert the preprocessing ops into the new function. builder.setInsertionPointToEnd(entryBlock); - for (auto& op : op.getOps()) { - if (!analysis.opsToClone.contains(&op)) { - continue; + auto emptyOp = + preprocessing::EmptyOp::create(builder, funcOp.getLoc(), storageTy); + Value storage = emptyOp.getStorage(); + + // When cloning a loop to the preprocessing function, the loop may have a + // ciphertext iter_arg, whose initializer and other upstream ops we don't + // want to clone. However, the op still needs a valid SSA value for its + // initializer. In this case, we materialize an op of the right type + // via UnrealizedConversionCast (which can create an op of any type from + // nothing), and later allow remove-dead-values to clean up the iter_arg, + // since it will be naturally unused. + for (Operation* opToClone : analysis.opsToClone) { + for (Value operand : opToClone->getOperands()) { + if (!map.contains(operand) && + !analysis.opsToClone.contains(operand.getDefiningOp())) { + if (isa(operand) && + analysis.opsToClone.contains( + cast(operand).getOwner()->getParentOp())) { + continue; + } + Value dummy = + UnrealizedConversionCastOp::create( + builder, funcOp.getLoc(), operand.getType(), ValueRange{}) + .getResult(0); + map.map(operand, dummy); + } } - builder.clone(op, map); } - SmallVector results; - for (const auto& batchOfEncodeOps : analysis.encodeOps) { - SmallVector batchResults; - for (const auto& encodeOp : batchOfEncodeOps) { - batchResults.push_back(map.lookup(encodeOp->getResult(0))); + + // Clone the ops into the preprocessing func. Note that we could instead + // clone the entire func and prune ops we don't want to clone, but this is + // more efficient. + for (auto [srcRegion, targetRegion] : + llvm::zip(op->getRegions(), funcOp->getRegions())) { + for (Block& srcBlock : srcRegion) { + Block* targetBlock; + if (&srcBlock == &srcRegion.front()) { + targetBlock = &targetRegion.front(); + } else { + targetBlock = builder.createBlock(&targetRegion); + for (BlockArgument arg : srcBlock.getArguments()) { + targetBlock->addArgument(arg.getType(), arg.getLoc()); + map.map(arg, targetBlock->getArguments().back()); + } + } + + for (Operation& childOp : srcBlock) { + if (analysis.opsToClone.contains(&childOp)) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(targetBlock); + recursiveCloneOpWithFilter(&childOp, map, analysis.opsToClone, + builder); + } + } } - results.push_back(tensor::FromElementsOp::create(builder, funcOp.getLoc(), - batchResults)); } - func::ReturnOp::create(builder, funcOp.getLoc(), results); + // Insert StoreOps after all cloned EncodeOps + SmallVector clonedEncodes; + funcOp.walk([&](PlaintextEncodeOpInterface clonedOp) { + clonedEncodes.push_back(clonedOp); + }); + + for (Operation* clonedEncode : clonedEncodes) { + auto siteIdAttr = clonedEncode->getAttrOfType( + "split_preprocessing_site_id"); + assert( + siteIdAttr && + "Expected split_preprocessing_site_id attribute on cloned encode op"); + int32_t siteId = siteIdAttr.getInt(); + + SmallVector indices = + getContextualLoopIndices(clonedEncode, funcOp.getOperation()); + builder.setInsertionPointAfter(clonedEncode); + Value valToStore = clonedEncode->getResult(0); + Type elemStorageTy = valToStore.getType(); + + if (!isAllowedPlaintextType(elemStorageTy)) { + Location loc = clonedEncode->getLoc(); + funcOp.erase(); + return mlir::emitError(loc) + << "'lwe.rlwe_encode' op with result type " << elemStorageTy + << " unsupported in split-preprocessing."; + } + + // After isAllowedPlaintextType, only a tensor<1x!pt> is allowed here, so + // extract it and store it. + if (auto tensorTy = dyn_cast(elemStorageTy)) { + elemStorageTy = tensorTy.getElementType(); + Value zero = + arith::ConstantIndexOp::create(builder, clonedEncode->getLoc(), 0); + valToStore = tensor::ExtractOp::create(builder, clonedEncode->getLoc(), + valToStore, ValueRange{zero}); + } + + preprocessing::StoreOp::create( + builder, clonedEncode->getLoc(), valToStore, storage, indices, + builder.getI32IntegerAttr(siteId), TypeAttr::get(elemStorageTy)); + } + + builder.setInsertionPointToEnd(&funcOp.getRegion().back()); + func::ReturnOp::create(builder, funcOp.getLoc(), ValueRange{storage}); return funcOp; } - FuncOp createPreprocessedFunction(FuncOp op, FuncOp preprocessingFuncOp, - const PreprocessingAnalysis& analysis) { + FailureOr createPreprocessedFunction( + FuncOp op, FuncOp preprocessingFuncOp, + const PreprocessingAnalysis& analysis) { MLIRContext* context = op.getContext(); OpBuilder builder(context); - SmallVector inputTypes; - for (auto argType : op.getArgumentTypes()) { - if (isa(getElementTypeOrSelf(argType))) { - inputTypes.push_back(argType); - } - } - for (auto preprocessingResult : - preprocessingFuncOp.getFunctionType().getResults()) { - inputTypes.push_back(preprocessingResult); - } + // Add the new preprocessing.storage type as a function argument + SmallVector inputTypes(op.getArgumentTypes().begin(), + op.getArgumentTypes().end()); + inputTypes.push_back(preprocessingFuncOp.getResultTypes()[0]); auto funcType = FunctionType::get(context, inputTypes, op.getResultTypes()); auto funcName = op.getName().str() + "__preprocessed"; auto funcOp = FuncOp::create(op.getLoc(), funcName, funcType); funcOp.setVisibility(op.getVisibility()); - // Add a special attribute to the new function so that - // later passes can reference the original function. funcOp->setAttr( kClientPreprocessedFuncAttrName, builder.getDictionaryAttr({ @@ -320,154 +430,116 @@ struct SplitPreprocessingPass builder.getStringAttr(op.getName())), })); - // Build the main function body in the preprocessed function. IRMapping map; Block* entryBlock = funcOp.addEntryBlock(); - int index = 0; - for (auto arg : op.getArguments()) { - if (isa(getElementTypeOrSelf(arg.getType()))) { - map.map(arg, entryBlock->getArgument(index++)); - } + for (auto [idx, arg] : llvm::enumerate(op.getArguments())) { + map.map(arg, entryBlock->getArgument(idx)); } - // Map each of the encode ops to a tensor extract of a block argument. + Value storageArg = entryBlock->getArguments().back(); + builder.setInsertionPointToEnd(entryBlock); - for (const auto& batchOfEncodeOps : analysis.encodeOps) { - auto plaintextTensorArg = entryBlock->getArgument(index++); - for (auto [i, encodeOp] : llvm::enumerate(batchOfEncodeOps)) { - LLVM_DEBUG({ - llvm::dbgs() << "mapping encode op: " << *encodeOp - << " to index: " << i << "of block argument at index " - << plaintextTensorArg.getArgNumber() << "\n"; - }); - auto idx = arith::ConstantIndexOp::create(builder, funcOp.getLoc(), i); - map.map(encodeOp->getResult(0), - tensor::ExtractOp::create(builder, funcOp.getLoc(), - encodeOp->getResult(0).getType(), - plaintextTensorArg, {idx})); + for (auto& toClone : op.getOps()) { + Operation* clonedOp = builder.clone(toClone, map); + for (auto [idx, res] : llvm::enumerate(toClone.getResults())) { + map.map(res, clonedOp->getResult(idx)); } } - for (auto& toClone : op.getOps()) { - if (analysis.opsToClone.contains(&toClone) && - !toClone.hasTrait()) { - // Perhaps this is brittle? The for loop intends to copy all ops that - // are not exclusively used for pre-processing. Otherwise we would need - // to iterate over all uses of the ops. - continue; + SmallVector clonedEncodes; + funcOp.walk([&](PlaintextEncodeOpInterface clonedOp) { + clonedEncodes.push_back(clonedOp); + }); + + for (Operation* clonedEncode : clonedEncodes) { + auto siteIdAttr = clonedEncode->getAttrOfType( + "split_preprocessing_site_id"); + assert( + siteIdAttr && + "Expected split_preprocessing_site_id attribute on cloned encode op"); + int32_t siteId = siteIdAttr.getInt(); + + SmallVector indices = + getContextualLoopIndices(clonedEncode, funcOp.getOperation()); + builder.setInsertionPointAfter(clonedEncode); + Type resultTy = clonedEncode->getResult(0).getType(); + + if (!isAllowedPlaintextType(resultTy)) { + Location loc = clonedEncode->getLoc(); + funcOp.erase(); + return mlir::emitError(loc) + << "'lwe.rlwe_encode' op with result type " << resultTy + << " unsupported in split-preprocessing."; + } + + Type loadTy = getElementTypeOrSelf(resultTy); + auto loadOp = preprocessing::LoadOp::create( + builder, clonedEncode->getLoc(), loadTy, storageArg, indices, + builder.getI32IntegerAttr(siteId), TypeAttr::get(loadTy)); + Value loadedVal = loadOp.getResult(); + + // After isAllowedPlaintextType, only a tensor<1x!pt> is supported, so + // reconstruct it from the loaded value. + if (isa(resultTy)) { + loadedVal = tensor::FromElementsOp::create( + builder, clonedEncode->getLoc(), resultTy, ValueRange{loadedVal}); } - builder.clone(toClone, map); + + clonedEncode->getResult(0).replaceAllUsesWith(loadedVal); + clonedEncode->erase(); } + return funcOp; } - PreprocessingAnalysis analyzePreprocessing(FuncOp funcOp, - int maxReturnValues) { + PreprocessingAnalysis analyzePreprocessing(FuncOp funcOp) { PreprocessingAnalysis analysis; - SetVector encodeOps; - for (auto encodeOp : funcOp.getBody().getOps()) { - encodeOps.insert(encodeOp); - } - if (encodeOps.empty()) { - return analysis; - } - - // Group into batches to ensure that the number of new return values doesn't - // exceed the limit. - int maxBucketSize = 0; - DenseMap> encodeOpsByType; - for (const auto& encodeOp : encodeOps) { - Type type = encodeOp->getResult(0).getType(); - if (!encodeOpsByType.contains(type)) { - encodeOpsByType[type] = {}; - } - encodeOpsByType[type].push_back(encodeOp); - if (encodeOpsByType[type].size() > maxBucketSize) { - maxBucketSize = encodeOpsByType[type].size(); + funcOp.walk([&](Operation* op) { + if (isa(op)) { + analysis.encodeOps.push_back(op); } - } - // If there are more types than the max return values, then we have to bail - // since we can't group values with different types into a tensor. - if (encodeOpsByType.size() > maxReturnValues) { - LLVM_DEBUG({ - llvm::dbgs() << "Too many types of encode ops in " << funcOp.getName() - << "\n"; - }); + }); + if (analysis.encodeOps.empty()) { return analysis; } - int numRemainingBuckets = maxReturnValues - encodeOpsByType.size(); - DenseMap numBucketsByType; - for (const auto& [type, encodeOps] : encodeOpsByType) { - numBucketsByType[type] = 1; - } - while (numRemainingBuckets > 0 && maxBucketSize != 1) { - // Add a bucket to the first largest bucket. - int largestBucketSize = 0; - Type maxBucketType; - for (const auto& [type, numOps] : encodeOpsByType) { - int buckets = numBucketsByType[type]; - auto bucketSize = ceilDiv(numOps.size(), buckets); - if (bucketSize > largestBucketSize) { - maxBucketType = type; - largestBucketSize = bucketSize; - } - } - numBucketsByType[maxBucketType]++; - numRemainingBuckets -= 1; - maxBucketSize = ceilDiv(encodeOpsByType[maxBucketType].size(), - numBucketsByType[maxBucketType]); - } + analysis.opsToClone = + computeDependenciesToClone(analysis.encodeOps, funcOp); - for (const auto& [type, encodeOps] : encodeOpsByType) { - int bucketSize = ceilDiv(encodeOps.size(), numBucketsByType[type]); - // The bucket size is the ceiling division, so it's the size of the - // largest bucket in the group. So eagerly group as many bucketSize - // buckets as possible to minimize the number of buckets. - int numBuckets = ceilDiv(encodeOps.size(), bucketSize); - LLVM_DEBUG({ - llvm::dbgs() << "adding " << numBuckets << " buckets for type " << type - << " with " << encodeOps.size() << " encode ops\n"; - }); - for (int i = 0; i < numBuckets; i++) { - int start = i * bucketSize; - int end = start + bucketSize; - SetVector batchOps; - for (int j = start; j < end && j < encodeOps.size(); j++) { - batchOps.insert(encodeOps[j]); + // Gather any required block arguments for inputs. These are all expected to + // be args of the func.func, but we filter out Ciphertext types. + for (auto* op : analysis.opsToClone) { + for (auto arg : op->getOperands()) { + auto argOp = arg.getDefiningOp(); + if (analysis.opsToClone.count(argOp) || + isa(getElementTypeOrSelf(arg.getType()))) { + continue; } - analysis.encodeOps.push_back(batchOps); - } - } - analysis.opsToClone = getBackwardSlice(encodeOps, [&](Operation* op) { - return op->getParentRegion() == &funcOp.getRegion(); - }); + if (auto blockArg = dyn_cast(arg)) { + Operation* parentOp = blockArg.getOwner()->getParentOp(); + if (analysis.opsToClone.contains(parentOp)) { + continue; + } + + if (parentOp != funcOp.getOperation()) { + parentOp->emitWarning() + << "split-preprocessing identified a block argument input " + << " that should be cloned into the preprocessing function, " + "but it was not a block argument of the parent func " + << funcOp.getName() << ". The input was " << blockArg; + } + } - // Gather any required block arguments for inputs. - for (auto* op : analysis.opsToClone) { - for (auto arg : op->getOperands()) { - auto argOp = arg.getDefiningOp(); // may be null - if (!analysis.opsToClone.count(argOp)) analysis.inputs.insert(arg); + analysis.inputs.insert(arg); } } - LLVM_DEBUG({ - llvm::dbgs() << "Adding inputs for preprocessing:\n"; - for (auto input : analysis.inputs) { - llvm::dbgs() << "\t - " << input << "\n"; - } - }); - return analysis; } }; } // namespace -std::unique_ptr createSplitPreprocessingPass() { - return std::make_unique(); -} - } // namespace heir } // namespace mlir diff --git a/lib/Transforms/SplitPreprocessing/SplitPreprocessing.td b/lib/Transforms/SplitPreprocessing/SplitPreprocessing.td index a105dc4abd..eddb95928d 100644 --- a/lib/Transforms/SplitPreprocessing/SplitPreprocessing.td +++ b/lib/Transforms/SplitPreprocessing/SplitPreprocessing.td @@ -25,9 +25,9 @@ def SplitPreprocessing : Pass<"split-preprocessing", "mlir::ModuleOp"> { (* example filepath=tests/Transforms/split_preprocessing/split_preprocessing.mlir *) }]; - let options = [ - Option<"maxReturnValues", "max-return-values", "int", /*default=*/"8", "Use " - "this to restrict the maximum return values of the preprocessing function.">, + let dependentDialects = [ + "mlir::heir::preprocessing::PreprocessingDialect", + "mlir::func::FuncDialect" ]; } diff --git a/tests/Dialect/Preprocessing/Transforms/validate_errors.mlir b/tests/Dialect/Preprocessing/Transforms/validate_errors.mlir index e19820b6ff..472370c7dd 100644 --- a/tests/Dialect/Preprocessing/Transforms/validate_errors.mlir +++ b/tests/Dialect/Preprocessing/Transforms/validate_errors.mlir @@ -3,7 +3,7 @@ func.func @multiple_allocations() { // expected-note@+1 {{previous allocation here}} %storage1 = preprocessing.empty : !preprocessing.storage - // expected-error@+1 {{more than one preprocessing.empty allocation in module}} + // expected-error@+1 {{more than one preprocessing.empty allocation in function}} %storage2 = preprocessing.empty : !preprocessing.storage return } @@ -121,3 +121,16 @@ func.func @mismatched_storage_type(%arg0: i32, %arg1: index, %storage1: !preproc %res = preprocessing.load %storage2[%arg1] site 0 : !preprocessing.storage, i32 return } + +// ----- + +// Multiple allocations are allowed if they are in different functions. +func.func @multiple_allocations_ok_1() { + %storage = preprocessing.empty : !preprocessing.storage + return +} + +func.func @multiple_allocations_ok_2() { + %storage = preprocessing.empty : !preprocessing.storage + return +} diff --git a/tests/Dialect/Secret/Conversions/secret_to_bgv/hamming_distance_1024.mlir b/tests/Dialect/Secret/Conversions/secret_to_bgv/hamming_distance_1024.mlir index c0592dcf51..ea52b3170c 100644 --- a/tests/Dialect/Secret/Conversions/secret_to_bgv/hamming_distance_1024.mlir +++ b/tests/Dialect/Secret/Conversions/secret_to_bgv/hamming_distance_1024.mlir @@ -1,14 +1,20 @@ -// RUN: heir-opt --annotate-module="backend=lattigo" --mlir-to-bgv=split-preprocessing=0 %s | FileCheck %s +// RUN: heir-opt --annotate-module="backend=lattigo" --mlir-to-bgv=enable-split-preprocessing %s | FileCheck %s -// CHECK: @hamming +// CHECK: @hamming__preprocessing +// CHECK: lwe.rlwe_encode + +// CHECK: @hamming__preprocessed // CHECK: bgv.sub // CHECK: bgv.mul // CHECK: bgv.relinearize // CHECK-COUNT-10: bgv.rotate_cols // CHECK: bgv.modulus_switch -// CHECK: lwe.rlwe_encode // CHECK: bgv.mul_plain // CHECK: bgv.modulus_switch + +// CHECK: @hamming +// CHECK: call @hamming__preprocessing +// CHECK: call @hamming__preprocessed // CHECK: return func.func @hamming(%arg0: !secret.secret>, %arg1: !secret.secret>) -> !secret.secret { diff --git a/tests/Dialect/Secret/Conversions/secret_to_ckks/hamming_distance_1024.mlir b/tests/Dialect/Secret/Conversions/secret_to_ckks/hamming_distance_1024.mlir index 7ef890a52b..b7bc210848 100644 --- a/tests/Dialect/Secret/Conversions/secret_to_ckks/hamming_distance_1024.mlir +++ b/tests/Dialect/Secret/Conversions/secret_to_ckks/hamming_distance_1024.mlir @@ -1,14 +1,20 @@ -// RUN: heir-opt --annotate-module="backend=lattigo" --mlir-to-ckks=split-preprocessing=0 %s | FileCheck %s +// RUN: heir-opt --annotate-module="backend=lattigo" --mlir-to-ckks=enable-split-preprocessing %s | FileCheck %s -// CHECK: @hamming +// CHECK: @hamming__preprocessing +// CHECK: lwe.rlwe_encode + +// CHECK: @hamming__preprocessed // CHECK: ckks.sub // CHECK: ckks.mul // CHECK: ckks.relinearize // CHECK-COUNT-10: ckks.rotate // CHECK: ckks.rescale -// CHECK: lwe.rlwe_encode // CHECK: ckks.mul_plain // CHECK: ckks.rescale + +// CHECK: @hamming +// CHECK: call @hamming__preprocessing +// CHECK: call @hamming__preprocessed // CHECK: return func.func @hamming(%arg0: !secret.secret>, %arg1: !secret.secret>) -> !secret.secret { diff --git a/tests/Examples/lattigo/ckks/batchnorm1d/BUILD b/tests/Examples/lattigo/ckks/batchnorm1d/BUILD index 22a7bd1390..500c7148c8 100644 --- a/tests/Examples/lattigo/ckks/batchnorm1d/BUILD +++ b/tests/Examples/lattigo/ckks/batchnorm1d/BUILD @@ -9,7 +9,7 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", "--linalg-canonicalizations", - "--torch-linalg-to-ckks=ciphertext-degree=2048 split-preprocessing=1", + "--torch-linalg-to-ckks=ciphertext-degree=2048", "--scheme-to-lattigo", ], mlir_src = "@heir//tests/Examples/common:batchnorm1d.mlir", diff --git a/tests/Examples/lattigo/ckks/bicyclic_matmul/BUILD b/tests/Examples/lattigo/ckks/bicyclic_matmul/BUILD index fadb1ab273..e8a0ed5fbd 100644 --- a/tests/Examples/lattigo/ckks/bicyclic_matmul/BUILD +++ b/tests/Examples/lattigo/ckks/bicyclic_matmul/BUILD @@ -7,7 +7,7 @@ heir_lattigo_lib( go_library_name = "bicyclicmatmul", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - "--mlir-to-ckks=ciphertext-degree=2048 split-preprocessing=0", + "--mlir-to-ckks=ciphertext-degree=2048 enable-split-preprocessing=false", "--scheme-to-lattigo", ], mlir_src = "@heir//tests/Examples/common:bicyclic_matmul.mlir", diff --git a/tests/Examples/lattigo/ckks/conv1d_dilated/BUILD b/tests/Examples/lattigo/ckks/conv1d_dilated/BUILD index 38aa6d6061..ccda7c5223 100644 --- a/tests/Examples/lattigo/ckks/conv1d_dilated/BUILD +++ b/tests/Examples/lattigo/ckks/conv1d_dilated/BUILD @@ -8,7 +8,7 @@ heir_lattigo_lib( go_library_name = "conv1d_dilated", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - "--torch-linalg-to-ckks=ciphertext-degree=1024 scaling-mod-bits=45 first-mod-bits=60 split-preprocessing=1", + "--torch-linalg-to-ckks=ciphertext-degree=1024 scaling-mod-bits=45 first-mod-bits=60", "--scheme-to-lattigo", ], mlir_src = "conv1d_dilated.mlir", diff --git a/tests/Examples/lattigo/ckks/conv2d_dilated/BUILD b/tests/Examples/lattigo/ckks/conv2d_dilated/BUILD index 6e88f8441c..a1f8913072 100644 --- a/tests/Examples/lattigo/ckks/conv2d_dilated/BUILD +++ b/tests/Examples/lattigo/ckks/conv2d_dilated/BUILD @@ -8,7 +8,7 @@ heir_lattigo_lib( go_library_name = "conv2d_dilated", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - "--torch-linalg-to-ckks=ciphertext-degree=1024 scaling-mod-bits=45 first-mod-bits=60 split-preprocessing=1", + "--torch-linalg-to-ckks=ciphertext-degree=1024 scaling-mod-bits=45 first-mod-bits=60", "--scheme-to-lattigo", ], mlir_src = "conv2d_dilated.mlir", diff --git a/tests/Examples/lattigo/ckks/matvec_512x784/BUILD b/tests/Examples/lattigo/ckks/matvec_512x784/BUILD index 0055d774a0..8b6f5e440e 100644 --- a/tests/Examples/lattigo/ckks/matvec_512x784/BUILD +++ b/tests/Examples/lattigo/ckks/matvec_512x784/BUILD @@ -8,8 +8,7 @@ heir_lattigo_lib( go_library_name = "matvec512x784", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - # TODO(#2960): remove split-preprocessing=0 once split-preprocessing handles loops - "--mlir-to-ckks=ciphertext-degree=1024 modulus-switch-after-mul=true experimental-disable-loop-unroll=true level-budget=40 first-mod-bits=55 split-preprocessing=0", + "--mlir-to-ckks=ciphertext-degree=1024 modulus-switch-after-mul=true experimental-disable-loop-unroll=true level-budget=40 first-mod-bits=55", "--scheme-to-lattigo", ], mlir_src = "@heir//tests/Examples/common:matvec_512x784.mlir", diff --git a/tests/Examples/lattigo/ckks/mnist/BUILD b/tests/Examples/lattigo/ckks/mnist/BUILD index a583512ae6..8a32578738 100644 --- a/tests/Examples/lattigo/ckks/mnist/BUILD +++ b/tests/Examples/lattigo/ckks/mnist/BUILD @@ -8,8 +8,8 @@ heir_lattigo_lib( go_library_name = "mnist", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - # TODO(#2960): remove split-preprocessing=0 once split-preprocessing handles loops - "--torch-linalg-to-ckks=ciphertext-degree=1024 modulus-switch-after-mul=true experimental-disable-loop-unroll=true level-budget=40 first-mod-bits=55 split-preprocessing=0", + # TODO(#2960): remove enable-split-preprocessing=false once split-preprocessing handles loops + "--torch-linalg-to-ckks=ciphertext-degree=1024 modulus-switch-after-mul=true experimental-disable-loop-unroll=true level-budget=40 first-mod-bits=55 enable-split-preprocessing=false", "--scheme-to-lattigo", ], mlir_src = "@heir//tests/Examples/common/mnist:mnist.mlir", diff --git a/tests/Examples/lattigo/ckks/pooling/BUILD b/tests/Examples/lattigo/ckks/pooling/BUILD index a098efbbd2..9c95a458e1 100644 --- a/tests/Examples/lattigo/ckks/pooling/BUILD +++ b/tests/Examples/lattigo/ckks/pooling/BUILD @@ -8,7 +8,7 @@ heir_lattigo_lib( go_library_name = "pooling", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - "--torch-linalg-to-ckks=ciphertext-degree=4096 scaling-mod-bits=55 first-mod-bits=60 split-preprocessing=1", + "--torch-linalg-to-ckks=ciphertext-degree=4096 scaling-mod-bits=55 first-mod-bits=60", "--scheme-to-lattigo", ], mlir_src = "pooling.mlir", diff --git a/tests/Examples/lattigo/ckks/pooling1d/BUILD b/tests/Examples/lattigo/ckks/pooling1d/BUILD index 4c20ee2d91..7f66faac9b 100644 --- a/tests/Examples/lattigo/ckks/pooling1d/BUILD +++ b/tests/Examples/lattigo/ckks/pooling1d/BUILD @@ -8,7 +8,7 @@ heir_lattigo_lib( go_library_name = "pooling1d", heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", - "--torch-linalg-to-ckks=ciphertext-degree=4096 scaling-mod-bits=55 first-mod-bits=60 split-preprocessing=1", + "--torch-linalg-to-ckks=ciphertext-degree=4096 scaling-mod-bits=55 first-mod-bits=60", "--scheme-to-lattigo", ], mlir_src = "pooling1d.mlir", diff --git a/tests/Examples/lattigo/ckks/preprocessing/BUILD b/tests/Examples/lattigo/ckks/preprocessing/BUILD index 5650dcf785..aecad78582 100644 --- a/tests/Examples/lattigo/ckks/preprocessing/BUILD +++ b/tests/Examples/lattigo/ckks/preprocessing/BUILD @@ -8,7 +8,7 @@ heir_lattigo_lib( go_library_name = "matvec", heir_opt_flags = [ "--annotate-module=scheme=ckks backend=lattigo", - "--torch-linalg-to-ckks=split-preprocessing=2", + "--torch-linalg-to-ckks", "--scheme-to-lattigo", ], mlir_src = "matvec.mlir", diff --git a/tests/Examples/lattigo/ckks/preprocessing/matvec_test.go b/tests/Examples/lattigo/ckks/preprocessing/matvec_test.go index fad308a8fa..58dca08494 100644 --- a/tests/Examples/lattigo/ckks/preprocessing/matvec_test.go +++ b/tests/Examples/lattigo/ckks/preprocessing/matvec_test.go @@ -34,11 +34,11 @@ func TestMatvecSplit(t *testing.T) { ct0 := matvec__encrypt__arg0(evaluator, params, ecd, enc, arg0) // Call preprocessing separately - v2, v3 := matvec_utils.Matvec__preprocessing(params, ecd, arg0) + storage := matvec_utils.Matvec__preprocessing(params, ecd, arg0) startTime := time.Now() // Call preprocessed function separately - resultCt := matvec__preprocessed(evaluator, params, ecd, ct0, v2, v3) + resultCt := matvec__preprocessed(evaluator, params, ecd, ct0, arg0, storage) duration := time.Since(startTime) fmt.Printf("matvec__preprocessed call took: %v\n", duration) diff --git a/tests/Examples/openfhe/ckks/batchnorm_sigmoid/BUILD b/tests/Examples/openfhe/ckks/batchnorm_sigmoid/BUILD index 723e03f146..855e191bd9 100644 --- a/tests/Examples/openfhe/ckks/batchnorm_sigmoid/BUILD +++ b/tests/Examples/openfhe/ckks/batchnorm_sigmoid/BUILD @@ -8,7 +8,7 @@ openfhe_end_to_end_test( heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=ckks", "--activation-canonicalizations", - "--mlir-to-ckks=ciphertext-degree=1024 split-preprocessing=0", + "--mlir-to-ckks=ciphertext-degree=1024 enable-split-preprocessing=false", "--scheme-to-openfhe", ], heir_translate_flags = [], diff --git a/tests/Examples/openfhe/ckks/halevi_shoup_matvec/BUILD b/tests/Examples/openfhe/ckks/halevi_shoup_matvec/BUILD index 2c4e4a4a7e..6804790194 100644 --- a/tests/Examples/openfhe/ckks/halevi_shoup_matvec/BUILD +++ b/tests/Examples/openfhe/ckks/halevi_shoup_matvec/BUILD @@ -27,7 +27,7 @@ openfhe_interpreter_test( heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=ckks", # Split preprocessing is not supported with the openfhe interpreter backend - "--torch-linalg-to-ckks=ciphertext-degree=8192 split-preprocessing=0", + "--torch-linalg-to-ckks=ciphertext-degree=8192 enable-split-preprocessing=false", "--scheme-to-openfhe", "--inline", ], diff --git a/tests/Examples/openfhe/ckks/preprocessing/BUILD b/tests/Examples/openfhe/ckks/preprocessing/BUILD index d160713e6c..659bd214a8 100644 --- a/tests/Examples/openfhe/ckks/preprocessing/BUILD +++ b/tests/Examples/openfhe/ckks/preprocessing/BUILD @@ -10,7 +10,7 @@ openfhe_lib( generated_lib_header = "mnist_openfhe_lib.inc.h", heir_opt_flags = [ "--annotate-module=backend=openfhe scheme=ckks", - "--torch-linalg-to-ckks=ciphertext-degree=1024 split-preprocessing=8", + "--torch-linalg-to-ckks=ciphertext-degree=1024", "--scheme-to-openfhe", ], mlir_src = "@heir//tests/Examples/common/mnist:mnist.mlir", diff --git a/tests/Regression/matvec_non_power_two_crash.mlir b/tests/Regression/matvec_non_power_two_crash.mlir index e405aa4836..3e51c66aeb 100644 --- a/tests/Regression/matvec_non_power_two_crash.mlir +++ b/tests/Regression/matvec_non_power_two_crash.mlir @@ -1,6 +1,5 @@ // RUN: heir-opt %s --annotate-module="backend=lattigo scheme=ckks" --mlir-to-ckks="ciphertext-degree=1024 modulus-switch-after-mul=true experimental-disable-loop-unroll=true level-budget=40 first-mod-bits=55" --scheme-to-lattigo -// TODO(#2960): remove XFAIL and update test appropriately -// XFAIL: * + module { func.func @matvec(%arg0 : tensor<33xf32> {secret.secret}, %matrix: tensor<32x33xf32>) -> tensor<32xf32> { diff --git a/tests/Transforms/split_preprocessing/args.mlir b/tests/Transforms/split_preprocessing/args.mlir index 1b9077604e..6135d67f39 100644 --- a/tests/Transforms/split_preprocessing/args.mlir +++ b/tests/Transforms/split_preprocessing/args.mlir @@ -3,14 +3,16 @@ // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @hoist_arg__preprocessing(%[[arg0:.*]]: tensor<1024xf32>) -> tensor<1x![[pt]]> +// CHECK: func.func @hoist_arg__preprocessing(%[[arg0:.*]]: tensor<1024xf32>) -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "hoist_arg"} +// CHECK: func.func @hoist_arg__preprocessed(%[[CT:.*]]: ![[ct_L1]], %[[ARG0:.*]]: tensor<1024xf32>, %[[STORAGE:.*]]: !preprocessing.storage) + // CHECK: func.func @hoist_arg( // CHECK-SAME: %[[CT:.*]]: ![[ct_L1]], // CHECK-SAME: %[[ARG0:.*]]: tensor<1024xf32>) -// CHECK-NEXT: %[[PT:.*]] = call @hoist_arg__preprocessing(%[[ARG0]]) -// CHECK-NEXT: %[[CALL:.*]] = call @hoist_arg__preprocessed(%[[CT]], %[[PT]]) +// CHECK-NEXT: %[[STORAGE:.*]] = call @hoist_arg__preprocessing(%[[ARG0]]) +// CHECK-NEXT: %[[CALL:.*]] = call @hoist_arg__preprocessed(%[[CT]], %[[ARG0]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: } @@ -41,19 +43,19 @@ func.func @hoist_arg(%ct: !ct_L1, %c1: tensor<1024xf32>) -> (!ct_L1) { // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @hoist_arg_and_constant__preprocessing(%[[arg0:.*]]: tensor<1024xf32>) -> (tensor<1x![[pt]]>, tensor<1x![[pt]]>) +// CHECK: func.func @hoist_arg_and_constant__preprocessing(%[[arg0:.*]]: tensor<1024xf32>) -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "hoist_arg_and_constant"} -// CHECK: func.func @hoist_arg_and_constant__preprocessed(%[[arg0:.*]]: ![[ct_L1]], -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x![[pt]]>, -// CHECK-SAME: %[[ARG1:.*]]: tensor<1x![[pt]]>) -> ![[ct_L1]] +// CHECK: func.func @hoist_arg_and_constant__preprocessed(%[[CT:.*]]: ![[ct_L1]], +// CHECK-SAME: %[[ARG0:.*]]: tensor<1024xf32>, +// CHECK-SAME: %[[STORAGE:.*]]: !preprocessing.storage) -> ![[ct_L1]] // CHECK-SAME: client.preprocessed_func = {func_name = "hoist_arg_and_constant"} // CHECK: func.func @hoist_arg_and_constant // CHECK-SAME: (%[[CT:.*]]: ![[ct_L1]], // CHECK-SAME: %[[ARG0:.*]]: tensor<1024xf32>) -// CHECK-NEXT: %[[PT1:.*]]:2 = call @hoist_arg_and_constant__preprocessing(%[[ARG0]]) -// CHECK-NEXT: %[[CALL:.*]] = call @hoist_arg_and_constant__preprocessed(%[[CT]], %[[PT1]]#0, %[[PT1]]#1) +// CHECK-NEXT: %[[STORAGE:.*]] = call @hoist_arg_and_constant__preprocessing(%[[ARG0]]) +// CHECK-NEXT: %[[CALL:.*]] = call @hoist_arg_and_constant__preprocessed(%[[CT]], %[[ARG0]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: } @@ -87,19 +89,20 @@ func.func @hoist_arg_and_constant(%ct: !ct_L1, %c1: tensor<1024xf32>) -> (!ct_L1 // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @hoist_with_computation__preprocessing(%[[arg0:.*]]: tensor<1x1024xf32>) -> tensor<1x![[pt]]> +// CHECK: func.func @hoist_with_computation__preprocessing(%[[arg0:.*]]: tensor<1x1024xf32>) -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "hoist_with_computation"} -// CHECK-NEXT: tensor.extract_slice +// CHECK: tensor.extract_slice -// CHECK: func.func @hoist_with_computation__preprocessed(%[[arg0:.*]]: ![[ct_L1]], -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x![[pt]]>) +// CHECK: func.func @hoist_with_computation__preprocessed(%[[CT:.*]]: ![[ct_L1]], +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1024xf32>, +// CHECK-SAME: %[[STORAGE:.*]]: !preprocessing.storage) // CHECK-SAME: client.preprocessed_func = {func_name = "hoist_with_computation"} // CHECK: func.func @hoist_with_computation // CHECK-SAME: (%[[CT:.*]]: ![[ct_L1]], // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1024xf32>) -// CHECK-NEXT: %[[PT:.*]] = call @hoist_with_computation__preprocessing(%[[ARG0]]) -// CHECK-NEXT: %[[CALL:.*]] = call @hoist_with_computation__preprocessed(%[[CT]], %[[PT]]) +// CHECK-NEXT: %[[STORAGE:.*]] = call @hoist_with_computation__preprocessing(%[[ARG0]]) +// CHECK-NEXT: %[[CALL:.*]] = call @hoist_with_computation__preprocessed(%[[CT]], %[[ARG0]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: } diff --git a/tests/Transforms/split_preprocessing/double_use.mlir b/tests/Transforms/split_preprocessing/double_use.mlir index fc42d0b561..7e5e343fd4 100644 --- a/tests/Transforms/split_preprocessing/double_use.mlir +++ b/tests/Transforms/split_preprocessing/double_use.mlir @@ -4,10 +4,10 @@ // is duplicated. // CHECK: func.func @constant__preprocessing -// CHECK-NEXT: arith.constant 0 : index +// CHECK: arith.constant 0 : index // CHECK: func.func @constant__preprocessed -// CHECK-NEXT: arith.constant 0 : index +// CHECK: arith.constant 0 : index !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> !Z36028797018652673_i64 = !mod_arith.int<36028797018652673 : i64> diff --git a/tests/Transforms/split_preprocessing/if_else_encode.mlir b/tests/Transforms/split_preprocessing/if_else_encode.mlir index 822149fe22..a409faa5ec 100644 --- a/tests/Transforms/split_preprocessing/if_else_encode.mlir +++ b/tests/Transforms/split_preprocessing/if_else_encode.mlir @@ -1,8 +1,13 @@ -// RUN: heir-opt %s --split-preprocessing='max-return-values=16' | FileCheck %s -// TODO(#2960): remove XFAIL and update test appropriately -// XFAIL: * - -// CHECK: func.func @if_else_encode__preprocessing +// RUN: heir-opt %s --split-preprocessing | FileCheck %s +// CHECK: func.func @if_else_encode__preprocessing( +// CHECK: %[[STORAGE:.*]] = preprocessing.empty +// CHECK: scf.if +// CHECK: %[[PT1:.*]] = lwe.rlwe_encode +// CHECK: preprocessing.store %[[PT1]], %[[STORAGE]][] site 0 +// CHECK: else +// CHECK: %[[PT2:.*]] = lwe.rlwe_encode +// CHECK: preprocessing.store %[[PT2]], %[[STORAGE]][] site 1 +// CHECK: return %[[STORAGE]] !Z35184371138561_i64 = !mod_arith.int<35184371138561 : i64> !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> diff --git a/tests/Transforms/split_preprocessing/iv_dependent_encode.mlir b/tests/Transforms/split_preprocessing/iv_dependent_encode.mlir index 6553da122e..9d4dfc2132 100644 --- a/tests/Transforms/split_preprocessing/iv_dependent_encode.mlir +++ b/tests/Transforms/split_preprocessing/iv_dependent_encode.mlir @@ -1,8 +1,10 @@ -// RUN: heir-opt %s --split-preprocessing='max-return-values=16' | FileCheck %s -// TODO(#2960): remove XFAIL and update test appropriately -// XFAIL: * - -// CHECK: func.func @zip_matvec__preprocessing +// RUN: heir-opt %s --split-preprocessing | FileCheck %s +// CHECK: func.func @zip_matvec__preprocessing() -> !preprocessing.storage +// CHECK: %[[STORAGE:.*]] = preprocessing.empty +// CHECK: scf.for +// CHECK: %[[PT:.*]] = lwe.rlwe_encode +// CHECK: preprocessing.store %[[PT]], %[[STORAGE]][%{{.*}}] site 0 : !pt, +// CHECK: return %[[STORAGE]] !Z35184371138561_i64 = !mod_arith.int<35184371138561 : i64> !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> diff --git a/tests/Transforms/split_preprocessing/linalg.mlir b/tests/Transforms/split_preprocessing/linalg.mlir index a3a2ae9356..991eaef153 100644 --- a/tests/Transforms/split_preprocessing/linalg.mlir +++ b/tests/Transforms/split_preprocessing/linalg.mlir @@ -6,18 +6,18 @@ // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @linalg__preprocessing() -> tensor<1x![[pt]]> +// CHECK: func.func @linalg__preprocessing() -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "linalg"} // CHECK: linalg.broadcast // CHECK: lwe.rlwe_encode -// CHECK: func.func @linalg__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: tensor<1x![[pt]]>) -> ![[ct_L1]] +// CHECK: func.func @linalg__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[STORAGE:.*]]: !preprocessing.storage) -> ![[ct_L1]] // CHECK-SAME: client.preprocessed_func = {func_name = "linalg"} // CHECK: func.func @linalg // CHECK-SAME: (%[[CT:.*]]: ![[ct_L1]]) -// CHECK-NEXT: %[[PT:.*]] = call @linalg__preprocessing() -// CHECK-NEXT: %[[CALL:.*]] = call @linalg__preprocessed(%[[CT]], %[[PT]]) : (![[ct_L1]], tensor<1x![[pt]]>) -> ![[ct_L1]] +// CHECK-NEXT: %[[STORAGE:.*]] = call @linalg__preprocessing() +// CHECK-NEXT: %[[CALL:.*]] = call @linalg__preprocessed(%[[CT]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: } diff --git a/tests/Transforms/split_preprocessing/loop_encode.mlir b/tests/Transforms/split_preprocessing/loop_encode.mlir index bfee1f791c..261d1d2afb 100644 --- a/tests/Transforms/split_preprocessing/loop_encode.mlir +++ b/tests/Transforms/split_preprocessing/loop_encode.mlir @@ -1,8 +1,20 @@ -// RUN: heir-opt %s --split-preprocessing='max-return-values=16' | FileCheck %s -// TODO(#2960): remove XFAIL and update test appropriately -// XFAIL: * +// RUN: heir-opt %s --split-preprocessing | FileCheck %s // CHECK: func.func @matvec_loop_encode__preprocessing +// CHECK: %[[storage:.*]] = preprocessing.empty +// CHECK: scf.for %[[iv:.*]] = %c0 to %c4 step %c1 { +// CHECK: %[[pt:.*]] = lwe.rlwe_encode +// CHECK: preprocessing.store %[[pt]], %[[storage]][%[[iv]]] site 0 +// CHECK: } +// CHECK: return %[[storage]] : !preprocessing.storage + +// CHECK: func.func @matvec_loop_encode__preprocessed(%[[arg0:.*]]: tensor<1x!ct_L2>, %[[storageArg:.*]]: !preprocessing.storage) -> tensor<1x!ct_L2> +// CHECK: scf.for %[[iv:.*]] = %c0 to %c4 step %c1 iter_args(%[[arg2:.*]] = %[[arg0]]) -> (tensor<1x!ct_L2>) { +// CHECK: %[[loaded:.*]] = preprocessing.load %[[storageArg]][%[[iv]]] site 0 +// CHECK: %[[from_elements:.*]] = tensor.from_elements %[[loaded]] : tensor<1x!pt> +// CHECK: %[[add:.*]] = ckks.add_plain %[[from_elements]], %[[arg2]] +// CHECK: scf.yield %[[add]] +// CHECK: } !Z35184371138561_i64 = !mod_arith.int<35184371138561 : i64> !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> diff --git a/tests/Transforms/split_preprocessing/many_plaintexts.mlir b/tests/Transforms/split_preprocessing/many_plaintexts.mlir index 0a41a1dca0..82273cdfca 100644 --- a/tests/Transforms/split_preprocessing/many_plaintexts.mlir +++ b/tests/Transforms/split_preprocessing/many_plaintexts.mlir @@ -6,18 +6,19 @@ // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @multiple__preprocessing(%[[arg0:.*]]: tensor<2x1024xf32>) -> (tensor<1x![[pt]]>, tensor<1x![[pt]]>) +// CHECK: func.func @multiple__preprocessing(%[[arg0:.*]]: tensor<2x1024xf32>) -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "multiple"} +// CHECK: %[[STORAGE:.*]] = preprocessing.empty // CHECK-COUNT-2: lwe.rlwe_encode -// CHECK: func.func @multiple__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: tensor<1x![[pt]]>, %[[arg1:.*]]: tensor<1x![[pt]]>) -> ![[ct_L1]] +// CHECK: func.func @multiple__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: tensor<2x1024xf32>, %[[STORAGE:.*]]: !preprocessing.storage) -> ![[ct_L1]] // CHECK-SAME: client.preprocessed_func = {func_name = "multiple"} // CHECK: func.func @multiple // CHECK-SAME: (%[[CT:.*]]: ![[ct_L1]], // CHECK-SAME: %[[ARG0:.*]]: tensor<2x1024xf32>) -// CHECK-NEXT: %[[PT:.*]]:2 = call @multiple__preprocessing(%[[ARG0]]) -// CHECK-NEXT: %[[CALL:.*]] = call @multiple__preprocessed(%[[CT]], %[[PT]]#0, %[[PT]]#1) : (![[ct_L1]], tensor<1x![[pt]]>, tensor<1x![[pt]]>) -> ![[ct_L1]] +// CHECK-NEXT: %[[STORAGE:.*]] = call @multiple__preprocessing(%[[ARG0]]) +// CHECK-NEXT: %[[CALL:.*]] = call @multiple__preprocessed(%[[CT]], %[[ARG0]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: } diff --git a/tests/Transforms/split_preprocessing/matvec_crash.mlir b/tests/Transforms/split_preprocessing/matvec_crash.mlir index 34c81eb79c..0efc9ca723 100644 --- a/tests/Transforms/split_preprocessing/matvec_crash.mlir +++ b/tests/Transforms/split_preprocessing/matvec_crash.mlir @@ -1,6 +1,4 @@ -// RUN: heir-opt %s --split-preprocessing='max-return-values=16' -// TODO(#2960): remove XFAIL and update test appropriately -// XFAIL: * +// RUN: heir-opt %s --split-preprocessing !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> !Z36028797018652673_i64 = !mod_arith.int<36028797018652673 : i64> diff --git a/tests/Transforms/split_preprocessing/options.mlir b/tests/Transforms/split_preprocessing/options.mlir index 480b653efe..82273cdfca 100644 --- a/tests/Transforms/split_preprocessing/options.mlir +++ b/tests/Transforms/split_preprocessing/options.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --split-preprocessing=max-return-values=1 %s | FileCheck %s +// RUN: heir-opt --split-preprocessing %s | FileCheck %s // Tests that a collection of plaintexts created from a single block arg or // constant is returned as a single tensor. @@ -6,18 +6,19 @@ // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @multiple__preprocessing(%[[arg0:.*]]: tensor<2x1024xf32>) -> tensor<2x![[pt]]> +// CHECK: func.func @multiple__preprocessing(%[[arg0:.*]]: tensor<2x1024xf32>) -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "multiple"} +// CHECK: %[[STORAGE:.*]] = preprocessing.empty // CHECK-COUNT-2: lwe.rlwe_encode -// CHECK: func.func @multiple__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: tensor<2x![[pt]]>) -> ![[ct_L1]] +// CHECK: func.func @multiple__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: tensor<2x1024xf32>, %[[STORAGE:.*]]: !preprocessing.storage) -> ![[ct_L1]] // CHECK-SAME: client.preprocessed_func = {func_name = "multiple"} // CHECK: func.func @multiple // CHECK-SAME: (%[[CT:.*]]: ![[ct_L1]], // CHECK-SAME: %[[ARG0:.*]]: tensor<2x1024xf32>) -// CHECK-NEXT: %[[PT:.*]] = call @multiple__preprocessing(%[[ARG0]]) -// CHECK-NEXT: %[[CALL:.*]] = call @multiple__preprocessed(%[[CT]], %[[PT]]) : (![[ct_L1]], tensor<2x![[pt]]>) -> ![[ct_L1]] +// CHECK-NEXT: %[[STORAGE:.*]] = call @multiple__preprocessing(%[[ARG0]]) +// CHECK-NEXT: %[[CALL:.*]] = call @multiple__preprocessed(%[[CT]], %[[ARG0]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: } diff --git a/tests/Transforms/split_preprocessing/pure_preproc_affine_loop.mlir b/tests/Transforms/split_preprocessing/pure_preproc_affine_loop.mlir new file mode 100644 index 0000000000..ec9253efea --- /dev/null +++ b/tests/Transforms/split_preprocessing/pure_preproc_affine_loop.mlir @@ -0,0 +1,41 @@ +// RUN: heir-opt %s --split-preprocessing | FileCheck %s + +!Z35184371138561_i64 = !mod_arith.int<35184371138561 : i64> +!Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> +!Z36028797017456641_i64 = !mod_arith.int<36028797017456641 : i64> +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +#ring_f64_1_x1024 = #polynomial.ring> +!rns_L2 = !rns.rns +!pt = !lwe.lwe_plaintext> +#ring_rns_L2_1_x1024 = #polynomial.ring> +#ciphertext_space_L2 = #lwe.ciphertext_space +!ct_L2 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #lwe.modulus_chain, current = 2>> + +// CHECK: func.func @pure_preproc_affine_loop__preprocessing() -> !preprocessing.storage +// CHECK: %[[STORAGE:.*]] = preprocessing.empty +// CHECK: affine.for %[[I:.*]] = 0 to 4 +// CHECK: %[[PT:.*]] = lwe.rlwe_encode +// CHECK: preprocessing.store %[[PT]], %[[STORAGE]][%[[I]]] site 0 : !pt, +// CHECK: return %[[STORAGE]] + +module attributes {backend.openfhe, ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { + func.func @pure_preproc_affine_loop(%arg0: tensor<1x!ct_L2>) -> tensor<1x!ct_L2> { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<1.0> : tensor<4x1024xf32> + %empty = tensor.empty() : tensor<4x!pt> + + %0 = affine.for %arg1 = 0 to 4 iter_args(%arg2 = %empty) -> (tensor<4x!pt>) { + %extracted_slice = tensor.extract_slice %cst[%arg1, 0] [1, 1024] [1, 1] : tensor<4x1024xf32> to tensor<1024xf32> + %pt = lwe.rlwe_encode %extracted_slice {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %inserted = tensor.insert %pt into %arg2[%arg1] : tensor<4x!pt> + affine.yield %inserted : tensor<4x!pt> + } + + %extracted_0 = tensor.extract %0[%c0] : tensor<4x!pt> + %from_elements = tensor.from_elements %extracted_0 : tensor<1x!pt> + %1 = ckks.add_plain %from_elements, %arg0 : (tensor<1x!pt>, tensor<1x!ct_L2>) -> tensor<1x!ct_L2> + + return %1 : tensor<1x!ct_L2> + } +} diff --git a/tests/Transforms/split_preprocessing/pure_preproc_loop.mlir b/tests/Transforms/split_preprocessing/pure_preproc_loop.mlir index 4c4149a66f..ead01ba0eb 100644 --- a/tests/Transforms/split_preprocessing/pure_preproc_loop.mlir +++ b/tests/Transforms/split_preprocessing/pure_preproc_loop.mlir @@ -1,8 +1,10 @@ -// RUN: heir-opt %s --split-preprocessing='max-return-values=16' | FileCheck %s -// TODO(#2960): remove XFAIL and update test appropriately -// XFAIL: * - -// CHECK: func.func @pure_preproc_loop__preprocessing +// RUN: heir-opt %s --split-preprocessing | FileCheck %s +// CHECK: func.func @pure_preproc_loop__preprocessing() -> !preprocessing.storage +// CHECK: %[[STORAGE:.*]] = preprocessing.empty +// CHECK: scf.for +// CHECK: %[[PT:.*]] = lwe.rlwe_encode +// CHECK: preprocessing.store %[[PT]], %[[STORAGE]][%{{.*}}] site 0 : !pt, +// CHECK: return %[[STORAGE]] !Z35184371138561_i64 = !mod_arith.int<35184371138561 : i64> !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> diff --git a/tests/Transforms/split_preprocessing/region.mlir b/tests/Transforms/split_preprocessing/region.mlir index 48f570017f..82994cd08a 100644 --- a/tests/Transforms/split_preprocessing/region.mlir +++ b/tests/Transforms/split_preprocessing/region.mlir @@ -6,15 +6,15 @@ // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L2:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @region__preprocessing() -> tensor<1x![[pt]]> -// CHECK: %[[cst:.*]] arith.constant dense_resource +// CHECK: func.func @region__preprocessing() -> !preprocessing.storage +// CHECK: %[[cst:.*]] = arith.constant dense_resource // CHECK: affine.for // CHECK: call @_assign_layout // CHECK: return // CHECK: func.func @region__preprocessed( // CHECK-SAME: %[[arg0:.*]]: tensor<1x![[ct_L2]]>, -// CHECK-SAME: %[[pt0:.*]]: tensor<1x![[pt]]>) +// CHECK-SAME: %[[STORAGE:.*]]: !preprocessing.storage) !Z35184371138561_i64 = !mod_arith.int<35184371138561 : i64> !Z35184372121601_i64 = !mod_arith.int<35184372121601 : i64> diff --git a/tests/Transforms/split_preprocessing/split_preprocessing.mlir b/tests/Transforms/split_preprocessing/split_preprocessing.mlir index 1812593377..270eb4fe08 100644 --- a/tests/Transforms/split_preprocessing/split_preprocessing.mlir +++ b/tests/Transforms/split_preprocessing/split_preprocessing.mlir @@ -3,19 +3,19 @@ // CHECK-DAG: ![[pt:.*]] = !lwe.lwe_plaintext // CHECK-DAG: ![[ct_L1:.*]] = !lwe.lwe_ciphertext -// CHECK: func.func @hoist_one_assign__preprocessing() -> tensor<1x![[pt]]> +// CHECK: func.func @hoist_one_assign__preprocessing() -> !preprocessing.storage // CHECK-SAME: client.pack_func = {func_name = "hoist_one_assign"} -// CHECK: func.func @hoist_one_assign__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: tensor<1x![[pt]]>) -> ![[ct_L1]] +// CHECK: func.func @hoist_one_assign__preprocessed(%[[ct:.*]]: ![[ct_L1]], %[[arg0:.*]]: !preprocessing.storage) -> ![[ct_L1]] // CHECK-SAME: client.preprocessed_func = {func_name = "hoist_one_assign"} -// CHECK: %[[extracted:.*]] = tensor.extract %[[arg0]] -// CHECK: %[[CT_0:.*]] = ckks.add_plain %ct, %[[extracted]] +// CHECK: %[[LOAD:.*]] = preprocessing.load %[[arg0]][] site 0 +// CHECK: %[[CT_0:.*]] = ckks.add_plain %ct, %[[LOAD]] // CHECK: return %[[CT_0]] : ![[ct_L1]] // CHECK: func.func @hoist_one_assign // CHECK-SAME: (%[[CT:.*]]: ![[ct_L1]] -// CHECK-NEXT: %[[PT:.*]] = call @hoist_one_assign__preprocessing -// CHECK-NEXT: %[[CALL:.*]] = call @hoist_one_assign__preprocessed(%[[CT]], %[[PT]]) +// CHECK-NEXT: %[[STORAGE:.*]] = call @hoist_one_assign__preprocessing() +// CHECK-NEXT: %[[CALL:.*]] = call @hoist_one_assign__preprocessed(%[[CT]], %[[STORAGE]]) // CHECK-NEXT: return %[[CALL]] // CHECK-NEXT: }