Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 72 additions & 34 deletions lib/Dialect/LWE/Transforms/AddDebugPort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <cassert>
#include <string>
#include <utility>

#include "lib/Dialect/Debug/IR/DebugOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
Expand Down Expand Up @@ -63,9 +62,39 @@ FailureOr<Type> getPrivateKeyType(func::FuncOp op) {
return lwePrivateKeyType;
}

void populateDebugFuncCache(ModuleOp module,
llvm::DenseMap<Type, func::FuncOp>& typeToDebugFunc,
llvm::DenseSet<StringRef>& debugFuncNames) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (!funcOp.isExternal()) continue;
if (!funcOp.getName().starts_with("__heir_debug_")) continue;
if (funcOp.getArgumentTypes().size() != 2) continue;
if (!funcOp.getResultTypes().empty()) continue;

typeToDebugFunc[funcOp.getFunctionType()] = funcOp;
debugFuncNames.insert(funcOp.getName());
}
}

static bool isAlreadyDebugged(Value value,
const llvm::DenseSet<StringRef>& debugFuncNames) {
for (auto& use : value.getUses()) {
Operation* user = use.getOwner();
if (isa<debug::ValidateOp>(user)) {
return true;
}
if (auto callOp = dyn_cast<func::CallOp>(user)) {
if (debugFuncNames.contains(callOp.getCallee())) {
return true;
}
}
}
return false;
}

func::FuncOp getOrCreateExternalDebugFunc(
ModuleOp module, Type lwePrivateKeyType, Type valueType,
llvm::DenseMap<Type, func::FuncOp>& typeToDebugFunc) {
ModuleOp module, SymbolTable& symbolTable, Type lwePrivateKeyType,
Type valueType, llvm::DenseMap<Type, func::FuncOp>& typeToDebugFunc) {
auto* context = module.getContext();
auto debugFuncType =
FunctionType::get(context, {lwePrivateKeyType, valueType}, {});
Expand All @@ -75,28 +104,29 @@ func::FuncOp getOrCreateExternalDebugFunc(
return it->second;
}

int counter = typeToDebugFunc.size();
std::string funcName = "__heir_debug_" + std::to_string(counter);
unsigned uniquingCounter = typeToDebugFunc.size();
SmallString<128> funcName = SymbolTable::generateSymbolName<128>(
"__heir_debug",
[&](StringRef name) { return symbolTable.lookup(name) != nullptr; },
uniquingCounter);

// Assert that this name is not already in use.
assert(!module.lookupSymbol<func::FuncOp>(funcName) &&
"Symbol already exists");

ImplicitLocOpBuilder b =
ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody());
auto funcOp = func::FuncOp::create(b, funcName, debugFuncType);
auto funcOp = func::FuncOp::create(module.getLoc(), funcName, debugFuncType);
// required for external func call
funcOp.setPrivate();

symbolTable.insert(funcOp, module.getBody()->begin());

typeToDebugFunc[debugFuncType] = funcOp;
return funcOp;
}

