From 8eae1c6d94cfe9965ced857057088c015713bf44 Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Tue, 10 Feb 2026 14:03:10 -0800 Subject: [PATCH] fix: fixes lattigo in place transform by ensuring that storage values keep level state invariant regression test for in-place issue #2635 PiperOrigin-RevId: 868309124 --- lib/Analysis/LevelAnalysis/LevelAnalysis.cpp | 35 ++---- lib/Dialect/HEIRInterfaces.td | 35 ++++++ lib/Dialect/Lattigo/IR/LattigoBGVOps.td | 8 +- lib/Dialect/Lattigo/IR/LattigoCKKSOps.td | 6 +- lib/Dialect/Lattigo/IR/LattigoOps.cpp | 4 + lib/Dialect/Lattigo/IR/LattigoRLWEOps.td | 6 +- .../Lattigo/Transforms/AllocToInPlace.cpp | 76 ++++++++++--- lib/Dialect/Lattigo/Transforms/BUILD | 4 +- lib/Dialect/Mgmt/IR/MgmtOps.cpp | 3 + lib/Dialect/Mgmt/IR/MgmtOps.td | 9 +- .../ArithmeticPipelineRegistration.cpp | 3 +- lib/Target/Lattigo/LattigoEmitter.cpp | 15 ++- lib/Utils/AllocToInPlaceUtils.h | 29 ++++- lib/Utils/BUILD | 2 + .../Transforms/alloc_to_in_place_levels.mlir | 23 ++++ .../alloc_to_inplace_dot_product.mlir | 10 +- .../alloc_to_inplace_multi_func.mlir | 19 +++- .../Dialect/Mgmt/Transforms/level_reduce.mlir | 13 +++ tests/Emitter/Lattigo/emit_lattigo.mlir | 16 +++ tests/Examples/lattigo/ckks/in_place/BUILD | 30 +++++ .../lattigo/ckks/in_place/in_place.mlir | 46 ++++++++ .../lattigo/ckks/in_place/in_place_test.go | 105 ++++++++++++++++++ 22 files changed, 432 insertions(+), 65 deletions(-) create mode 100644 tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir create mode 100644 tests/Dialect/Mgmt/Transforms/level_reduce.mlir create mode 100644 tests/Examples/lattigo/ckks/in_place/BUILD create mode 100644 tests/Examples/lattigo/ckks/in_place/in_place.mlir create mode 100644 tests/Examples/lattigo/ckks/in_place/in_place_test.go diff --git a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp index 7c80672939..75633099eb 100644 --- a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp +++ b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp @@ -50,21 +50,7 @@ static void debugLog(StringRef opName, ArrayRef operands, }); }; -LevelState transferForward(mgmt::ModReduceOp op, - ArrayRef operands) { - LevelState result = std::visit( - Overloaded{ - [](MaxLevel) -> LevelState { return LevelState(Invalid{}); }, - [](Uninit) -> LevelState { return LevelState(Invalid{}); }, - [](Invalid) -> LevelState { return LevelState(Invalid{}); }, - [](int val) -> LevelState { return LevelState(val + 1); }, - }, - operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("mod_reduce", operands, result)); - return result; -} - -LevelState transferForward(mgmt::LevelReduceOp op, +LevelState transferForward(ReducesLevelOpInterface op, ArrayRef operands) { LevelState result = std::visit( Overloaded{ @@ -72,15 +58,15 @@ LevelState transferForward(mgmt::LevelReduceOp op, [](Uninit) -> LevelState { return LevelState(Invalid{}); }, [](Invalid) -> LevelState { return LevelState(Invalid{}); }, [&](int val) -> LevelState { - return LevelState(val + (int)op.getLevelToDrop()); + return LevelState(val + op.getLevelsToDrop()); }, }, operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("level_reduce", operands, result)); + LLVM_DEBUG(debugLog("ReduceLevelOpInterface", operands, result)); return result; } -LevelState transferForward(mgmt::LevelReduceMinOp op, +LevelState transferForward(ReducesAllLevelsOpInterface op, ArrayRef operands) { LevelState result = std::visit( Overloaded{ @@ -92,11 +78,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op, [](int val) -> LevelState { return LevelState(MaxLevel{}); }, }, operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("level_reduce_min", operands, result)); + LLVM_DEBUG(debugLog("ReduceAllLevelsOpInterface", operands, result)); return result; } -LevelState transferForward(mgmt::BootstrapOp op, +LevelState transferForward(ResetsLevelOpInterface op, ArrayRef operands) { LevelState result = std::visit( Overloaded{ @@ -106,15 +92,18 @@ LevelState transferForward(mgmt::BootstrapOp op, [](int val) -> LevelState { return LevelState(0); }, }, operands[0]->getValue().get()); - LLVM_DEBUG(debugLog("bootstrap", operands, result)); + LLVM_DEBUG(debugLog("ResetsLevelOpInterface", operands, result)); return result; } LevelState deriveResultLevel(Operation* op, ArrayRef operands) { return llvm::TypeSwitch(*op) - .Case( + .Case( + [&](auto op) -> LevelState { return transferForward(op, operands); }) + .Case( + [&](auto op) -> LevelState { return transferForward(op, operands); }) + .Case( [&](auto op) -> LevelState { return transferForward(op, operands); }) .Default([&](auto& op) -> LevelState { LevelState result; diff --git a/lib/Dialect/HEIRInterfaces.td b/lib/Dialect/HEIRInterfaces.td index a6a30931bf..a760be0e3e 100644 --- a/lib/Dialect/HEIRInterfaces.td +++ b/lib/Dialect/HEIRInterfaces.td @@ -44,6 +44,41 @@ def ResetsMulDepthOpInterface : OpInterface<"ResetsMulDepthOpInterface"> { }]; } +def ResetsLevelOpInterface : OpInterface<"ResetsLevelOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + An interface that signals when an operation resets level + among its results, such as a `mgmt.bootstrap`. + }]; +} + +def ReducesLevelOpInterface : OpInterface<"ReducesLevelOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + An interface that signals when an operation reduces level + among its results, such as a `mgmt.mod_reduce` or `ckks.rescale`. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/"Return the number of levels to reduce by.", + /*retTy=*/"int", + /*methodName=*/"getLevelsToDrop", + /*args=*/(ins ), + /*body=*/[{}], + /*defaultBody=*/[{ return 1; }] + >, + ]; +} + +def ReducesAllLevelsOpInterface : OpInterface<"ReducesAllLevelsOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + An interface that signals when an operation reduces all level + among its results, such as a `mgmt.level_reduce_min`. + }]; +} + def LUTOpInterface : OpInterface<"LUTOpInterface"> { let cppNamespace = "::mlir::heir"; let description = [{ diff --git a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td index 42f57f6d77..f128b074c6 100644 --- a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td @@ -185,8 +185,8 @@ def Lattigo_BGVMulOp : Lattigo_BGVBinaryInPlaceOp<"mul", [IncreasesMulDepthOpInt }]; } -class Lattigo_BGVUnaryOp : - Lattigo_BGVOp { +class Lattigo_BGVUnaryOp traits = []> : + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$input @@ -201,7 +201,7 @@ def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> { }]; } -def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> { +def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo BGV dialect. @@ -261,7 +261,7 @@ def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInPlaceOp<"relinearize"> { }]; } -def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale"> { +def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo BGV dialect. diff --git a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td index 015cbb8547..c0a4be23ab 100644 --- a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td @@ -234,7 +234,7 @@ def Lattigo_CKKSRelinearizeNewOp : Lattigo_CKKSUnaryOp<"relinearize_new"> { }]; } -def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new"> { +def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo CKKS dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo CKKS dialect. @@ -288,7 +288,7 @@ def Lattigo_CKKSRelinearizeOp : Lattigo_CKKSUnaryInPlaceOp<"relinearize"> { }]; } -def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale"> { +def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> { let summary = "Rescale a ciphertext in the Lattigo CKKS dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo CKKS dialect. @@ -326,7 +326,7 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [ let hasVerifier = 1; } -def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap", [ResetsMulDepthOpInterface]> { +def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> { let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect"; let description = [{ Bootstraps a ciphertext value in the Lattigo CKKS dialect. diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.cpp b/lib/Dialect/Lattigo/IR/LattigoOps.cpp index 10336f9a85..49e55a9c49 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.cpp +++ b/lib/Dialect/Lattigo/IR/LattigoOps.cpp @@ -47,6 +47,10 @@ LogicalResult RLWENewEncryptorOp::verify() { return success(); } +int RLWEDropLevelNewOp::getLevelsToDrop() { return getLevelToDrop(); } + +int RLWEDropLevelOp::getLevelsToDrop() { return getLevelToDrop(); } + LogicalResult BGVRotateColumnsNewOp::verify() { return containsExactlyOneOrEmitError(getOperation(), getDynamicShift(), getStaticShift()); diff --git a/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td b/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td index 6ad4129cb5..b3fbfbbee2 100644 --- a/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td @@ -122,7 +122,8 @@ def Lattigo_RLWEDecryptOp : Lattigo_RLWEOp<"decrypt", [Pure]> { let results = (outs Lattigo_RLWEPlaintext:$plaintext); } -def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new", [Pure]> { +def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new", + [Pure, DeclareOpInterfaceMethods]> { let summary = "Drop level of a ciphertext"; let arguments = (ins Lattigo_RLWEEvaluator:$evaluator, @@ -132,7 +133,8 @@ def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new", [Pure]> { let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InPlaceOpInterface]> { +def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", + [InPlaceOpInterface, DeclareOpInterfaceMethods]> { let summary = "Drop level of a ciphertext"; let description = [{ This operation drops the level of a ciphertext diff --git a/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp b/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp index e4c8d29ba4..055298806a 100644 --- a/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp +++ b/lib/Dialect/Lattigo/Transforms/AllocToInPlace.cpp @@ -2,36 +2,56 @@ #include +#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/Lattigo/IR/LattigoOps.h" #include "lib/Dialect/Lattigo/IR/LattigoTypes.h" #include "lib/Utils/AllocToInPlaceUtils.h" -#include "mlir/include/mlir/Analysis/Liveness.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/Liveness.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project +#define DEBUG_TYPE "alloc-to-inplace" + namespace mlir { namespace heir { namespace lattigo { +namespace { + +// Sets the level of a potentially newly created value. +static inline void setValueToLevel(DataFlowSolver* solver, Value value, + int level) { + auto* lattice = solver->getOrCreateState(value); + lattice->getValue().setLevel(level); +} + +} // namespace + template struct ConvertBinOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertBinOp(mlir::MLIRContext* context, Liveness* liveness, + DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(BinOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -45,13 +65,15 @@ struct ConvertBinOp : public OpRewritePattern { // Update storage info, which must happen before the op is removed storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); - + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -60,16 +82,17 @@ struct ConvertUnaryOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertUnaryOp( - mlir::MLIRContext* context, Liveness* liveness, + mlir::MLIRContext* context, Liveness* liveness, DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(UnaryOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -82,12 +105,15 @@ struct ConvertUnaryOp : public OpRewritePattern { op.getOperand(0), op.getOperand(1), storage); storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -96,16 +122,17 @@ struct ConvertRotateOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertRotateOp( - mlir::MLIRContext* context, Liveness* liveness, + mlir::MLIRContext* context, Liveness* liveness, DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(RotateOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -132,12 +159,15 @@ struct ConvertRotateOp : public OpRewritePattern { // update storage info storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -146,16 +176,17 @@ struct ConvertDropLevelOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; ConvertDropLevelOp( - mlir::MLIRContext* context, Liveness* liveness, + mlir::MLIRContext* context, Liveness* liveness, DataFlowSolver* solver, DenseMap* blockToStorageInfo) : OpRewritePattern(context), liveness(liveness), + solver(solver), blockToStorageInfo(blockToStorageInfo) {} LogicalResult matchAndRewrite(DropLevelOp op, PatternRewriter& rewriter) const override { auto& storageInfo = (*blockToStorageInfo)[op->getBlock()]; - auto storage = storageInfo.getAvailableStorage(op, liveness); + auto storage = storageInfo.getAvailableStorage(op, liveness, solver); if (!storage) { return rewriter.notifyMatchFailure(op, "no available storage found"); } @@ -169,12 +200,15 @@ struct ConvertDropLevelOp : public OpRewritePattern { // update storage info storageInfo.replaceAllocWithInPlace(op, inplaceOp, storage); + setValueToLevel(solver, inplaceOp->getResult(0), + getLevel(storage, solver).value().getInt()); rewriter.replaceOp(op, inplaceOp); return success(); } private: Liveness* liveness; + DataFlowSolver* solver; DenseMap* blockToStorageInfo; }; @@ -185,6 +219,14 @@ struct AllocToInPlace : impl::AllocToInPlaceBase { using AllocToInPlaceBase::AllocToInPlaceBase; void runOnOperation() override { + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + } Liveness liveness(getOperation()); MLIRContext* context = &getContext(); @@ -213,8 +255,8 @@ struct AllocToInPlace : impl::AllocToInPlaceBase { // RLWE ConvertUnaryOp, ConvertDropLevelOp>(context, &liveness, - &blockToStorageInfo); + lattigo::RLWEDropLevelOp>>( + context, &liveness, &solver, &blockToStorageInfo); // The greedy policy relies on the order of processing the operations. walkAndApplyPatterns(getOperation(), std::move(patterns)); diff --git a/lib/Dialect/Lattigo/Transforms/BUILD b/lib/Dialect/Lattigo/Transforms/BUILD index c4ec609e6f..4260bdd090 100644 --- a/lib/Dialect/Lattigo/Transforms/BUILD +++ b/lib/Dialect/Lattigo/Transforms/BUILD @@ -23,9 +23,11 @@ cc_library( hdrs = ["AllocToInPlace.h"], deps = [ ":pass_inc_gen", + "@heir//lib/Analysis/LevelAnalysis", + "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect/Lattigo/IR:Dialect", "@heir//lib/Utils:AllocToInPlaceUtils", - "@heir//lib/Utils/Tablegen:InPlaceOpInterface", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/Mgmt/IR/MgmtOps.cpp b/lib/Dialect/Mgmt/IR/MgmtOps.cpp index 9b7764e5eb..8913dc18b5 100644 --- a/lib/Dialect/Mgmt/IR/MgmtOps.cpp +++ b/lib/Dialect/Mgmt/IR/MgmtOps.cpp @@ -1,5 +1,6 @@ #include "lib/Dialect/Mgmt/IR/MgmtOps.h" +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Mgmt/IR/MgmtPatterns.h" #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project @@ -50,6 +51,8 @@ void cleanupInitOp(Operation* top) { }); } +int LevelReduceOp::getLevelsToDrop() { return getLevelToDrop(); } + } // namespace mgmt } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Mgmt/IR/MgmtOps.td b/lib/Dialect/Mgmt/IR/MgmtOps.td index 8b899bbe43..82f9095a63 100644 --- a/lib/Dialect/Mgmt/IR/MgmtOps.td +++ b/lib/Dialect/Mgmt/IR/MgmtOps.td @@ -15,7 +15,7 @@ class Mgmt_Op traits = []> : let cppNamespace = "::mlir::heir::mgmt"; } -def Mgmt_ModReduceOp : Mgmt_Op<"modreduce"> { +def Mgmt_ModReduceOp : Mgmt_Op<"modreduce", [ReducesLevelOpInterface]> { let summary = "Modulus switch the input ciphertext down by one limb (RNS assumed)"; let description = [{ @@ -35,7 +35,8 @@ def Mgmt_ModReduceOp : Mgmt_Op<"modreduce"> { let hasCanonicalizer = 1; } -def Mgmt_LevelReduceOp : Mgmt_Op<"level_reduce"> { +def Mgmt_LevelReduceOp : Mgmt_Op<"level_reduce", + [DeclareOpInterfaceMethods]> { let summary = "Reduce the level of input ciphertext by dropping the last k RNS limbs"; let description = [{ @@ -60,7 +61,7 @@ def Mgmt_LevelReduceOp : Mgmt_Op<"level_reduce"> { let hasCanonicalizer = 1; } -def Mgmt_LevelReduceMinOp : Mgmt_Op<"level_reduce_min"> { +def Mgmt_LevelReduceMinOp : Mgmt_Op<"level_reduce_min", [ReducesAllLevelsOpInterface]> { let summary = "Reduce the level of input ciphertext to the minimum level"; let description = [{ This scheme-agonistic operation reduces the ciphertext level @@ -99,7 +100,7 @@ def Mgmt_RelinearizeOp : Mgmt_Op<"relinearize"> { let assemblyFormat = "operands attr-dict `:` type($output)"; } -def Mgmt_BootstrapOp : Mgmt_Op<"bootstrap", [ResetsMulDepthOpInterface]> { +def Mgmt_BootstrapOp : Mgmt_Op<"bootstrap", [ResetsMulDepthOpInterface, ResetsLevelOpInterface]> { let summary = "Bootstrap the input ciphertext to refresh its noise budget"; let description = [{ diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index 909b8ae3cf..3d1bd478bf 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -510,8 +510,7 @@ BackendPipelineBuilder toLattigoPipelineBuilder() { pm.addPass(lwe::createLWEToLattigo()); // Convert Alloc Ops to InPlace Ops - // TODO(#2635): Disable until this is fixed. - // pm.addPass(lattigo::createAllocToInPlace()); + pm.addPass(lattigo::createAllocToInPlace()); // Simplify, in case the lowering revealed redundancy pm.addPass(createCanonicalizerPass()); diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index e20757e3dc..ef2885ebf2 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -1379,16 +1379,17 @@ LogicalResult LattigoEmitter::printOperation(RLWEDropLevelNewOp op) { // there is no DropLevelNew method in Lattigo BGV Evaluator, manually create // new ciphertext std::string resultName = getName(op.getOutput()); - os << resultName << " := " << getName(op.getInput()) << ".CopyNew()\n"; + emitAssignment(resultName, getName(op.getInput()) + ".CopyNew()"); os << getName(op.getEvaluator()) << ".DropLevel(" << resultName << ", " << op.getLevelToDrop() << ")\n"; return success(); } LogicalResult LattigoEmitter::printOperation(RLWEDropLevelOp op) { + // Check if we need to declare a new variable for the output. if (getName(op.getOutput()) != getName(op.getInput())) { std::string resultName = getName(op.getOutput()); - os << resultName << " := " << getName(op.getInput()) << ".CopyNew()\n"; + emitAssignment(resultName, getName(op.getInput()) + ".CopyNew()"); } os << getName(op.getEvaluator()) << ".DropLevel(" << getName(op.getOutput()) << ", " << op.getLevelToDrop() << ")\n"; @@ -1407,7 +1408,7 @@ LogicalResult LattigoEmitter::printOperation(RLWENegateNewOp op) { // there is no NegateNew method in Lattigo, manually create new // ciphertext std::string resultName = getName(op.getOutput()); - os << resultName << " := " << getName(op.getInput()) << ".CopyNew()\n"; + emitAssignment(resultName, getName(op.getInput()) + ".CopyNew()"); auto indexName = getName(op.getOutput()) + "_index"; auto res = llvm::formatv(negateTemplate, indexName, getName(op.getOutput()), getName(op.getEvaluator())); @@ -1586,7 +1587,7 @@ LogicalResult LattigoEmitter::printOperation(BGVRelinearizeNewOp op) { LogicalResult LattigoEmitter::printOperation(BGVRescaleNewOp op) { // there is no RescaleNew method in Lattigo, manually create new ciphertext std::string resultName = getName(op.getOutput()); - os << resultName << " := " << getName(op.getInput()) << ".CopyNew()\n"; + emitAssignment(resultName, getName(op.getInput()) + ".CopyNew()"); return printEvalInPlaceMethod( op.getEvaluator(), {op.getInput(), op.getOutput()}, "Rescale", true); } @@ -1883,7 +1884,7 @@ LogicalResult LattigoEmitter::printOperation(CKKSRelinearizeNewOp op) { LogicalResult LattigoEmitter::printOperation(CKKSRescaleNewOp op) { // there is no RescaleNew method in Lattigo, manually create new ciphertext std::string resultName = getName(op.getOutput()); - os << resultName << " := " << getName(op.getInput()) << ".CopyNew()\n"; + emitAssignment(resultName, getName(op.getInput()) + ".CopyNew()"); return printEvalInPlaceMethod( op.getEvaluator(), {op.getInput(), op.getOutput()}, "Rescale", true); } @@ -1924,6 +1925,10 @@ LogicalResult LattigoEmitter::printOperation(CKKSRescaleOp op) { } LogicalResult LattigoEmitter::printOperation(CKKSRotateOp op) { + auto inputName = getName(op.getInput()); + auto inplaceName = getName(op.getInplace()); + os << inplaceName << ".Resize(" << inputName << ".Degree()," << inputName + << ".Level())\n"; auto errName = getErrName(); os << errName << " := " << getName(op.getEvaluator()) << ".Rotate("; os << getName(op.getInput()) << ", "; diff --git a/lib/Utils/AllocToInPlaceUtils.h b/lib/Utils/AllocToInPlaceUtils.h index 026687cf6a..cdf293a0ad 100644 --- a/lib/Utils/AllocToInPlaceUtils.h +++ b/lib/Utils/AllocToInPlaceUtils.h @@ -1,3 +1,7 @@ +#include + +#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #ifndef LIB_UTILS_ALLOCTOINPLACEUTILS_H_ #define LIB_UTILS_ALLOCTOINPLACEUTILS_H_ @@ -21,6 +25,18 @@ namespace mlir { namespace heir { +static std::optional getLevel(Value value, DataFlowSolver* solver) { + auto* lattice = solver->lookupState(value); + if (!lattice || !lattice->getValue().isInitialized()) { + return std::nullopt; + } + auto latticeVal = lattice->getValue(); + if (!latticeVal.isInt()) { + return std::nullopt; + } + return lattice->getValue().getInt(); +} + // CallerProvidedStorageInfo provides an analysis of SSA values that // can be reused for in-place operations that require the caller to pass // in pre-allocated memory for the operation to use. @@ -104,11 +120,22 @@ class CallerProvidedStorageInfo { // various accelerators. One basic optimization is to use the dead value that // is closest to the current operation in the block. But as we do not have the // information of the memory layout, we do not implement this optimization. - Value getAvailableStorage(Operation* op, Liveness* liveness) const { + Value getAvailableStorage(Operation* op, Liveness* liveness, + DataFlowSolver* solver) const { LLVM_DEBUG(llvm::dbgs() << "getAvailableStorage for op " << op->getName() << "\n"); for (auto& [storage, values] : storageToReferringValues) { // storage and all referring values are dead + if (solver) { + auto opLevel = getLevel(op->getResult(0), solver); + auto storageLevel = getLevel(storage, solver); + if (!opLevel.has_value() || !storageLevel.has_value()) { + continue; + } + if (opLevel.value() != storageLevel.value()) { + continue; + } + } if (std::all_of( values.begin(), values.end(), [&](Value value) { return liveness->isDeadAfter(value, op); }) && diff --git a/lib/Utils/BUILD b/lib/Utils/BUILD index 2437d74b64..01e4fae38e 100644 --- a/lib/Utils/BUILD +++ b/lib/Utils/BUILD @@ -232,7 +232,9 @@ cc_library( srcs = ["AllocToInPlaceUtils.cpp"], hdrs = ["AllocToInPlaceUtils.h"], deps = [ + "@heir//lib/Analysis/LevelAnalysis", "@heir//lib/Utils/Tablegen:InPlaceOpInterface", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir new file mode 100644 index 0000000000..6fe69c3ece --- /dev/null +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_in_place_levels.mlir @@ -0,0 +1,23 @@ +// RUN: heir-opt --lattigo-alloc-to-inplace %s | FileCheck %s + +// Use the minimum level level of the two operands for the result storage + +!evaluator = !lattigo.bgv.evaluator +!ct = !lattigo.rlwe.ciphertext + +// CHECK: ![[evaluator:.*]] = !lattigo.bgv.evaluator + +// CHECK: func.func @drop_level +// CHECK-SAME: %[[evaluator:.*]]: ![[evaluator]] +func.func @drop_level(%evaluator : !evaluator, %ct : !ct) -> !ct { + %ct_level_0 = lattigo.bgv.rotate_columns_new %evaluator, %ct {static_shift = 4} : (!evaluator, !ct) -> !ct + // CHECK: %[[ct_level_2:.*]] = lattigo.rlwe.drop_level_new + // CHECK-SAME: levelToDrop = 2 + // CHECK: %[[ct_level_4:.*]] = lattigo.rlwe.drop_level_new + // CHECK-SAME: levelToDrop = 4 + %0 = lattigo.rlwe.drop_level_new %evaluator, %ct { levelToDrop = 2 } : (!evaluator, !ct) -> !ct + %1 = lattigo.rlwe.drop_level_new %evaluator, %ct_level_0 { levelToDrop = 4 } : (!evaluator, !ct) -> !ct + // CHECK: lattigo.bgv.add %[[evaluator]], %[[ct_level_2]], %[[ct_level_4]], %[[ct_level_4]] + %2 = lattigo.bgv.add_new %evaluator, %0, %1 : (!evaluator, !ct, !ct) -> !ct + return %2 : !ct +} diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir index 481f73df66..de8729b2d2 100644 --- a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_dot_product.mlir @@ -3,7 +3,15 @@ // CHECK: func.func @dot_product func.func @dot_product(%evaluator: !lattigo.bgv.evaluator, %param: !lattigo.bgv.parameter, %encoder: !lattigo.bgv.encoder, %ct: !lattigo.rlwe.ciphertext, %ct_0: !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { // no new allocation found as the two ciphertexts in function argument are enough to store the imtermediate results - // CHECK-NOT: _new + // a new allocation is only needed for the rescale because of level change + // CHECK-NOT: mul_new + // CHECK-NOT: relinearize_new + // CHECK-NOT: rotate_columns_new + // CHECK-NOT: add_new + // CHECK: rescale_new + // CHECK-NOT: mul_new + // CHECK-NOT: rotate_columns_new + // CHECK: return %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir index db39d939d5..290bd7f103 100644 --- a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_multi_func.mlir @@ -12,7 +12,15 @@ module attributes {bgv.schemeParam = #bgv.scheme_param !ct attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { // no new allocation found as the two ciphertexts in function argument are enough to store the imtermediate results - // CHECK-NOT: _new + // a new allocation is only needed for the rescale because of level change + // CHECK-NOT: mul_new + // CHECK-NOT: relinearize_new + // CHECK-NOT: rotate_columns_new + // CHECK-NOT: add_new + // CHECK: rescale_new + // CHECK-NOT: mul_new + // CHECK-NOT: rotate_columns_new + // CHECK: return %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index @@ -37,7 +45,14 @@ module attributes {bgv.schemeParam = #bgv.scheme_param !ct attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { // no new allocation found as the two ciphertexts in function argument are enough to store the imtermediate results - // CHECK-NOT: _new + // CHECK-NOT: mul_new + // CHECK-NOT: relinearize_new + // CHECK-NOT: rotate_columns_new + // CHECK-NOT: add_new + // CHECK: rescale_new + // CHECK-NOT: mul_new + // CHECK-NOT: rotate_columns_new + // CHECK: return %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c4 = arith.constant 4 : index diff --git a/tests/Dialect/Mgmt/Transforms/level_reduce.mlir b/tests/Dialect/Mgmt/Transforms/level_reduce.mlir new file mode 100644 index 0000000000..31f7054ee5 --- /dev/null +++ b/tests/Dialect/Mgmt/Transforms/level_reduce.mlir @@ -0,0 +1,13 @@ +// RUN: heir-opt --annotate-mgmt %s | FileCheck %s + +func.func @main(%arg0: !secret.secret>) -> !secret.secret> { + // CHECK: secret.generic + // CHECK-SAME: level = 2 + %b = secret.generic(%arg0: !secret.secret>) { + ^body(%clear_a: tensor<8xi8>): + %c = mgmt.level_reduce %clear_a { levelToDrop = 2 }: tensor<8xi8> + secret.yield %c : tensor<8xi8> + // CHECK: } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> !secret.secret> + func.return %b : !secret.secret> +} diff --git a/tests/Emitter/Lattigo/emit_lattigo.mlir b/tests/Emitter/Lattigo/emit_lattigo.mlir index 2924d4d7c9..07bf4cd257 100644 --- a/tests/Emitter/Lattigo/emit_lattigo.mlir +++ b/tests/Emitter/Lattigo/emit_lattigo.mlir @@ -374,3 +374,19 @@ module attributes {scheme.bgv} { return %0 : tensor<4xi32> } } + +// ----- + +module attributes {scheme.bgv} { + // CHECK: func test_drop_level_inplace_declared + // CHECK-SAME: ([[evaluator:.*]] *bgv.Evaluator, [[ct:.*]] *rlwe.Ciphertext, [[alloc:.*]] []*rlwe.Ciphertext) + func.func @test_drop_level_inplace_declared(%evaluator: !lattigo.bgv.evaluator, %ct: !lattigo.rlwe.ciphertext, %alloc: memref<1x!lattigo.rlwe.ciphertext>) -> (!lattigo.rlwe.ciphertext) { + %c0 = arith.constant 0 : index + // CHECK: [[val:[^, ]*]] := [[alloc]][{{.*}}] + %val = memref.load %alloc[%c0] : memref<1x!lattigo.rlwe.ciphertext> + // CHECK: [[val]] = [[ct]].CopyNew() + // CHECK: [[evaluator]].DropLevel([[val]], 2) + %ct1 = lattigo.rlwe.drop_level %evaluator, %ct, %val {levelToDrop = 2}: (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext + return %ct1 : !lattigo.rlwe.ciphertext + } +} diff --git a/tests/Examples/lattigo/ckks/in_place/BUILD b/tests/Examples/lattigo/ckks/in_place/BUILD new file mode 100644 index 0000000000..9e93ad639c --- /dev/null +++ b/tests/Examples/lattigo/ckks/in_place/BUILD @@ -0,0 +1,30 @@ +# See README.md for setup required to run these tests + +load("@heir//tests/Examples/lattigo:test.bzl", "heir_lattigo_lib") +load("@rules_go//go:def.bzl", "go_test") + +package(default_applicable_licenses = ["@heir//:license"]) + +heir_lattigo_lib( + name = "in_place", + go_library_name = "main", + heir_opt_flags = [ + "--canonicalize", + "--cse", + "--scheme-to-lattigo", + ], + mlir_src = "in_place.mlir", +) + +# For Google-internal reasons we must separate the go_test rules from the macro +# above. + +go_test( + name = "in_place_test", + srcs = ["in_place_test.go"], + embed = [":main"], + deps = [ + "@com_github_tuneinsight_lattigo_v6//core/rlwe", + "@com_github_tuneinsight_lattigo_v6//schemes/ckks", + ], +) diff --git a/tests/Examples/lattigo/ckks/in_place/in_place.mlir b/tests/Examples/lattigo/ckks/in_place/in_place.mlir new file mode 100644 index 0000000000..892595f837 --- /dev/null +++ b/tests/Examples/lattigo/ckks/in_place/in_place.mlir @@ -0,0 +1,46 @@ +!Z1056763241666817029_i64 = !mod_arith.int<1056763241666817029 : i64> +!Z1106058412451299513_i64 = !mod_arith.int<1106058412451299513 : i64> +!Z957769724367225479_i64 = !mod_arith.int<957769724367225479 : i64> +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding1 = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding2 = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +#modulus_chain_L1_C1 = #lwe.modulus_chain, current = 1> +#modulus_chain_L2_C2 = #lwe.modulus_chain, current = 2> +#ring_f64_1_x131072 = #polynomial.ring> +!rns_L1 = !rns.rns +!rns_L2 = !rns.rns +!pt = !lwe.lwe_plaintext> +#ring_rns_L1_1_x131072 = #polynomial.ring> +#ring_rns_L2_1_x131072 = #polynomial.ring> +#ciphertext_space_L1 = #lwe.ciphertext_space +#ciphertext_space_L2 = #lwe.ciphertext_space +!ct_L1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> +!ct_L2 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L2_C2> +!ct_L2_1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L2_C2> +module attributes {ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { + func.func @in_place(%ct: !ct_L2) -> !ct_L1 { + %cst = arith.constant dense<0.000000e+00> : tensor<65536xf64> + %ct_0 = ckks.rotate %ct {static_shift = 0 : i32} : !ct_L2 + %pt = lwe.rlwe_encode %cst {encoding = #inverse_canonical_encoding1, ring = #ring_f64_1_x131072} : tensor<65536xf64> -> !pt + %ct_1 = ckks.mul_plain %ct_0, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_2 = ckks.rescale %ct_1 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_3 = ckks.rotate %ct {static_shift = 1 : i32} : !ct_L2 + %ct_4 = ckks.mul_plain %ct_3, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_5 = ckks.rescale %ct_4 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_6 = ckks.add %ct_2, %ct_5 : (!ct_L1, !ct_L1) -> !ct_L1 + %ct_7 = ckks.rotate %ct {static_shift = 2 : i32} : !ct_L2 + %ct_8 = ckks.mul_plain %ct_7, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_9 = ckks.rescale %ct_8 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_10 = ckks.add %ct_6, %ct_9 : (!ct_L1, !ct_L1) -> !ct_L1 + %ct_11 = ckks.rotate %ct {static_shift = 3 : i32} : !ct_L2 + %ct_12 = ckks.mul_plain %ct_11, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_13 = ckks.rescale %ct_12 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_14 = ckks.add %ct_10, %ct_13 : (!ct_L1, !ct_L1) -> !ct_L1 + %ct_15 = ckks.rotate %ct {static_shift = 4 : i32} : !ct_L2 + %ct_16 = ckks.mul_plain %ct_15, %pt : (!ct_L2, !pt) -> !ct_L2_1 + %ct_17 = ckks.rescale %ct_16 {to_ring = #ring_rns_L1_1_x131072} : !ct_L2_1 -> !ct_L1 + %ct_18 = ckks.add %ct_14, %ct_17 : (!ct_L1, !ct_L1) -> !ct_L1 + return %ct_18 : !ct_L1 + } +} diff --git a/tests/Examples/lattigo/ckks/in_place/in_place_test.go b/tests/Examples/lattigo/ckks/in_place/in_place_test.go new file mode 100644 index 0000000000..a7445e16aa --- /dev/null +++ b/tests/Examples/lattigo/ckks/in_place/in_place_test.go @@ -0,0 +1,105 @@ +package main + +import ( + "fmt" + "testing" + "time" + + "github.com/tuneinsight/lattigo/v6/core/rlwe" + "github.com/tuneinsight/lattigo/v6/schemes/ckks" +) + +// MakeFlattenedOnes creates a slice of float64 filled with 1.0s. +// The size of the slice is determined by the product of the input 2D dimensions (rows * cols). +func MakeFlattenedOnes(rows, cols int) []float64 { + size := rows * cols + tensor := make([]float64, size) + for i := range tensor { + tensor[i] = 1.0 + } + return tensor +} + +func makeRange(n int) []int { + a := make([]int, n) + for i := range a { + a[i] = i + } + return a +} + +func generateGalEls(param ckks.Parameters, indices []int) []uint64 { + var galEls []uint64 + for _, index := range indices { + galEls = append(galEls, param.GaloisElement(index)) + } + return galEls +} + +func TestMLP(t *testing.T) { + logN := 14 + numSlots := 1 << (logN - 1) + + // Input is arbitrary, doesn't matter since we're just testing + // performance + inputClear := make([]float64, numSlots) + for i := range inputClear { + inputClear[i] = 1.0 + } + + // Function args: + // + // %ct: encrypted input, + + // These parameters should match the mlir file, though due to the weird + // nature of this test, this is the source of truth for what is used, + // not the mlir file. + logQ := make([]int, 7) + for i := range logQ { + logQ[i] = 60 + } + param, err := ckks.NewParametersFromLiteral(ckks.ParametersLiteral{ + LogN: logN, + LogQ: logQ, + LogP: []int{60}, + LogDefaultScale: 40, + }) + if err != nil { + panic(err) + } + + encoder := ckks.NewEncoder(param) + kgen := rlwe.NewKeyGenerator(param) + sk, pk := kgen.GenKeyPairNew() + encryptor := rlwe.NewEncryptor(param, pk) + rk := kgen.GenRelinearizationKeyNew(sk) + + // We have to do this once for each distinct linear_transform op to + // ensure we generate all the galois keys needed by lattigo + var galEls []uint64 + // Manually add Galois key for extra rotation indices used in the + // mlir file, outside of linear_transform + // + // For some reason I need to manually add rotation keys used in + // linear_transform! That should have been handled by the above code... + rotIndices := makeRange(10) + galEls = append(galEls, generateGalEls(param, rotIndices)...) + + fmt.Printf("Final galEls: %v\n", galEls) + + evk := rlwe.NewMemEvaluationKeySet(rk, kgen.GenGaloisKeysNew(galEls, sk)...) + evaluator := ckks.NewEvaluator(param, evk) + + pt := ckks.NewPlaintext(param, 2) + encoder.Encode(inputClear, pt) + ctInput, err25 := encryptor.EncryptNew(pt) + if err25 != nil { + panic(err25) + } + + fmt.Printf("Starting call") + startTime := time.Now() + in_place(evaluator, param, encoder, ctInput) + duration := time.Since(startTime) + fmt.Printf("MLP call took: %v\n", duration) +}