From a52cab55d0f0c0e46803211a2a9a73db3515cf5c Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 4 Jun 2026 14:06:08 -0700 Subject: [PATCH] explicitly pass slotCount to MakeCKKSPackedPlaintext PiperOrigin-RevId: 926877435 --- lib/Dialect/LWE/Transforms/AddDebugPort.cpp | 106 ++++++++++++------ lib/Dialect/LWE/Transforms/BUILD | 1 - .../ArithmeticPipelineRegistration.cpp | 35 +++--- .../ArithmeticPipelineRegistration.h | 4 + lib/Pipelines/BUILD | 2 +- lib/Pipelines/BooleanPipelineRegistration.cpp | 8 +- lib/Pipelines/BooleanPipelineRegistration.h | 8 ++ .../OpenFhePke/OpenFhePkeDebugEmitter.cpp | 70 +++++++----- .../OpenFhePke/OpenFhePkeDebugEmitter.h | 5 +- lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp | 44 ++++++-- lib/Target/OpenFhePke/OpenFhePkeTemplates.h | 22 +++- .../add_debug_port_idempotency.mlir | 85 ++++++++++++++ .../Transforms/add_debug_port_no_debug.mlir | 27 +++++ tests/Emitter/Openfhe/emit_debug_helper.mlir | 2 +- tests/Emitter/Openfhe/emit_pybind.mlir | 4 +- .../lattigo/bfv/dot_product_8_debug/BUILD | 1 - .../lattigo/bfv/noise/mult_dep_16_debug/BUILD | 1 - .../lattigo/bfv/noise/mult_dep_8_debug/BUILD | 1 - .../bfv/noise/mult_indep_32_debug/BUILD | 1 - .../bfv/noise/mult_indep_8_debug/BUILD | 1 - tests/Examples/lattigo/bgv/bgv_debug.go | 4 + tests/Examples/lattigo/bgv/cross_level/BUILD | 1 - .../bgv/cross_level/cross_level_debug.go | 4 + .../lattigo/bgv/dot_product_8_debug/BUILD | 1 - .../bgv/dot_product_8_debug_mono/BUILD | 1 - tests/Examples/lattigo/ckks/ckks_debug.go | 17 ++- tests/Examples/lattigo/ckks/cross_level/BUILD | 1 - .../ckks/cross_level/cross_level_debug.go | 4 + .../lattigo/ckks/dot_product_8f_debug/BUILD | 1 - tests/Examples/openfhe/bfv/debug_helper.cpp | 12 +- tests/Examples/openfhe/bfv/debug_helper.h | 4 + tests/Examples/openfhe/bgv/debug_helper.cpp | 2 +- tests/Examples/openfhe/bgv/debug_helper.h | 4 + tests/Examples/openfhe/ckks/BUILD | 1 + tests/Examples/openfhe/ckks/debug_helper.cpp | 7 ++ tests/Examples/openfhe/ckks/debug_helper.h | 6 + .../openfhe/ckks/debug_validate/BUILD | 31 +++++ .../ckks/debug_validate/debug_helper.cpp | 19 ++++ .../ckks/debug_validate/debug_helper.h | 24 ++++ .../ckks/debug_validate/validate_lower.mlir | 6 + .../debug_validate/validate_lower_test.cpp | 32 ++++++ .../Examples/openfhe/ckks/loop_support/BUILD | 4 +- .../openfhe/ckks/loop_support/loop_test.cpp | 7 +- 43 files changed, 496 insertions(+), 125 deletions(-) create mode 100644 tests/Dialect/LWE/Transforms/add_debug_port_idempotency.mlir create mode 100644 tests/Dialect/LWE/Transforms/add_debug_port_no_debug.mlir create mode 100644 tests/Examples/openfhe/ckks/debug_validate/BUILD create mode 100644 tests/Examples/openfhe/ckks/debug_validate/debug_helper.cpp create mode 100644 tests/Examples/openfhe/ckks/debug_validate/debug_helper.h create mode 100644 tests/Examples/openfhe/ckks/debug_validate/validate_lower.mlir create mode 100644 tests/Examples/openfhe/ckks/debug_validate/validate_lower_test.cpp diff --git a/lib/Dialect/LWE/Transforms/AddDebugPort.cpp b/lib/Dialect/LWE/Transforms/AddDebugPort.cpp index b50ed8c8e3..55b73da39c 100644 --- a/lib/Dialect/LWE/Transforms/AddDebugPort.cpp +++ b/lib/Dialect/LWE/Transforms/AddDebugPort.cpp @@ -2,7 +2,6 @@ #include #include -#include #include "lib/Dialect/Debug/IR/DebugOps.h" #include "lib/Dialect/LWE/IR/LWETypes.h" @@ -63,9 +62,39 @@ FailureOr getPrivateKeyType(func::FuncOp op) { return lwePrivateKeyType; } +void populateDebugFuncCache(ModuleOp module, + llvm::DenseMap& typeToDebugFunc, + llvm::DenseSet& debugFuncNames) { + for (auto funcOp : module.getOps()) { + if (!funcOp.isExternal()) continue; + if (!funcOp.getName().starts_with("__heir_debug_")) continue; + if (funcOp.getArgumentTypes().size() != 2) continue; + if (!funcOp.getResultTypes().empty()) continue; + + typeToDebugFunc[funcOp.getFunctionType()] = funcOp; + debugFuncNames.insert(funcOp.getName()); + } +} + +static bool isAlreadyDebugged(Value value, + const llvm::DenseSet& debugFuncNames) { + for (auto& use : value.getUses()) { + Operation* user = use.getOwner(); + if (isa(user)) { + return true; + } + if (auto callOp = dyn_cast(user)) { + if (debugFuncNames.contains(callOp.getCallee())) { + return true; + } + } + } + return false; +} + func::FuncOp getOrCreateExternalDebugFunc( - ModuleOp module, Type lwePrivateKeyType, Type valueType, - llvm::DenseMap& typeToDebugFunc) { + ModuleOp module, SymbolTable& symbolTable, Type lwePrivateKeyType, + Type valueType, llvm::DenseMap& typeToDebugFunc) { auto* context = module.getContext(); auto debugFuncType = FunctionType::get(context, {lwePrivateKeyType, valueType}, {}); @@ -75,28 +104,29 @@ func::FuncOp getOrCreateExternalDebugFunc( return it->second; } - int counter = typeToDebugFunc.size(); - std::string funcName = "__heir_debug_" + std::to_string(counter); + unsigned uniquingCounter = typeToDebugFunc.size(); + SmallString<128> funcName = SymbolTable::generateSymbolName<128>( + "__heir_debug", + [&](StringRef name) { return symbolTable.lookup(name) != nullptr; }, + uniquingCounter); - // Assert that this name is not already in use. - assert(!module.lookupSymbol(funcName) && - "Symbol already exists"); - - ImplicitLocOpBuilder b = - ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody()); - auto funcOp = func::FuncOp::create(b, funcName, debugFuncType); + auto funcOp = func::FuncOp::create(module.getLoc(), funcName, debugFuncType); // required for external func call funcOp.setPrivate(); + symbolTable.insert(funcOp, module.getBody()->begin()); + typeToDebugFunc[debugFuncType] = funcOp; return funcOp; } -void insertValidationOps(func::FuncOp op) { +void insertValidationOps(func::FuncOp op, + const llvm::DenseSet& debugFuncNames) { int count = 0; auto insertValidate = [&](Value value, OpBuilder& b) { Type valueType = value.getType(); if (isa(getElementTypeOrSelf(valueType))) { + if (isAlreadyDebugged(value, debugFuncNames)) return; debug::ValidateOp::create(b, value.getLoc(), value, "heir_debug_" + std::to_string(count++), nullptr); @@ -121,8 +151,8 @@ void insertValidationOps(func::FuncOp op) { } LogicalResult lowerValidationOps( - func::FuncOp op, Value privateKey, int messageSize, - llvm::DenseMap& typeToDebugFunc) { + func::FuncOp op, SymbolTable& symbolTable, Value privateKey, + int messageSize, llvm::DenseMap& typeToDebugFunc) { auto module = op->getParentOfType(); Type lwePrivateKeyType = privateKey.getType(); @@ -141,10 +171,10 @@ LogicalResult lowerValidationOps( attrs.push_back(b.getNamedAttr( "message.size", b.getStringAttr(std::to_string(messageSize)))); - auto debugFunc = getOrCreateExternalDebugFunc(module, lwePrivateKeyType, - valueType, typeToDebugFunc); - auto callOp = - b.create(debugFunc, ArrayRef{privateKey, value}); + auto debugFunc = getOrCreateExternalDebugFunc( + module, symbolTable, lwePrivateKeyType, valueType, typeToDebugFunc); + auto callOp = func::CallOp::create(b, b.getLoc(), debugFunc, + ArrayRef{privateKey, value}); callOp->setDialectAttrs(attrs); validateOp.erase(); @@ -166,6 +196,10 @@ struct AddDebugPort : impl::AddDebugPortBase { ModuleOp module = cast(getOperation()); SymbolTable symbolTable(module); + llvm::DenseMap typeToDebugFunc; + llvm::DenseSet debugFuncNames; + populateDebugFuncCache(module, typeToDebugFunc, debugFuncNames); + SmallVector worklist; llvm::DenseMap funcToKeyType; if (failed(identifyInitialTargets(module, symbolTable, funcToKeyType, @@ -183,7 +217,7 @@ struct AddDebugPort : impl::AddDebugPortBase { if (insertDebugAfterEveryOp) { for (auto& [func, _] : funcToKeyType) { - insertValidationOps(func); + insertValidationOps(func, debugFuncNames); } } @@ -198,8 +232,8 @@ struct AddDebugPort : impl::AddDebugPortBase { return; } - llvm::DenseMap typeToDebugFunc; - if (failed(lowerAllValidationOps(module, funcToKeyType, typeToDebugFunc))) { + if (failed(lowerAllValidationOps(module, symbolTable, funcToKeyType, + typeToDebugFunc))) { signalPassFailure(); return; } @@ -229,18 +263,21 @@ struct AddDebugPort : impl::AddDebugPortBase { } if (entryFunc) { - auto type = getPrivateKeyType(entryFunc); - if (succeeded(type)) { + bool shouldProcess = + containsAnyOperations(entryFunc) || + insertDebugAfterEveryOp; + + if (shouldProcess) { + auto type = getPrivateKeyType(entryFunc); + if (failed(type)) { + entryFunc.emitError( + "Cannot infer LWE private key type for entry function"); + return failure(); + } funcToKeyType[entryFunc] = *type; worklist.push_back(entryFunc); - return success(); - } - - if (containsAnyOperations(entryFunc)) { - entryFunc.emitError( - "Cannot infer LWE private key type for entry function"); - return failure(); } + return success(); } for (auto funcOp : module.getOps()) { @@ -404,7 +441,8 @@ struct AddDebugPort : impl::AddDebugPortBase { /// \param typePairToInt Map to track generated debug function names. /// \return success() if successful, failure() otherwise. LogicalResult lowerAllValidationOps( - ModuleOp module, const llvm::DenseMap& funcToKeyType, + ModuleOp module, SymbolTable& symbolTable, + const llvm::DenseMap& funcToKeyType, llvm::DenseMap& typeToDebugFunc) { for (auto funcOp : module.getOps()) { if (funcOp.isExternal()) continue; @@ -428,8 +466,8 @@ struct AddDebugPort : impl::AddDebugPortBase { } if (privateKey) { - if (failed(lowerValidationOps(funcOp, privateKey, messageSize, - typeToDebugFunc))) { + if (failed(lowerValidationOps(funcOp, symbolTable, privateKey, + messageSize, typeToDebugFunc))) { funcOp.emitError("failed to lower validation ops"); return failure(); } diff --git a/lib/Dialect/LWE/Transforms/BUILD b/lib/Dialect/LWE/Transforms/BUILD index 29df8e43ee..45404f1162 100644 --- a/lib/Dialect/LWE/Transforms/BUILD +++ b/lib/Dialect/LWE/Transforms/BUILD @@ -28,7 +28,6 @@ cc_library( ], deps = [ ":pass_inc_gen", - "@heir//lib/Dialect:FuncUtils", "@heir//lib/Dialect/Debug/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Utils", diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index fa25069c42..46d9819cae 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -219,12 +219,9 @@ void mlirToPlaintextPipelineBuilder(OpPassManager& pm, mlirToRLWEPipelineOptions.ciphertextDegree = options.plaintextSize; mlirToSecretArithmeticPipelineBuilder(pm, mlirToRLWEPipelineOptions); - if (options.debug) { - // Insert debug handler calls - secret::SecretAddDebugPortOptions debugOptions; - debugOptions.insertDebugAfterEveryOp = true; - pm.addPass(secret::createSecretAddDebugPort(debugOptions)); - } + // Insert debug handler calls and/or lower debug.validate + pm.addPass(secret::createSecretAddDebugPort(secret::SecretAddDebugPortOptions{ + .insertDebugAfterEveryOp = options.debug})); pm.addPass(secret::createSecretDistributeGeneric()); pm.addPass(createCanonicalizerPass()); @@ -411,9 +408,15 @@ void mlirToRLWEPipeline(OpPassManager& pm, exit(EXIT_FAILURE); } + // Lower debug.validate ops to function calls with private key + pm.addPass(lwe::createAddDebugPort( + lwe::AddDebugPortOptions{.messageSize = (int)options.ciphertextDegree, + .insertDebugAfterEveryOp = options.debug})); + pm.addPass(createForwardInsertToExtract()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); + pm.addPass(createSymbolDCEPass()); // TODO(#2554): skip this pass if the backend supports trivial encryption pm.addPass(lwe::createImplementTrivialEncryptionAsAddition()); @@ -459,11 +462,11 @@ BackendPipelineBuilder toOpenFhePipelineBuilder() { pm.addPass(ckks::createCKKSToLWE()); // insert debug handler calls - if (options.debug) { - lwe::AddDebugPortOptions addDebugPortOptions; - addDebugPortOptions.entryFunction = options.entryFunction; - pm.addPass(lwe::createAddDebugPort(addDebugPortOptions)); - } + lwe::AddDebugPortOptions addDebugPortOptions{ + .entryFunction = options.entryFunction, + .insertDebugAfterEveryOp = options.debug, + }; + pm.addPass(lwe::createAddDebugPort(addDebugPortOptions)); // Convert LWE (and scheme-specific CKKS/BGV ops) to OpenFHE pm.addPass(lwe::createLWEToOpenfhe()); @@ -501,11 +504,11 @@ BackendPipelineBuilder toLattigoPipelineBuilder() { pm.addPass(ckks::createCKKSToLWE()); // insert debug handler calls - if (options.debug) { - lwe::AddDebugPortOptions addDebugPortOptions; - addDebugPortOptions.entryFunction = options.entryFunction; - pm.addPass(lwe::createAddDebugPort(addDebugPortOptions)); - } + lwe::AddDebugPortOptions addDebugPortOptions{ + .entryFunction = options.entryFunction, + .insertDebugAfterEveryOp = options.debug, + }; + pm.addPass(lwe::createAddDebugPort(addDebugPortOptions)); // Convert LWE (and scheme-specific BGV ops) to Lattigo pm.addPass(lwe::createLWEToLattigo()); diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index 3a8731ca93..b26ceaf6f8 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -94,6 +94,10 @@ struct MlirToRLWEPipelineOptions : public LoopOptions { llvm::cl::desc( "The level budget excluding levels required for bootstrap"), llvm::cl::init(10)}; + PassOptions::Option debug{ + *this, "debug", + llvm::cl::desc("Insert debug ports after every secret operation."), + llvm::cl::init(false)}; PassOptions::Option plaintextExecutionResultFileName{ *this, "plaintext-execution-result-file-name", llvm::cl::desc("File name to import execution result from (c.f. --secret-" diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index 6985c58566..291455c515 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -60,6 +60,7 @@ cc_library( "@heir//lib/Dialect/Debug/Transforms:ValidateNames", "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", "@heir//lib/Dialect/Secret/Conversions/SecretToCGGI", + "@heir//lib/Dialect/Secret/Transforms:AddDebugPort", "@heir//lib/Dialect/Secret/Transforms:DistributeGeneric", "@heir//lib/Transforms/AddClientInterface", "@heir//lib/Transforms/BooleanVectorizer", @@ -102,7 +103,6 @@ cc_library( ":PipelineRegistration", "@heir//lib/Dialect/BGV/Conversions/BGVToLWE", "@heir//lib/Dialect/CKKS/Transforms:CKKSToLWE", - "@heir//lib/Dialect/Debug/Transforms", "@heir//lib/Dialect/Debug/Transforms:ValidateNames", "@heir//lib/Dialect/LWE/Conversions/LWEToLattigo", "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", diff --git a/lib/Pipelines/BooleanPipelineRegistration.cpp b/lib/Pipelines/BooleanPipelineRegistration.cpp index f8ae7842cb..d8ee4d4ab3 100644 --- a/lib/Pipelines/BooleanPipelineRegistration.cpp +++ b/lib/Pipelines/BooleanPipelineRegistration.cpp @@ -9,11 +9,10 @@ #include "lib/Dialect/CGGI/Conversions/CGGIToSCIFRBool/CGGIToSCIFRBool.h" #include "lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.h" #include "lib/Dialect/CGGI/Conversions/CGGIToTfheRustBool/CGGIToTfheRustBool.h" -#include "lib/Dialect/Debug/Transforms/Passes.h" #include "lib/Dialect/Debug/Transforms/ValidateNames.h" #include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h" +#include "lib/Dialect/Secret/Transforms/AddDebugPort.h" #include "lib/Dialect/Secret/Transforms/DistributeGeneric.h" -#include "lib/Pipelines/PipelineRegistration.h" #include "lib/Transforms/BooleanVectorizer/BooleanVectorizer.h" #include "lib/Transforms/FoldConstantTensors/FoldConstantTensors.h" #include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h" @@ -23,7 +22,6 @@ #include "lib/Transforms/MemrefToArith/MemrefToArith.h" #include "lib/Transforms/Secretize/Passes.h" #include "lib/Transforms/TensorLinalgToAffineLoops/TensorLinalgToAffineLoops.h" -#include "lib/Transforms/UnusedMemRef/UnusedMemRef.h" #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/Transforms/Passes.h" // from @llvm-project @@ -64,6 +62,8 @@ void mlirToCGGIPipeline(OpPassManager& pm, const std::string& yosysFilesPath, const std::string& abcPath) { pm.addPass(debug::createDebugValidateNames()); + pm.addPass(secret::createSecretAddDebugPort(secret::SecretAddDebugPortOptions{ + .insertDebugAfterEveryOp = options.debug})); pm.addPass(createConvertTensorToLinalgPass()); pm.addPass(createLinalgGeneralizeNamedOpsPass()); @@ -154,6 +154,8 @@ CGGIPipelineBuilder mlirToCGGIPipelineBuilder() { void mlirToCGGIPipeline(OpPassManager& pm, const MLIRToCGGIPipelineOptions& options) { pm.addPass(debug::createDebugValidateNames()); + pm.addPass(secret::createSecretAddDebugPort(secret::SecretAddDebugPortOptions{ + .insertDebugAfterEveryOp = options.debug})); // Bufferize ::mlir::heir::oneShotBufferize(pm); diff --git a/lib/Pipelines/BooleanPipelineRegistration.h b/lib/Pipelines/BooleanPipelineRegistration.h index 4d982b7712..b4d2c002bf 100644 --- a/lib/Pipelines/BooleanPipelineRegistration.h +++ b/lib/Pipelines/BooleanPipelineRegistration.h @@ -20,6 +20,10 @@ enum DataType { Bool, Integer }; #ifndef HEIR_NO_YOSYS // If Yosys is enabled, also add all yosys optimizer pipeline options. struct MLIRToCGGIPipelineOptions : public YosysOptimizerPipelineOptions { + PassOptions::Option debug{ + *this, "debug", + llvm::cl::desc("Insert debug ports after every secret operation."), + llvm::cl::init(false)}; PassOptions::Option dataType{ *this, "data-type", llvm::cl::desc("Data type to use for arithmetization, yosys must be " @@ -44,6 +48,10 @@ void mlirToCGGIPipeline(OpPassManager& pm, #else struct MLIRToCGGIPipelineOptions : public PassPipelineOptions { + PassOptions::Option debug{ + *this, "debug", + llvm::cl::desc("Insert debug ports after every secret operation."), + llvm::cl::init(false)}; PassOptions::Option dataType{ *this, "data-type", llvm::cl::desc("Data type to use for arithmetization."), diff --git a/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.cpp index 6b36fb45a8..615bcf73d7 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.cpp @@ -7,10 +7,12 @@ #include "lib/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Utils/TargetUtils.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/Location.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project @@ -74,34 +76,48 @@ LogicalResult OpenFhePkeDebugEmitter::printOperation(ModuleOp moduleOp) { return success(); } -LogicalResult OpenFhePkeDebugEmitter::emitDebugHelperImpl() { - os << "auto " << kIsBlockArgVar << " = " << kDebugAttrMapParam - << ".at(\"asm.is_block_arg\");\n"; +LogicalResult OpenFhePkeDebugEmitter::emitDebugHelperImpl(Type ctType, + Location loc) { + os << llvm::formatv(kDebugBlockArgCheckTemplate.data(), kIsBlockArgVar, + kDebugAttrMapParam); + + auto emitDecrypt = [&](llvm::StringRef ciphertextAccess) { + os << llvm::formatv(kDebugDecryptTemplate.data(), kPlaintxtVar, kCctxtVar, + kPrivKeyTVar, ciphertextAccess, kDebugAttrMapParam); + }; + + if (llvm::isa(ctType)) { + os << "for (size_t i = 0; i < " << kCiphertxtVar << ".size(); ++i) {\n"; + os.indent(); + os << "if (" << kCiphertxtVar + << ".size() > 1) std::cout << \"Tensor index \" << i << \":\" << " + "std::endl;\n"; + emitDecrypt(llvm::formatv("{0}[i]", kCiphertxtVar).str()); + os.unindent(); + os << "}\n"; + } else { + emitDecrypt(kCiphertxtVar); + } - os << llvm::formatv("if ({0} == \"1\") {{\n", kIsBlockArgVar); - os.indent(); - os << "std::cout << \"Input\" << std::endl;\n"; - os.unindent(); - os << "}"; - os << llvm::formatv(" else {{\n"); - os.indent(); - os << "std::cout" << " " << "<< "; - os << kDebugAttrMapParam << ".at" << "(\"asm.op_name\") << std::endl;\n"; - os.unindent(); - os << "}\n"; - os << "\n"; - - os << "PlaintextT " << kPlaintxtVar << ";\n"; - os << kCctxtVar << "->Decrypt(" << kPrivKeyTVar << ", " << kCiphertxtVar - << ", &" << kPlaintxtVar << ");\n"; - os << kPlaintxtVar << "->SetLength(std::stod(" << kDebugAttrMapParam - << ".at(\"message.size\")));\n"; - os << "std::cout << \" \" << " << kPlaintxtVar << " << std::endl;\n"; return success(); } LogicalResult OpenFhePkeDebugEmitter::printOperation(func::FuncOp funcOp) { - if ((!isDebugPort(funcOp.getName())) || isEmitted) { + if (!isDebugPort(funcOp.getName())) { + return success(); + } + + auto argTypes = funcOp.getArgumentTypes(); + if (argTypes.size() != 3) { + return emitError(funcOp.getLoc(), "Unexpected debug port signature"); + } + auto ctTy = convertType(argTypes[2], funcOp.getLoc()); + if (failed(ctTy)) { + return emitError(funcOp.getLoc(), "Failed to convert type"); + } + + std::string sig = ctTy.value(); + if (!emittedSignatures.insert(sig).second) { return success(); } @@ -116,23 +132,19 @@ LogicalResult OpenFhePkeDebugEmitter::printOperation(func::FuncOp funcOp) { os << " {\n"; os.indent(); - res = emitDebugHelperImpl(); + res = emitDebugHelperImpl(argTypes[2], funcOp.getLoc()); if (failed(res)) { return res; } os.unindent(); os << "}\n"; - isEmitted = true; return success(); } OpenFhePkeDebugEmitter::OpenFhePkeDebugEmitter( raw_ostream& os, OpenfheImportType importType, const std::string& debugImportPath) - : importType_(importType), - os(os), - debugImportPath(debugImportPath), - isEmitted(false) {} + : importType_(importType), os(os), debugImportPath(debugImportPath) {} } // namespace openfhe } // namespace heir diff --git a/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.h b/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.h index c3e8dac5e2..60dbccaf9f 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkeDebugEmitter.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -46,12 +47,12 @@ class OpenFhePkeDebugEmitter { LogicalResult printOperation(::mlir::ModuleOp op); LogicalResult printOperation(::mlir::func::FuncOp op); - LogicalResult emitDebugHelperImpl(); + LogicalResult emitDebugHelperImpl(::mlir::Type ctType, ::mlir::Location loc); /// Include path for debug imports std::string debugImportPath; - bool isEmitted; + std::set emittedSignatures; }; } // namespace openfhe diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 5c36afa721..047a439613 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -1,8 +1,6 @@ #include "lib/Target/OpenFhePke/OpenFhePkeEmitter.h" -#include #include -#include #include #include #include @@ -368,13 +366,17 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::CallOp op) { if (attr.getName().getValue() == "callee") { continue; } - os << debugAttrMapName << "[\"" << attr.getName().getValue() - << "\"] = \""; + os << debugAttrMapName << "[\"" << attr.getName().getValue() << "\"] = "; // Use AsmPrinter to print Attribute if (mlir::isa(attr.getValue())) { - os << mlir::cast(attr.getValue()).getValue() << "\";\n"; + StringRef val = mlir::cast(attr.getValue()).getValue(); + if (val.contains('\n')) { + os << "R\"HEIR_RAW(" << val << ")HEIR_RAW\";\n"; + } else { + os << "\"" << val << "\";\n"; + } } else { - os << attr.getValue() << "\";\n"; + os << "\"" << attr.getValue() << "\";\n"; } } auto ciphertext = op->getOperand(op->getNumOperands() - 1); @@ -1727,13 +1729,27 @@ LogicalResult OpenFhePkeEmitter::printOperation( variableNames->getNameForValue(op.getResult()) + "_filled"; std::string inputVarFilledName = filledPrefix; std::string inputVarFilledLengthName = filledPrefix + "_n"; + auto moduleOp = op->getParentOfType(); + int64_t slotCount = 0; + if (moduleOp) { + auto slotCountAttr = + moduleOp->getAttrOfType("scheme.requested_slot_count"); + if (slotCountAttr) { + slotCount = slotCountAttr.getValue().getSExtValue(); + } + } + os << "auto " << inputVarFilledLengthName << " = " << cc << "->GetCryptoParameters()->GetElementParams()->GetRingDimension() / " "2;\n"; os << "auto " << inputVarFilledName << " = " << inputVarName << ";\n"; os << inputVarFilledName << ".clear();\n"; os << inputVarFilledName << ".reserve(" << inputVarFilledLengthName << ");\n"; - os << "for (auto i = 0; i < " << inputVarFilledLengthName << "; ++i) {\n"; + if (slotCount != 0) { + os << "for (auto i = 0; i < " << slotCount << "; ++i) {\n"; + } else { + os << "for (auto i = 0; i < " << inputVarFilledLengthName << "; ++i) {\n"; + } os << " " << inputVarFilledName << ".push_back(" << inputVarName << "[i % " << inputVarName << ".size()]);\n"; os << "}\n"; @@ -1742,7 +1758,8 @@ LogicalResult OpenFhePkeEmitter::printOperation( // https://github.com/openfheorg/openfhe-development/issues/1046 os << "auto " << variableNames->getNameForValue(op.getResult()) << " = "; os << variableNames->getNameForValue(resultCC.value()) - << "->MakeCKKSPackedPlaintext(" << inputVarFilledName << ");\n"; + << "->MakeCKKSPackedPlaintext(" << inputVarFilledName + << ", 1, 0, nullptr, " << slotCount << ");\n"; return success(); } @@ -1870,6 +1887,17 @@ LogicalResult OpenFhePkeEmitter::printOperation(GenParamsOp op) { } if (op.getBatchSize() != 0) { os << paramsName << ".SetBatchSize(" << op.getBatchSize() << ");\n"; + } else { + // Fallback to scheme.requested_slot_count from module attributes + auto moduleOp = op->getParentOfType(); + if (moduleOp) { + auto slotCountAttr = + moduleOp->getAttrOfType("scheme.requested_slot_count"); + if (slotCountAttr) { + os << paramsName << ".SetBatchSize(" + << slotCountAttr.getValue().getSExtValue() << ");\n"; + } + } } // Modulus chain parameters if (op.getFirstModSize() != 0) { diff --git a/lib/Target/OpenFhePke/OpenFhePkeTemplates.h b/lib/Target/OpenFhePke/OpenFhePkeTemplates.h index 4c47925153..aedc50e2c6 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeTemplates.h +++ b/lib/Target/OpenFhePke/OpenFhePkeTemplates.h @@ -91,8 +91,8 @@ void bind_common(py::module &m) py::class_, std::shared_ptr>>(m, "PrivateKey", py::module_local()) .def(py::init<>()); py::class_>(m, "KeyPair", py::module_local()) - .def_readwrite("publicKey", &KeyPair::publicKey) - .def_readwrite("secretKey", &KeyPair::secretKey); + .def_property_readonly("publicKey", [](const KeyPair &kp) { return kp.publicKey; }) + .def_property_readonly("secretKey", [](const KeyPair &kp) { return kp.secretKey; }); py::class_, std::shared_ptr>>(m, "Ciphertext", py::module_local()) .def(py::init<>()); py::class_, std::shared_ptr>>(m, "CryptoContext", py::module_local()) @@ -124,6 +124,24 @@ constexpr std::string_view KdebugHeaderImports = R"cpp( )cpp"; // clang-format on +// clang-format off +constexpr std::string_view kDebugBlockArgCheckTemplate = R"cpp( +auto {0} = {1}.at("asm.is_block_arg"); +if ({0} == "1") {{ + std::cout << "Input" << std::endl; +}} else {{ + std::cout << {1}.at("asm.op_name") << std::endl; +}} +)cpp"; + +constexpr std::string_view kDebugDecryptTemplate = R"cpp( +PlaintextT {0}; +{1}->Decrypt({2}, {3}, &{0}); +{0}->SetLength(std::stoul({4}.at("message.size"))); +std::cout << " " << {0} << std::endl; +)cpp"; +// clang-format on + } // namespace openfhe } // namespace heir } // namespace mlir diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_idempotency.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_idempotency.mlir new file mode 100644 index 0000000000..6fd918b64a --- /dev/null +++ b/tests/Dialect/LWE/Transforms/add_debug_port_idempotency.mlir @@ -0,0 +1,85 @@ +// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port=insert-debug-after-every-op=true --lwe-add-debug-port=insert-debug-after-every-op=true %s | FileCheck %s + +// CHECK: func.func private @__heir_debug_1(!lwe.lwe_secret_key<{{.*}}>, !lwe.lwe_ciphertext<{{.*}}x{{.*}}64{{.*}}>) + +!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> +!Z65537_i64_ = !mod_arith.int<65537 : i64> + +!rns_L0_ = !rns.rns + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#ring_rns_L0_1_x32_ = #polynomial.ring> + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> + +#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> + +#plaintext_space = #lwe.plaintext_space + +#ciphertext_space_L0_ = #lwe.ciphertext_space + +!ty = !lwe.lwe_ciphertext + +!pt = !lwe.lwe_plaintext> + +func.func @simple_sum(%arg0: !ty) -> !ty { + %c31 = arith.constant 31 : index + %c1_i16 = arith.constant 1 : i16 + %cst = arith.constant dense<0> : tensor<32xi16> + %inserted = tensor.insert %c1_i16 into %cst[%c31] : tensor<32xi16> + %0 = bgv.rotate_cols %arg0 { static_shift = 16 } : !ty + %1 = bgv.add %arg0, %0 : (!ty, !ty) -> !ty + %2 = bgv.rotate_cols %1 { static_shift = 8 } : !ty + %3 = bgv.add %1, %2 : (!ty, !ty) -> !ty + %4 = bgv.rotate_cols %3 { static_shift = 4 } : !ty + %5 = bgv.add %3, %4 : (!ty, !ty) -> !ty + %6 = bgv.rotate_cols %5 { static_shift = 2 } : !ty + %7 = bgv.add %5, %6 : (!ty, !ty) -> !ty + %8 = bgv.rotate_cols %7 { static_shift = 1 } : !ty + %9 = bgv.add %7, %8 : (!ty, !ty) -> !ty + %pt = lwe.rlwe_encode %inserted {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi16> -> !pt + %10 = bgv.mul_plain %9, %pt : (!ty, !pt) -> !ty + %11 = bgv.rotate_cols %10 {static_shift = 31 : index} : !ty + return %11 : !ty +} + +// CHECK: @simple_sum +// CHECK-SAME: (%[[sk:[^:]*]]: [[sk_ty:[^,]*]], +// CHECK-SAME: %[[original_input:[^:]*]]: [[in_ty:[^)]*]]) +// CHECK-SAME: -> [[out_ty:[^{]*]] { + +// CHECK: call @__heir_debug +// CHECK-SAME: (%[[sk]], %[[original_input]]) + +// CHECK-COUNT-12: call @__heir_debug +// CHECK-NOT: call @__heir_debug + +#ring_rns_L0_1_x64_ = #polynomial.ring> +#ciphertext_space_L0_2_ = #lwe.ciphertext_space +!ty2 = !lwe.lwe_ciphertext +!sk1 = !lwe.lwe_secret_key + +func.func private @__heir_debug_0(!sk1, !ty) + +// CHECK: func.func private @__heir_debug_0(!lwe.lwe_secret_key<{{.*}}>, !lwe.lwe_ciphertext<{{.*}}>) + +// CHECK: func.func @mixed_test +// CHECK-SAME: (%[[sk:[^:]*]]: [[sk_ty:[^,]*]], +// CHECK-SAME: %[[arg1:[^:]*]]: [[ty1:[^,]*]], +// CHECK-SAME: %[[arg2:[^:]*]]: [[ty2:[^)]*]]) +func.func @mixed_test(%arg0: !sk1, %arg1: !ty, %arg2: !ty2) -> (!ty, !ty2) { + // CHECK: call @__heir_debug_1(%[[sk]], %[[arg2]]) + // CHECK: call @__heir_debug_0(%[[sk]], %[[arg1]]) + func.call @__heir_debug_0(%arg0, %arg1) : (!sk1, !ty) -> () + + // CHECK: %[[rotate1:.*]] = bgv.rotate_cols %[[arg1]] + %0 = bgv.rotate_cols %arg1 { static_shift = 16 } : !ty + // CHECK: call @__heir_debug_0(%[[sk]], %[[rotate1]]) + + // CHECK: %[[rotate2:.*]] = bgv.rotate_cols %[[arg2]] + %1 = bgv.rotate_cols %arg2 { static_shift = 16 } : !ty2 + // CHECK: call @__heir_debug_1(%[[sk]], %[[rotate2]]) + + return %0, %1 : !ty, !ty2 +} diff --git a/tests/Dialect/LWE/Transforms/add_debug_port_no_debug.mlir b/tests/Dialect/LWE/Transforms/add_debug_port_no_debug.mlir new file mode 100644 index 0000000000..63f59647b8 --- /dev/null +++ b/tests/Dialect/LWE/Transforms/add_debug_port_no_debug.mlir @@ -0,0 +1,27 @@ +// RUN: heir-opt --mlir-print-local-scope --lwe-add-debug-port="entry-function=simple_sum" %s | FileCheck %s + +!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> +!Z65537_i64_ = !mod_arith.int<65537 : i64> + +!rns_L0_ = !rns.rns + +#ring_Z65537_i64_1_x32_ = #polynomial.ring> +#ring_rns_L0_1_x32_ = #polynomial.ring> + +#full_crt_packing_encoding = #lwe.full_crt_packing_encoding +#key = #lwe.key<> + +#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> + +#plaintext_space = #lwe.plaintext_space + +#ciphertext_space_L0_ = #lwe.ciphertext_space + +!ty = !lwe.lwe_ciphertext + +// CHECK: func.func @simple_sum(%[[arg0:.*]]: !lwe.lwe_ciphertext<{{.*}}>) -> !lwe.lwe_ciphertext<{{.*}}> { +// CHECK-NOT: lwe.lwe_secret_key +func.func @simple_sum(%arg0: !ty) -> !ty { + %0 = bgv.rotate_cols %arg0 { static_shift = 16 } : !ty + return %0 : !ty +} diff --git a/tests/Emitter/Openfhe/emit_debug_helper.mlir b/tests/Emitter/Openfhe/emit_debug_helper.mlir index baf4287b0d..1805f44ded 100644 --- a/tests/Emitter/Openfhe/emit_debug_helper.mlir +++ b/tests/Emitter/Openfhe/emit_debug_helper.mlir @@ -18,7 +18,7 @@ // CHECK: } // CHECK: PlaintextT [[ptxt:.*]]; // CHECK: [[cc]]->Decrypt([[sk]], [[ct]], &[[ptxt]]); -// CHECK: [[ptxt]]->SetLength(std::stod([[m]].at("message.size"))); +// CHECK: [[ptxt]]->SetLength(std::stoul([[m]].at("message.size"))); // CHECK: std::cout << " " << [[ptxt]] << std::endl; // CHECK: } diff --git a/tests/Emitter/Openfhe/emit_pybind.mlir b/tests/Emitter/Openfhe/emit_pybind.mlir index 6ba9f3c18f..9aafb6527c 100644 --- a/tests/Emitter/Openfhe/emit_pybind.mlir +++ b/tests/Emitter/Openfhe/emit_pybind.mlir @@ -17,8 +17,8 @@ // CHECK: py::class_, std::shared_ptr>>(m, "PrivateKey", py::module_local()) // CHECK: .def(py::init<>()); // CHECK: py::class_>(m, "KeyPair", py::module_local()) -// CHECK: .def_readwrite("publicKey", &KeyPair::publicKey) -// CHECK: .def_readwrite("secretKey", &KeyPair::secretKey); +// CHECK: .def_property_readonly("publicKey", [](const KeyPair &kp) { return kp.publicKey; }) +// CHECK: .def_property_readonly("secretKey", [](const KeyPair &kp) { return kp.secretKey; }); // CHECK: py::class_, std::shared_ptr>>(m, "Ciphertext", py::module_local()) // CHECK: .def(py::init<>()); // CHECK: py::class_, std::shared_ptr>>(m, "CryptoContext", py::module_local()) diff --git a/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD b/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD index 4691ead83e..6db6d09e62 100644 --- a/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD +++ b/tests/Examples/lattigo/bfv/dot_product_8_debug/BUILD @@ -12,7 +12,6 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=ciphertext-degree=1024 annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD index 897bb9129b..43a6576240 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_dep_16_debug/BUILD @@ -13,7 +13,6 @@ heir_lattigo_lib( "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 \ annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_dep_16.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD index 1d88da97fb..927999e94a 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_dep_8_debug/BUILD @@ -13,7 +13,6 @@ heir_lattigo_lib( "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 \ annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_dep_8.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD index 42d04ef85d..f1fc0b34a9 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_indep_32_debug/BUILD @@ -12,7 +12,6 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_indep_32.mlir", diff --git a/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD b/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD index 6e918c9fde..fbc8fbf30d 100644 --- a/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD +++ b/tests/Examples/lattigo/bfv/noise/mult_indep_8_debug/BUILD @@ -13,7 +13,6 @@ heir_lattigo_lib( "--annotate-module=backend=lattigo scheme=bfv", "--mlir-to-bfv=noise-model=bfv-noise-bmcm23 \ annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:mult_indep_8.mlir", diff --git a/tests/Examples/lattigo/bgv/bgv_debug.go b/tests/Examples/lattigo/bgv/bgv_debug.go index 2c8c76d98b..766651a0b0 100644 --- a/tests/Examples/lattigo/bgv/bgv_debug.go +++ b/tests/Examples/lattigo/bgv/bgv_debug.go @@ -21,6 +21,10 @@ func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.E } fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) ct = v[0] + if ct == nil { + fmt.Println("First ciphertext element is nil") + return + } default: panic(fmt.Sprintf("unexpected type %T", ctObj)) } diff --git a/tests/Examples/lattigo/bgv/cross_level/BUILD b/tests/Examples/lattigo/bgv/cross_level/BUILD index c27c00a597..b6e381edf1 100644 --- a/tests/Examples/lattigo/bgv/cross_level/BUILD +++ b/tests/Examples/lattigo/bgv/cross_level/BUILD @@ -19,7 +19,6 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bgv", "--mlir-to-bgv=ciphertext-degree=1024", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "cross_level.mlir", diff --git a/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go b/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go index 41748c14a5..85f800bf04 100644 --- a/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go +++ b/tests/Examples/lattigo/bgv/cross_level/cross_level_debug.go @@ -21,6 +21,10 @@ func __heir_debug(evaluator *bgv.Evaluator, param bgv.Parameters, encoder *bgv.E } fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) ct = v[0] + if ct == nil { + fmt.Println("First ciphertext element is nil") + return + } default: panic(fmt.Sprintf("unexpected type %T", ctObj)) } diff --git a/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD b/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD index 4efb60c2b6..3af96a3c77 100644 --- a/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD +++ b/tests/Examples/lattigo/bgv/dot_product_8_debug/BUILD @@ -12,7 +12,6 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bgv", "--mlir-to-bgv=ciphertext-degree=8192 annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", diff --git a/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD b/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD index c3c3d7fbf6..82220d3920 100644 --- a/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD +++ b/tests/Examples/lattigo/bgv/dot_product_8_debug_mono/BUILD @@ -12,7 +12,6 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=bgv", "--mlir-to-bgv=ciphertext-degree=8192 noise-model=bgv-noise-mono annotate-noise-bound=true", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", diff --git a/tests/Examples/lattigo/ckks/ckks_debug.go b/tests/Examples/lattigo/ckks/ckks_debug.go index 1aa5842f26..6bf3e76f53 100644 --- a/tests/Examples/lattigo/ckks/ckks_debug.go +++ b/tests/Examples/lattigo/ckks/ckks_debug.go @@ -23,6 +23,10 @@ func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckk } fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) ct = v[0] + if ct == nil { + fmt.Println("First ciphertext element is nil") + return + } default: panic(fmt.Sprintf("unexpected type %T", ctObj)) } @@ -60,8 +64,14 @@ func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckk // calculate the precision if secretExecutionResult, ok := debugAttrMap["secret.execution_result"]; ok { // secretExecutionResult has the form "[1.0, 2.0, 3.0]", parse it into a slice of float64 - plaintextResultStr := strings.Trim(secretExecutionResult, "[]") - plaintextResultStrs := strings.Split(plaintextResultStr, ",") + if !strings.HasPrefix(secretExecutionResult, "[") || !strings.HasSuffix(secretExecutionResult, "]") { + panic(fmt.Sprintf("invalid secret.execution_result format: %s", secretExecutionResult)) + } + plaintextResultStr := secretExecutionResult[1 : len(secretExecutionResult)-1] + var plaintextResultStrs []string + if len(strings.TrimSpace(plaintextResultStr)) > 0 { + plaintextResultStrs = strings.Split(plaintextResultStr, ",") + } plaintextResult := make([]float64, len(plaintextResultStrs)) for i, s := range plaintextResultStrs { plaintextResult[i], err = strconv.ParseFloat(strings.TrimSpace(s), 64) @@ -69,6 +79,9 @@ func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckk panic(err) } } + if len(plaintextResult) != messageSize { + panic(fmt.Sprintf("dimension mismatch: message.size=%d, secret.execution_result length=%d", messageSize, len(plaintextResult))) + } maxError := math.Inf(-1) for i := 0; i < len(value) && i < len(plaintextResult); i++ { diff --git a/tests/Examples/lattigo/ckks/cross_level/BUILD b/tests/Examples/lattigo/ckks/cross_level/BUILD index d6151b4737..cc1badbf5e 100644 --- a/tests/Examples/lattigo/ckks/cross_level/BUILD +++ b/tests/Examples/lattigo/ckks/cross_level/BUILD @@ -19,7 +19,6 @@ heir_lattigo_lib( heir_opt_flags = [ "--annotate-module=backend=lattigo scheme=ckks", "--mlir-to-ckks=ciphertext-degree=4 modulus-switch-before-first-mul=true first-mod-bits=59 scaling-mod-bits=45", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "cross_level.mlir", diff --git a/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go b/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go index 325c044ce5..6ce181a241 100644 --- a/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go +++ b/tests/Examples/lattigo/ckks/cross_level/cross_level_debug.go @@ -22,6 +22,10 @@ func __heir_debug(evaluator *ckks.Evaluator, param ckks.Parameters, encoder *ckk } fmt.Printf("Ciphertext slice of size %d (debugging first element)\n", len(v)) ct = v[0] + if ct == nil { + fmt.Println("First ciphertext element is nil") + return + } default: panic(fmt.Sprintf("unexpected type %T", ctObj)) } diff --git a/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD b/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD index ebe7f943b1..dce52b6ea7 100644 --- a/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD +++ b/tests/Examples/lattigo/ckks/dot_product_8f_debug/BUILD @@ -17,7 +17,6 @@ heir_lattigo_lib( "--mlir-to-ckks=ciphertext-degree=2048 \ encryption-technique-extended=true \ plaintext-execution-result-file-name=$(location @heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log)", - "--lwe-add-debug-port=insert-debug-after-every-op=true", "--scheme-to-lattigo=insert-debug-handler-calls=true", ], mlir_src = "@heir//tests/Examples/common:dot_product_8f.mlir", diff --git a/tests/Examples/openfhe/bfv/debug_helper.cpp b/tests/Examples/openfhe/bfv/debug_helper.cpp index 42bb774127..02ccb701a5 100644 --- a/tests/Examples/openfhe/bfv/debug_helper.cpp +++ b/tests/Examples/openfhe/bfv/debug_helper.cpp @@ -1,8 +1,6 @@ #include "tests/Examples/openfhe/bfv/debug_helper.h" -#include #include -#include #include #include #include @@ -10,11 +8,9 @@ #include #include -#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe -#include "src/core/include/math/hal/nativeintbackend.h" // from @openfhe -#include "src/core/include/utils/inttypes.h" // from @openfhe -#include "src/pke/include/encoding/plaintext-fwd.h" // from @openfhe -#include "src/pke/include/scheme/bfvrns/bfvrns-cryptoparameters.h" // from @openfhe +#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe +#include "src/core/include/utils/inttypes.h" // from @openfhe +#include "src/pke/include/encoding/plaintext-fwd.h" // from @openfhe using lbcrypto::DCRTPoly; using PlaintextT = lbcrypto::Plaintext; @@ -51,7 +47,7 @@ DCRTPoly DecryptCore(const std::vector& cv, #define OP #define DECRYPT -#define NOISE +// #define NOISE void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, const std::map& debugAttrMap) { diff --git a/tests/Examples/openfhe/bfv/debug_helper.h b/tests/Examples/openfhe/bfv/debug_helper.h index 28f2d3ebfd..4830f7f828 100644 --- a/tests/Examples/openfhe/bfv/debug_helper.h +++ b/tests/Examples/openfhe/bfv/debug_helper.h @@ -17,4 +17,8 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, std::vector cts, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector ct, + const std::map& debugAttrMap); + #endif // TESTS_EXAMPLES_OPENFHE_BFV_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/bgv/debug_helper.cpp b/tests/Examples/openfhe/bgv/debug_helper.cpp index 0ff8cbccc7..d8b22125a1 100644 --- a/tests/Examples/openfhe/bgv/debug_helper.cpp +++ b/tests/Examples/openfhe/bgv/debug_helper.cpp @@ -67,7 +67,7 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, #ifdef DECRYPT PlaintextT ptxt; cc->Decrypt(sk, ct, &ptxt); - ptxt->SetLength(std::stod(debugAttrMap.at("message.size"))); + ptxt->SetLength(std::stoul(debugAttrMap.at("message.size"))); std::cout << " " << ptxt << std::endl; #endif diff --git a/tests/Examples/openfhe/bgv/debug_helper.h b/tests/Examples/openfhe/bgv/debug_helper.h index 4466fe2aaf..62a8449986 100644 --- a/tests/Examples/openfhe/bgv/debug_helper.h +++ b/tests/Examples/openfhe/bgv/debug_helper.h @@ -17,4 +17,8 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, std::vector cts, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector ct, + const std::map& debugAttrMap); + #endif // TESTS_EXAMPLES_OPENFHE_BGV_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/ckks/BUILD b/tests/Examples/openfhe/ckks/BUILD index b2ecd16a83..5824efcbc0 100644 --- a/tests/Examples/openfhe/ckks/BUILD +++ b/tests/Examples/openfhe/ckks/BUILD @@ -10,6 +10,7 @@ cc_library( srcs = ["debug_helper.cpp"], hdrs = ["debug_helper.h"], deps = [ + "@com_google_absl//absl/time", "@openfhe//:core", "@openfhe//:pke", ], diff --git a/tests/Examples/openfhe/ckks/debug_helper.cpp b/tests/Examples/openfhe/ckks/debug_helper.cpp index cacfd15a8e..0f6d337e67 100644 --- a/tests/Examples/openfhe/ckks/debug_helper.cpp +++ b/tests/Examples/openfhe/ckks/debug_helper.cpp @@ -12,6 +12,8 @@ #include #include +#include "absl/time/clock.h" // from @com_google_absl +#include "absl/time/time.h" // from @com_google_absl #include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe #include "src/core/include/utils/inttypes.h" // from @openfhe #include "src/pke/include/encoding/plaintext-fwd.h" // from @openfhe @@ -66,7 +68,12 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, #ifdef DECRYPT PlaintextT ptxt; + absl::Time start = absl::Now(); cc->Decrypt(sk, ct, &ptxt); + absl::Time end = absl::Now(); + absl::Duration elapsed = end - start; + std::cout << " Decryption took " << absl::ToDoubleSeconds(elapsed) + << " seconds." << std::endl; ptxt->SetLength(std::stod(debugAttrMap.at("message.size"))); std::vector result; result.reserve(ptxt->GetLength()); diff --git a/tests/Examples/openfhe/ckks/debug_helper.h b/tests/Examples/openfhe/ckks/debug_helper.h index 7fe13e1c5c..d69267fc18 100644 --- a/tests/Examples/openfhe/ckks/debug_helper.h +++ b/tests/Examples/openfhe/ckks/debug_helper.h @@ -5,7 +5,9 @@ #include #include +// IWYU pragma: begin_keep #include "src/pke/include/openfhe.h" // from @openfhe +// IWYU pragma: end_keep using CiphertextT = lbcrypto::Ciphertext; using CryptoContextT = lbcrypto::CryptoContext; @@ -17,4 +19,8 @@ void __heir_debug(CryptoContextT cc, PrivateKeyT sk, std::vector cts, const std::map& debugAttrMap); +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector ct, + const std::map& debugAttrMap); + #endif // TESTS_EXAMPLES_OPENFHE_CKKS_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/ckks/debug_validate/BUILD b/tests/Examples/openfhe/ckks/debug_validate/BUILD new file mode 100644 index 0000000000..3d13c2df0b --- /dev/null +++ b/tests/Examples/openfhe/ckks/debug_validate/BUILD @@ -0,0 +1,31 @@ +load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package(default_applicable_licenses = ["@heir//:license"]) + +cc_library( + name = "debug_helper", + srcs = ["debug_helper.cpp"], + hdrs = ["debug_helper.h"], + deps = [ + "@openfhe//:core", + "@openfhe//:pke", + ], +) + +openfhe_end_to_end_test( + name = "validate_lower_test", + generated_lib_header = "validate_lower_lib.h", + heir_opt_flags = [ + "--annotate-module=backend=openfhe scheme=ckks", + "--mlir-to-ckks=ciphertext-degree=8", + "--scheme-to-openfhe", + "--lwe-add-debug-port", + ], + heir_translate_flags = [ + "--openfhe-debug-helper-include-path=tests/Examples/openfhe/ckks/debug_validate/debug_helper.h", + ], + mlir_src = "validate_lower.mlir", + test_src = "validate_lower_test.cpp", + deps = [":debug_helper"], +) diff --git a/tests/Examples/openfhe/ckks/debug_validate/debug_helper.cpp b/tests/Examples/openfhe/ckks/debug_validate/debug_helper.cpp new file mode 100644 index 0000000000..00ae2ae490 --- /dev/null +++ b/tests/Examples/openfhe/ckks/debug_validate/debug_helper.cpp @@ -0,0 +1,19 @@ +#include "tests/Examples/openfhe/ckks/debug_validate/debug_helper.h" + +#include +#include +#include +#include + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, + const std::map& debugAttrMap) { + auto name = debugAttrMap.at("debug.name"); + assert(name == "input_val" || name == "output_val"); +} + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector ct, + const std::map& debugAttrMap) { + auto name = debugAttrMap.at("debug.name"); + assert(name == "input_val" || name == "output_val"); +} diff --git a/tests/Examples/openfhe/ckks/debug_validate/debug_helper.h b/tests/Examples/openfhe/ckks/debug_validate/debug_helper.h new file mode 100644 index 0000000000..3c2ef05b3f --- /dev/null +++ b/tests/Examples/openfhe/ckks/debug_validate/debug_helper.h @@ -0,0 +1,24 @@ +#ifndef TESTS_EXAMPLES_OPENFHE_CKKS_DEBUG_VALIDATE_DEBUG_HELPER_H_ +#define TESTS_EXAMPLES_OPENFHE_CKKS_DEBUG_VALIDATE_DEBUG_HELPER_H_ + +#include +#include +#include + +#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe +#include "src/pke/include/ciphertext-fwd.h" // from @openfhe +#include "src/pke/include/cryptocontext-fwd.h" // from @openfhe +#include "src/pke/include/key/privatekey-fwd.h" // from @openfhe + +using CiphertextT = lbcrypto::Ciphertext; +using CryptoContextT = lbcrypto::CryptoContext; +using PrivateKeyT = lbcrypto::PrivateKey; + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct, + const std::map& debugAttrMap); + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, + std::vector ct, + const std::map& debugAttrMap); + +#endif // TESTS_EXAMPLES_OPENFHE_CKKS_DEBUG_VALIDATE_DEBUG_HELPER_H_ diff --git a/tests/Examples/openfhe/ckks/debug_validate/validate_lower.mlir b/tests/Examples/openfhe/ckks/debug_validate/validate_lower.mlir new file mode 100644 index 0000000000..d7b28f43e5 --- /dev/null +++ b/tests/Examples/openfhe/ckks/debug_validate/validate_lower.mlir @@ -0,0 +1,6 @@ +func.func @test_validate_lower(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { + debug.validate %arg0 {name = "input_val", metadata = "input_meta"} : tensor<8xf32> + %0 = arith.addf %arg0, %arg0 : tensor<8xf32> + debug.validate %0 {name = "output_val", metadata = "output_meta"} : tensor<8xf32> + return %0 : tensor<8xf32> +} diff --git a/tests/Examples/openfhe/ckks/debug_validate/validate_lower_test.cpp b/tests/Examples/openfhe/ckks/debug_validate/validate_lower_test.cpp new file mode 100644 index 0000000000..2660794398 --- /dev/null +++ b/tests/Examples/openfhe/ckks/debug_validate/validate_lower_test.cpp @@ -0,0 +1,32 @@ +#include + +#include "gtest/gtest.h" // from @googletest + +// Generated headers +#include "tests/Examples/openfhe/ckks/debug_validate/validate_lower_lib.h" + +namespace mlir { +namespace heir { +namespace openfhe { + +TEST(ValidateLowerTest, RunTest) { + auto cryptoContext = test_validate_lower__generate_crypto_context(); + auto keyPair = cryptoContext->KeyGen(); + auto publicKey = keyPair.publicKey; + auto secretKey = keyPair.secretKey; + cryptoContext = + test_validate_lower__configure_crypto_context(cryptoContext, secretKey); + + std::vector arg0 = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}; + + auto arg0Encrypted = + test_validate_lower__encrypt__arg0(cryptoContext, arg0, publicKey); + auto outputEncrypted = + test_validate_lower(cryptoContext, secretKey, arg0Encrypted); + // Just test that no assertions are hit in the callbacks defined in + // debug_helper.h +} + +} // namespace openfhe +} // namespace heir +} // namespace mlir diff --git a/tests/Examples/openfhe/ckks/loop_support/BUILD b/tests/Examples/openfhe/ckks/loop_support/BUILD index ab0eb9e5ca..9de9ba5bc8 100644 --- a/tests/Examples/openfhe/ckks/loop_support/BUILD +++ b/tests/Examples/openfhe/ckks/loop_support/BUILD @@ -13,12 +13,12 @@ openfhe_end_to_end_test( "--scheme-to-openfhe=insert-debug-handler-calls=true", ], heir_translate_flags = [ - "--openfhe-debug-helper-include-path=tests/Examples/openfhe/ckks/loop_support/debug_helper.h", + "--openfhe-debug-helper-include-path=tests/Examples/openfhe/ckks/debug_helper.h", ], mlir_src = "@heir//tests/Examples/common:loop.mlir", test_src = "loop_test.cpp", deps = [ - ":debug_helper", + "@heir//tests/Examples/openfhe/ckks:debug_helper", ], ) diff --git a/tests/Examples/openfhe/ckks/loop_support/loop_test.cpp b/tests/Examples/openfhe/ckks/loop_support/loop_test.cpp index 6de76c784b..523d64320e 100644 --- a/tests/Examples/openfhe/ckks/loop_support/loop_test.cpp +++ b/tests/Examples/openfhe/ckks/loop_support/loop_test.cpp @@ -26,7 +26,8 @@ CryptoContextT override_crypto_context() { params.SetScalingModSize(55); params.SetFirstModSize(60); params.SetRingDim(2048); - params.SetBatchSize(1024); + // Set batch size to 8 to match the test data size and optimize bootstrapping. + params.SetBatchSize(8); params.SetSecurityLevel(HEStd_NotSet); CryptoContextT cc = GenCryptoContext(params); cc->Enable(PKE); @@ -39,8 +40,8 @@ CryptoContextT override_crypto_context() { CryptoContextT override_configure_crypto_context(CryptoContextT cc, PrivateKeyT sk) { cc->EvalMultKeyGen(sk); - cc->EvalBootstrapSetup({3, 3}); - cc->EvalBootstrapKeyGen(sk, 1024); + cc->EvalBootstrapSetup({3, 3}, {0, 0}, 8); + cc->EvalBootstrapKeyGen(sk, 8); return cc; }