void insertValidationOps(func::FuncOp op) {
void insertValidationOps(func::FuncOp op,
const llvm::DenseSet<StringRef>& debugFuncNames) {
int count = 0;
auto insertValidate = [&](Value value, OpBuilder& b) {
Type valueType = value.getType();
if (isa<LWECiphertextType>(getElementTypeOrSelf(valueType))) {
if (isAlreadyDebugged(value, debugFuncNames)) return;
debug::ValidateOp::create(b, value.getLoc(), value,
"heir_debug_" + std::to_string(count++),
nullptr);
Expand All @@ -121,8 +151,8 @@ void insertValidationOps(func::FuncOp op) {
}

LogicalResult lowerValidationOps(
func::FuncOp op, Value privateKey, int messageSize,
llvm::DenseMap<Type, func::FuncOp>& typeToDebugFunc) {
func::FuncOp op, SymbolTable& symbolTable, Value privateKey,
int messageSize, llvm::DenseMap<Type, func::FuncOp>& typeToDebugFunc) {
auto module = op->getParentOfType<ModuleOp>();
Type lwePrivateKeyType = privateKey.getType();

Expand All @@ -141,10 +171,10 @@ LogicalResult lowerValidationOps(
attrs.push_back(b.getNamedAttr(
"message.size", b.getStringAttr(std::to_string(messageSize))));

auto debugFunc = getOrCreateExternalDebugFunc(module, lwePrivateKeyType,
valueType, typeToDebugFunc);
auto callOp =
b.create<func::CallOp>(debugFunc, ArrayRef<Value>{privateKey, value});
auto debugFunc = getOrCreateExternalDebugFunc(
module, symbolTable, lwePrivateKeyType, valueType, typeToDebugFunc);
auto callOp = func::CallOp::create(b, b.getLoc(), debugFunc,
ArrayRef<Value>{privateKey, value});
callOp->setDialectAttrs(attrs);

validateOp.erase();
Expand All @@ -166,6 +196,10 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
ModuleOp module = cast<ModuleOp>(getOperation());
SymbolTable symbolTable(module);

llvm::DenseMap<Type, func::FuncOp> typeToDebugFunc;
llvm::DenseSet<StringRef> debugFuncNames;
populateDebugFuncCache(module, typeToDebugFunc, debugFuncNames);

SmallVector<func::FuncOp, 16> worklist;
llvm::DenseMap<func::FuncOp, Type> funcToKeyType;
if (failed(identifyInitialTargets(module, symbolTable, funcToKeyType,
Expand All @@ -183,7 +217,7 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {

if (insertDebugAfterEveryOp) {
for (auto& [func, _] : funcToKeyType) {
insertValidationOps(func);
insertValidationOps(func, debugFuncNames);
}
}

Expand All @@ -198,8 +232,8 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
return;
}

llvm::DenseMap<Type, func::FuncOp> typeToDebugFunc;
if (failed(lowerAllValidationOps(module, funcToKeyType, typeToDebugFunc))) {
if (failed(lowerAllValidationOps(module, symbolTable, funcToKeyType,
typeToDebugFunc))) {
signalPassFailure();
return;
}
Expand Down Expand Up @@ -229,18 +263,21 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
}

if (entryFunc) {
auto type = getPrivateKeyType(entryFunc);
if (succeeded(type)) {
bool shouldProcess =
containsAnyOperations<debug::ValidateOp>(entryFunc) ||
insertDebugAfterEveryOp;

if (shouldProcess) {
auto type = getPrivateKeyType(entryFunc);
if (failed(type)) {
entryFunc.emitError(
"Cannot infer LWE private key type for entry function");
return failure();
}
funcToKeyType[entryFunc] = *type;
worklist.push_back(entryFunc);
return success();
}

if (containsAnyOperations<debug::ValidateOp>(entryFunc)) {
entryFunc.emitError(
"Cannot infer LWE private key type for entry function");
return failure();
}
return success();
}

for (auto funcOp : module.getOps<func::FuncOp>()) {
Expand Down Expand Up @@ -404,7 +441,8 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
/// \param typePairToInt Map to track generated debug function names.
/// \return success() if successful, failure() otherwise.
LogicalResult lowerAllValidationOps(
ModuleOp module, const llvm::DenseMap<func::FuncOp, Type>& funcToKeyType,
ModuleOp module, SymbolTable& symbolTable,
const llvm::DenseMap<func::FuncOp, Type>& funcToKeyType,
llvm::DenseMap<Type, func::FuncOp>& typeToDebugFunc) {
for (auto funcOp : module.getOps<func::FuncOp>()) {
if (funcOp.isExternal()) continue;
Expand All @@ -428,8 +466,8 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
}

if (privateKey) {
if (failed(lowerValidationOps(funcOp, privateKey, messageSize,
typeToDebugFunc))) {
if (failed(lowerValidationOps(funcOp, symbolTable, privateKey,
messageSize, typeToDebugFunc))) {
funcOp.emitError("failed to lower validation ops");
return failure();
}
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/LWE/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ cc_library(
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect:FuncUtils",
"@heir//lib/Dialect/Debug/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Utils",
Expand Down
35 changes: 19 additions & 16 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,9 @@ void mlirToPlaintextPipelineBuilder(OpPassManager& pm,
mlirToRLWEPipelineOptions.ciphertextDegree = options.plaintextSize;
mlirToSecretArithmeticPipelineBuilder(pm, mlirToRLWEPipelineOptions);

if (options.debug) {
// Insert debug handler calls
secret::SecretAddDebugPortOptions debugOptions;
debugOptions.insertDebugAfterEveryOp = true;
pm.addPass(secret::createSecretAddDebugPort(debugOptions));
}
// Insert debug handler calls and/or lower debug.validate
pm.addPass(secret::createSecretAddDebugPort(secret::SecretAddDebugPortOptions{
.insertDebugAfterEveryOp = options.debug}));

pm.addPass(secret::createSecretDistributeGeneric());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -411,9 +408,15 @@ void mlirToRLWEPipeline(OpPassManager& pm,
exit(EXIT_FAILURE);
}

// Lower debug.validate ops to function calls with private key
pm.addPass(lwe::createAddDebugPort(
lwe::AddDebugPortOptions{.messageSize = (int)options.ciphertextDegree,
.insertDebugAfterEveryOp = options.debug}));

pm.addPass(createForwardInsertToExtract());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(createSymbolDCEPass());

// TODO(#2554): skip this pass if the backend supports trivial encryption
pm.addPass(lwe::createImplementTrivialEncryptionAsAddition());
Expand Down Expand Up @@ -459,11 +462,11 @@ BackendPipelineBuilder toOpenFhePipelineBuilder() {
pm.addPass(ckks::createCKKSToLWE());

// insert debug handler calls
if (options.debug) {
lwe::AddDebugPortOptions addDebugPortOptions;
addDebugPortOptions.entryFunction = options.entryFunction;
pm.addPass(lwe::createAddDebugPort(addDebugPortOptions));
}
lwe::AddDebugPortOptions addDebugPortOptions{
.entryFunction = options.entryFunction,
.insertDebugAfterEveryOp = options.debug,
};
pm.addPass(lwe::createAddDebugPort(addDebugPortOptions));

// Convert LWE (and scheme-specific CKKS/BGV ops) to OpenFHE
pm.addPass(lwe::createLWEToOpenfhe());
Expand Down Expand Up @@ -501,11 +504,11 @@ BackendPipelineBuilder toLattigoPipelineBuilder() {
pm.addPass(ckks::createCKKSToLWE());

// insert debug handler calls
if (options.debug) {
lwe::AddDebugPortOptions addDebugPortOptions;
addDebugPortOptions.entryFunction = options.entryFunction;
pm.addPass(lwe::createAddDebugPort(addDebugPortOptions));
}
lwe::AddDebugPortOptions addDebugPortOptions{
.entryFunction = options.entryFunction,
.insertDebugAfterEveryOp = options.debug,
};
pm.addPass(lwe::createAddDebugPort(addDebugPortOptions));

// Convert LWE (and scheme-specific BGV ops) to Lattigo
pm.addPass(lwe::createLWEToLattigo());
Expand Down
4 changes: 4 additions & 0 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ struct MlirToRLWEPipelineOptions : public LoopOptions {
llvm::cl::desc(
"The level budget excluding levels required for bootstrap"),
llvm::cl::init(10)};
PassOptions::Option<bool> debug{
*this, "debug",
llvm::cl::desc("Insert debug ports after every secret operation."),
llvm::cl::init(false)};
PassOptions::Option<std::string> plaintextExecutionResultFileName{
*this, "plaintext-execution-result-file-name",
llvm::cl::desc("File name to import execution result from (c.f. --secret-"
Expand Down
2 changes: 1 addition & 1 deletion lib/Pipelines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ cc_library(
"@heir//lib/Dialect/Debug/Transforms:ValidateNames",
"@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial",
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
"@heir//lib/Dialect/Secret/Transforms:AddDebugPort",
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
"@heir//lib/Transforms/AddClientInterface",
"@heir//lib/Transforms/BooleanVectorizer",
Expand Down Expand Up @@ -102,7 +103,6 @@ cc_library(
":PipelineRegistration",
"@heir//lib/Dialect/BGV/Conversions/BGVToLWE",
"@heir//lib/Dialect/CKKS/Transforms:CKKSToLWE",
"@heir//lib/Dialect/Debug/Transforms",
"@heir//lib/Dialect/Debug/Transforms:ValidateNames",
"@heir//lib/Dialect/LWE/Conversions/LWEToLattigo",
"@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe",
Expand Down
8 changes: 5 additions & 3 deletions lib/Pipelines/BooleanPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
#include "lib/Dialect/CGGI/Conversions/CGGIToSCIFRBool/CGGIToSCIFRBool.h"
#include "lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.h"
#include "lib/Dialect/CGGI/Conversions/CGGIToTfheRustBool/CGGIToTfheRustBool.h"
#include "lib/Dialect/Debug/Transforms/Passes.h"
#include "lib/Dialect/Debug/Transforms/ValidateNames.h"
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h"
#include "lib/Dialect/Secret/Transforms/AddDebugPort.h"
#include "lib/Dialect/Secret/Transforms/DistributeGeneric.h"
#include "lib/Pipelines/PipelineRegistration.h"
#include "lib/Transforms/BooleanVectorizer/BooleanVectorizer.h"
#include "lib/Transforms/FoldConstantTensors/FoldConstantTensors.h"
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h"
Expand All @@ -23,7 +22,6 @@
#include "lib/Transforms/MemrefToArith/MemrefToArith.h"
#include "lib/Transforms/Secretize/Passes.h"
#include "lib/Transforms/TensorLinalgToAffineLoops/TensorLinalgToAffineLoops.h"
#include "lib/Transforms/UnusedMemRef/UnusedMemRef.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Transforms/Passes.h" // from @llvm-project
Expand Down Expand Up @@ -64,6 +62,8 @@ void mlirToCGGIPipeline(OpPassManager& pm,
const std::string& yosysFilesPath,
const std::string& abcPath) {
pm.addPass(debug::createDebugValidateNames());
pm.addPass(secret::createSecretAddDebugPort(secret::SecretAddDebugPortOptions{
.insertDebugAfterEveryOp = options.debug}));
pm.addPass(createConvertTensorToLinalgPass());
pm.addPass(createLinalgGeneralizeNamedOpsPass());

Expand Down Expand Up @@ -154,6 +154,8 @@ CGGIPipelineBuilder mlirToCGGIPipelineBuilder() {
void mlirToCGGIPipeline(OpPassManager& pm,
const MLIRToCGGIPipelineOptions& options) {
pm.addPass(debug::createDebugValidateNames());
pm.addPass(secret::createSecretAddDebugPort(secret::SecretAddDebugPortOptions{
.insertDebugAfterEveryOp = options.debug}));
// Bufferize
::mlir::heir::oneShotBufferize(pm);

Expand Down
8 changes: 8 additions & 0 deletions lib/Pipelines/BooleanPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ enum DataType { Bool, Integer };
#ifndef HEIR_NO_YOSYS
// If Yosys is enabled, also add all yosys optimizer pipeline options.
struct MLIRToCGGIPipelineOptions : public YosysOptimizerPipelineOptions {
PassOptions::Option<bool> debug{
*this, "debug",
llvm::cl::desc("Insert debug ports after every secret operation."),
llvm::cl::init(false)};
PassOptions::Option<enum DataType> dataType{
*this, "data-type",
llvm::cl::desc("Data type to use for arithmetization, yosys must be "
Expand All @@ -44,6 +48,10 @@ void mlirToCGGIPipeline(OpPassManager& pm,
#else
struct MLIRToCGGIPipelineOptions
: public PassPipelineOptions<MLIRToCGGIPipelineOptions> {
PassOptions::Option<bool> debug{
*this, "debug",
llvm::cl::desc("Insert debug ports after every secret operation."),
llvm::cl::init(false)};
PassOptions::Option<enum DataType> dataType{
*this, "data-type",
llvm::cl::desc("Data type to use for arithmetization."),
Expand Down
Loading
Loading