Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
"@heir//lib/Dialect/Lattigo/IR:Dialect",
"@heir//lib/Dialect/Preprocessing/Conversions:Util",
"@heir//lib/Dialect/Preprocessing/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "lib/Dialect/Preprocessing/Conversions/Util.h"
#include "lib/Dialect/Preprocessing/IR/PreprocessingDialect.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/Utils.h"
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand All @@ -31,6 +32,10 @@ struct PreprocessingToLattigo
void runOnOperation() override {
ModuleOp module = getOperation();

if (!containsDialects<PreprocessingDialect>(module)) {
return;
}

PreprocessingStorageLayoutAnalysis analysis(module);
if (!analysis.isValid()) {
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
"@heir//lib/Analysis/PreprocessingStorageLayoutAnalysis",
"@heir//lib/Dialect/Preprocessing/Conversions:Util",
"@heir//lib/Dialect/Preprocessing/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineAnalysis",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "lib/Dialect/Preprocessing/IR/PreprocessingOps.h"
#include "lib/Dialect/Preprocessing/IR/PreprocessingTypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project
Expand Down Expand Up @@ -214,6 +215,10 @@ struct PreprocessingToMemref
void runOnOperation() override {
ModuleOp module = getOperation();

if (!containsDialects<PreprocessingDialect>(module)) {
return;
}

PreprocessingStorageLayoutAnalysis analysis(module);
if (!analysis.isValid()) {
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@heir//lib/Dialect/Preprocessing/Conversions:Util",
"@heir//lib/Dialect/Preprocessing/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "lib/Dialect/Preprocessing/Conversions/Util.h"
#include "lib/Dialect/Preprocessing/IR/PreprocessingDialect.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/Utils.h"
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand All @@ -31,6 +32,10 @@ struct PreprocessingToOpenfhe
void runOnOperation() override {
ModuleOp module = getOperation();

if (!containsDialects<PreprocessingDialect>(module)) {
return;
}

PreprocessingStorageLayoutAnalysis analysis(module);
if (!analysis.isValid()) {
signalPassFailure();
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Preprocessing/IR/PreprocessingOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Preprocessing_Op<string mnemonic, list<Trait> traits = []> :
let cppNamespace = "::mlir::heir::preprocessing";
}

def Preprocessing_EmptyOp : Preprocessing_Op<"empty"> {
def Preprocessing_EmptyOp : Preprocessing_Op<"empty", [MemoryEffectsOpInterface]> {
let summary = "Allocate an empty preprocessing storage";
let description = [{
This op creates a new `preprocessing.storage` value with an undetermined size.
Expand All @@ -23,7 +23,7 @@ def Preprocessing_EmptyOp : Preprocessing_Op<"empty"> {
let assemblyFormat = "attr-dict `:` type($storage)";
}

def Preprocessing_StoreOp : Preprocessing_Op<"store"> {
def Preprocessing_StoreOp : Preprocessing_Op<"store", [MemoryEffectsOpInterface]> {
let summary = "Store a value to preprocessing storage at given indices and site_id";
let description = [{
This op stores a value into a `preprocessing.storage`, while tracking:
Expand All @@ -49,7 +49,7 @@ def Preprocessing_StoreOp : Preprocessing_Op<"store"> {
let assemblyFormat = "$value `,` $storage `[` $indices `]` `site` $site_id `<` $element_type `>` attr-dict `:` type($value) `,` type($storage)";
}

def Preprocessing_LoadOp : Preprocessing_Op<"load"> {
def Preprocessing_LoadOp : Preprocessing_Op<"load", [MemoryEffectsOpInterface]> {
let summary = "Load a value from preprocessing storage at given indices and site_id";
let description = [{
This op loads a value from a `preprocessing.storage`, while tracking:
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Preprocessing/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cc_library(
":pass_inc_gen",
"@heir//lib/Dialect/Preprocessing/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
Expand Down
49 changes: 28 additions & 21 deletions lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

#include "lib/Dialect/Preprocessing/IR/PreprocessingOps.h"
#include "lib/Dialect/Preprocessing/IR/PreprocessingTypes.h"
#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand All @@ -25,30 +26,36 @@ struct ValidatePreprocessing
using ValidatePreprocessingBase::ValidatePreprocessingBase;

void runOnOperation() override {
SmallVector<EmptyOp, 2> emptyOps;
Operation* module = getOperation();

bool hasMultipleEmpties = false;
module->walk([&](func::FuncOp funcOp) {
SmallVector<EmptyOp> emptyOps(funcOp.getOps<EmptyOp>());
if (emptyOps.size() > 1) {
for (size_t i = 1; i < emptyOps.size(); ++i) {
auto diag =
emptyOps[i]->emitOpError()
<< "more than one preprocessing.empty allocation in function";
diag.attachNote(emptyOps[0]->getLoc()) << "previous allocation here";
}
hasMultipleEmpties = true;
}
});
if (hasMultipleEmpties) {
signalPassFailure();
}

DenseMap<uint32_t, SmallVector<StoreOp, 2>> storesBySite;
DenseMap<uint32_t, SmallVector<LoadOp, 2>> loadsBySite;

getOperation()->walk([&](Operation* op) {
if (auto emptyOp = dyn_cast<EmptyOp>(op)) {
emptyOps.push_back(emptyOp);
} else if (auto storeOp = dyn_cast<StoreOp>(op)) {
module->walk([&](Operation* op) {
if (auto storeOp = dyn_cast<StoreOp>(op)) {
storesBySite[storeOp.getSiteId()].push_back(storeOp);
} else if (auto loadOp = dyn_cast<LoadOp>(op)) {
loadsBySite[loadOp.getSiteId()].push_back(loadOp);
}
});

if (emptyOps.size() > 1) {
for (size_t i = 1; i < emptyOps.size(); ++i) {
auto diag = emptyOps[i]->emitOpError()
<< "more than one preprocessing.empty allocation in module";
diag.attachNote(emptyOps[0]->getLoc()) << "previous allocation here";
}
signalPassFailure();
// Do NOT return, continue to site pairing checks
}

std::set<uint32_t> siteIds;
for (const auto& [siteId, stores] : storesBySite) {
siteIds.insert(siteId);
Expand Down
16 changes: 12 additions & 4 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h"
#include "lib/Dialect/Openfhe/Transforms/CountAddAndKeySwitch.h"
#include "lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.h"
#include "lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo/PreprocessingToLattigo.h"
#include "lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe/PreprocessingToOpenfhe.h"
#include "lib/Dialect/Preprocessing/Transforms/ValidatePreprocessing.h"
#include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h"
#include "lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h"
#include "lib/Dialect/Secret/Conversions/SecretToModArith/SecretToModArith.h"
Expand Down Expand Up @@ -142,6 +145,8 @@ void cleanupAfterLowerAssignLayout(OpPassManager& pm) {
pm.addPass(createCSEPass());
pm.addPass(createRemoveDeadValuesPass());
pm.addPass(createSymbolDCEPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
}

// Implement layout conversions as shift networks
Expand Down Expand Up @@ -423,9 +428,10 @@ void mlirToRLWEPipeline(OpPassManager& pm,
pm.addPass(lwe::createImplementTrivialEncryptionAsAddition());

// Add a __preprocessed helper for offline pre-packing of plaintexts
auto splitPreprocessingOptions = SplitPreprocessingOptions{};
splitPreprocessingOptions.maxReturnValues = options.splitPreprocessing;
pm.addPass(createSplitPreprocessing(splitPreprocessingOptions));
if (options.enableSplitPreprocessing) {
pm.addPass(createSplitPreprocessing());
pm.addPass(preprocessing::createValidatePreprocessing());
}

ElementwiseToAffineOptions elementwiseOptions;
elementwiseOptions.convertDialects = {"ckks", "bgv", "lwe"};
Expand Down Expand Up @@ -471,6 +477,7 @@ BackendPipelineBuilder toOpenFhePipelineBuilder() {

// Convert LWE (and scheme-specific CKKS/BGV ops) to OpenFHE
pm.addPass(lwe::createLWEToOpenfhe());
pm.addPass(preprocessing::createPreprocessingToOpenfhe());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

Expand Down Expand Up @@ -513,6 +520,7 @@ BackendPipelineBuilder toLattigoPipelineBuilder() {

// Convert LWE (and scheme-specific BGV ops) to Lattigo
pm.addPass(lwe::createLWEToLattigo());
pm.addPass(preprocessing::createPreprocessingToLattigo());

// Convert Alloc Ops to InPlace Ops
pm.addPass(lattigo::createAllocToInPlace());
Expand Down Expand Up @@ -567,7 +575,7 @@ void torchLinalgToCkksBuilder(OpPassManager& manager,
suboptions.ckksBootstrapWaterline = options.ckksBootstrapWaterline;
suboptions.scalingModBits = options.scalingModBits;
suboptions.firstModBits = options.firstModBits;
suboptions.splitPreprocessing = options.splitPreprocessing;
suboptions.enableSplitPreprocessing = options.enableSplitPreprocessing;
suboptions.experimentalDisableLoopUnroll =
options.experimentalDisableLoopUnroll;
suboptions.usePublicKey = options.usePublicKey;
Expand Down
10 changes: 5 additions & 5 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ struct MlirToRLWEPipelineOptions : public LoopOptions {
llvm::cl::desc("File name to import execution result from (c.f. --secret-"
"import-execution-result)"),
llvm::cl::init("")};
PassOptions::Option<int> splitPreprocessing{
*this, "split-preprocessing",
llvm::cl::desc("Split preprocessing into separate function with N return "
"values (default to no split)"),
llvm::cl::init(16)};
PassOptions::Option<bool> enableSplitPreprocessing{
*this, "enable-split-preprocessing",
llvm::cl::desc(
"Split server-side plaintext preprocessing into a separate function"),
llvm::cl::init(true)};
};

struct PlaintextBackendOptions
Expand Down
3 changes: 3 additions & 0 deletions lib/Pipelines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ cc_library(
"@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext",
"@heir//lib/Dialect/Openfhe/Transforms:CountAddAndKeySwitch",
"@heir//lib/Dialect/Openfhe/Transforms:FastRotationPrecompute",
"@heir//lib/Dialect/Preprocessing/Conversions/PreprocessingToLattigo",
"@heir//lib/Dialect/Preprocessing/Conversions/PreprocessingToOpenfhe",
"@heir//lib/Dialect/Preprocessing/Transforms",
"@heir//lib/Dialect/Secret/Conversions/SecretToBGV",
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
"@heir//lib/Dialect/Secret/Conversions/SecretToCKKS",
Expand Down
1 change: 0 additions & 1 deletion lib/Target/Lattigo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ cc_library(
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
Expand Down
25 changes: 17 additions & 8 deletions lib/Target/Lattigo/LattigoEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,15 @@ LogicalResult LattigoEmitter::printOperation(func::FuncOp funcOp) {
// name and arg
std::string funcName = funcOp.getName().str();
if (funcOp->hasAttr(kClientPackFuncAttrName)) {
if (!funcName.empty() && funcFilter && funcFilter(funcOp) &&
std::islower(funcName[0])) {
// Upper case this only if we are emitting it in a filtered file so it
// needs exporting.
funcName[0] = std::toupper(funcName[0]);
if (!funcName.empty() && funcFilter && funcFilter(funcOp)) {
while (funcName[0] == '_') {
funcName.erase(0, 1);
}
if (std::islower(funcName[0])) {
// Upper case this only if we are emitting it in a filtered file so it
// needs exporting.
funcName[0] = std::toupper(funcName[0]);
}
}
}
os << "func " << funcName << "(";
Expand Down Expand Up @@ -286,11 +290,16 @@ LogicalResult LattigoEmitter::printOperation(func::CallOp op) {
auto moduleOp = op->getParentOfType<ModuleOp>();
auto calleeOp = moduleOp.lookupSymbol<func::FuncOp>(callee);
std::string calleeName = canonicalizeDebugPort(callee).str();
if (calleeOp && funcFilter && !funcFilter(calleeOp)) {
if (calleeOp->hasAttr(kClientPackFuncAttrName)) {
if (!calleeName.empty() && std::islower(calleeName[0])) {
if (calleeOp && calleeOp->hasAttr(kClientPackFuncAttrName)) {
if (funcFilter && !calleeName.empty()) {
while (calleeName[0] == '_') {
calleeName.erase(0, 1);
}
if (std::islower(calleeName[0])) {
calleeName[0] = std::toupper(calleeName[0]);
}
}
if (funcFilter && !funcFilter(calleeOp)) {
calleeName = packageName + "_utils." + calleeName;
extraImportsUsed = true;
}
Expand Down
8 changes: 3 additions & 5 deletions lib/Transforms/SplitPreprocessing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@ cc_library(
hdrs = ["SplitPreprocessing.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect:HEIRInterfaces",
"@heir//lib/Dialect:ModuleAttributes",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:SecretPatterns",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Dialect/Preprocessing/IR:Dialect",
"@heir//lib/Utils:AttributeUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)
Expand Down
Loading
Loading