Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoCKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,17 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInplaceOp<"rotate"> {
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_CKKSRotateHoistedNewOp : Lattigo_CKKSOp<"rotate_hoisted_new"> {
let summary = "Rotate a ciphertext by a batch of offsets in the Lattigo CKKS dialect";
let description = [{
This operation rotates a ciphertext value by a batch of offsets in the Lattigo CKKS dialect.
}];
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
DenseI64ArrayAttr:$offsets
);
let results = (outs Lattigo_RotatedCiphertextList:$output);
}

#endif // LIB_DIALECT_LATTIGO_IR_LATTIGOCKKSOPS_TD_
28 changes: 15 additions & 13 deletions lib/Dialect/Lattigo/IR/LattigoCKKSTypes.td
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
#ifndef LIB_DIALECT_LATTIGO_IR_LATTIGOCKKSTYPES_TD_
#define LIB_DIALECT_LATTIGO_IR_LATTIGOCKKSTYPES_TD_

include "LattigoAttributes.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"
include "LattigoTypes.td"

class Lattigo_CKKSType<string name, string typeMnemonic>
: Lattigo_Type<"CKKS" # name, "ckks." # typeMnemonic> {
}

// CKKSParameter type definition
def Lattigo_CKKSParameter : Lattigo_CKKSType<"Parameter", "parameter"> {
let description = [{
This type represents the parameters for the CKKS encryption scheme.
}];
let asmName = "param";
}

// CKKSEvaluator type definition
def Lattigo_CKKSEncoder : Lattigo_CKKSType<"Encoder", "encoder"> {
let summary = "Lattigo CKKS encoder type";
let description = [{
This type represents an encoder for the Lattigo CKKS encryption scheme.
}];
let asmName = "encoder";
}

def Lattigo_CKKSEvaluator : Lattigo_CKKSType<"Evaluator", "evaluator"> {
let summary = "Lattigo CKKS evaluator type";
let description = [{
This type represents the evaluator for the CKKS encryption scheme.
This type represents an evaluator for the Lattigo CKKS encryption scheme.
}];
let asmName = "evaluator";
}

// CKKSEncoder type definition
def Lattigo_CKKSEncoder : Lattigo_CKKSType<"Encoder", "encoder"> {
def Lattigo_CKKSRotatedCiphertextMap : Lattigo_CKKSType<"RotatedCiphertextMap", "rotated_ct_map"> {
let summary = "A map from rotation amount to rotated ciphertext.";
let description = [{
This type represents the encoder for the CKKS encryption scheme.
A map from rotation amount to rotated ciphertext.
}];
let asmName = "encoder";
let asmName = "rotated_ct_map";
}



#endif // LIB_DIALECT_LATTIGO_IR_LATTIGOCKKSTYPES_TD_
14 changes: 14 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoRLWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,18 @@ def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [InplaceOpInterface]> {
let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

def Lattigo_RLWELookupRotatedOp : Lattigo_RLWEOp<"lookup_rotated"> {
let summary = "Lookup a rotated ciphertext from a batch of rotated ciphertexts in the Lattigo CKKS dialect";
let description = [{
This operation looks up a rotated ciphertext from the output of a
bulk rotation op.
}];
let arguments = (ins
Lattigo_RotatedCiphertextList:$input,
Builtin_IntegerAttr:$offset
);
let results = (outs Lattigo_RLWECiphertext:$output);
}


#endif // LIB_DIALECT_LATTIGO_IR_LATTIGORLWEOPS_TD_
7 changes: 7 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoRLWETypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ def Lattigo_RLWECiphertext : Lattigo_RLWEType<"Ciphertext", "ciphertext"> {
let asmName = "ct";
}

def Lattigo_RotatedCiphertextList : Lattigo_RLWEType<"RotatedCiphertextList", "rotated_ciphertext_map"> {
let description = [{
This type represents a list of rotated ciphertexts from a bulk rotation op.
}];
let asmName = "rotated_ct_map";
}

def Lattigo_RLWECiphertextOrPlaintext : AnyTypeOf<[Lattigo_RLWECiphertext, Lattigo_RLWEPlaintext]>;

// common interface for RLWE
Expand Down
22 changes: 18 additions & 4 deletions lib/Dialect/Lattigo/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)
package(default_applicable_licenses = ["@heir//:license"])

cc_library(
name = "Transforms",
hdrs = ["Passes.h"],
deps = [
":AllocToInplace",
":ConfigureCryptoContext",
":HoistRotations",
":pass_inc_gen",
"@heir//lib/Dialect/Lattigo/IR:Dialect",
],
Expand Down Expand Up @@ -53,6 +51,22 @@ cc_library(
],
)

cc_library(
name = "HoistRotations",
srcs = ["HoistRotations.cpp"],
hdrs = ["HoistRotations.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/Lattigo/IR:Dialect",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

add_heir_transforms(
header_filename = "Passes.h.inc",
pass_name = "Lattigo",
Expand Down
109 changes: 109 additions & 0 deletions lib/Dialect/Lattigo/Transforms/HoistRotations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "lib/Dialect/Lattigo/Transforms/HoistRotations.h"

#include <cstdint>
#include <memory>

#include "lib/Dialect/Lattigo/IR/LattigoOps.h"
#include "lib/Dialect/Lattigo/IR/LattigoTypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/DebugLog.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project

#define DEBUG_TYPE "lattigo-hoist-rotations"

namespace mlir {
namespace heir {
namespace lattigo {

#define GEN_PASS_DEF_LATTIGOHOISTROTATIONS
#include "lib/Dialect/Lattigo/Transforms/Passes.h.inc"

struct LattigoHoistRotations
: impl::LattigoHoistRotationsBase<LattigoHoistRotations> {
using LattigoHoistRotationsBase::LattigoHoistRotationsBase;

void runOnOperation() override {
getOperation().walk([&](func::FuncOp op) -> WalkResult {
// CKKS
auto ckksResult = getArgOfType<lattigo::CKKSEvaluatorType>(op);
if (succeeded(ckksResult)) {
Value evaluator = ckksResult.value();
processFunc<lattigo::CKKSRotateNewOp, lattigo::CKKSRotateHoistedNewOp,
lattigo::RLWELookupRotatedOp>(op, evaluator);
return WalkResult::advance();
}

// BGV/BFV have different ops, still need to add them.

LDBG() << "Skipping func with no lattigo evaluator arg: "
<< op.getSymName() << "\n";
return WalkResult::advance();
});
}

private:
template <typename RotateOp, typename HoistedRotateOp, typename LookupOp>
void processFunc(func::FuncOp funcOp, Value evaluator) {
IRRewriter builder(funcOp->getContext());
llvm::DenseMap<Value, llvm::SmallVector<RotateOp>> ciphertextToRotateOps;
llvm::DenseMap<Value, llvm::DenseSet<int64_t>>
ciphertextToDistinctRotations;

funcOp->walk([&](RotateOp op) {
ciphertextToRotateOps[op.getInput()].push_back(op);
ciphertextToDistinctRotations[op.getInput()].insert(
op.getOffset().getValue().getSExtValue());
});

for (auto const& [ciphertext, rots] : ciphertextToDistinctRotations) {
if (rots.size() < 2) {
continue;
}
LLVM_DEBUG(llvm::dbgs() << "Found ciphertext with " << rots.size()
<< " distinct rotations: " << ciphertext << "\n");

if (auto* definingOp = ciphertext.getDefiningOp()) {
builder.setInsertionPointAfter(definingOp);
} else {
builder.setInsertionPointToStart(
cast<BlockArgument>(ciphertext).getOwner());
}

SmallVector<int64_t> offsets;
for (int64_t rot : rots) {
offsets.push_back(rot);
}

auto hoistedRotateOp = HoistedRotateOp::create(
builder, ciphertext.getLoc(),
lattigo::RLWERotatedCiphertextListType::get(builder.getContext()),
evaluator, ciphertext, builder.getDenseI64ArrayAttr(offsets));

for (RotateOp op : ciphertextToRotateOps[ciphertext]) {
builder.setInsertionPoint(op);
auto lookupOp =
LookupOp::create(builder, op.getLoc(), op.getType(),
hoistedRotateOp.getResult(), op.getOffset());
builder.replaceOp(op, lookupOp.getResult());
}
}
}
};

std::unique_ptr<Pass> createLattigoHoistRotations() {
return std::make_unique<LattigoHoistRotations>();
}

} // namespace lattigo
} // namespace heir
} // namespace mlir
21 changes: 21 additions & 0 deletions lib/Dialect/Lattigo/Transforms/HoistRotations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef LIB_DIALECT_LATTIGO_TRANSFORMS_HOISTROTATIONS_H_
#define LIB_DIALECT_LATTIGO_TRANSFORMS_HOISTROTATIONS_H_

#include <memory>

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace lattigo {

#define GEN_PASS_DECL_HOISTROTATIONS
#include "lib/Dialect/Lattigo/Transforms/Passes.h.inc"

std::unique_ptr<Pass> createLattigoHoistRotations();

} // namespace lattigo
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_LATTIGO_TRANSFORMS_HOISTROTATIONS_H_
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "lib/Dialect/Lattigo/IR/LattigoDialect.h"
#include "lib/Dialect/Lattigo/Transforms/AllocToInplace.h"
#include "lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.h"
#include "lib/Dialect/Lattigo/Transforms/HoistRotations.h"

namespace mlir {
namespace heir {
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/Lattigo/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,14 @@ def ConfigureCryptoContext : Pass<"lattigo-configure-crypto-context"> {
];
}

def LattigoHoistRotations : Pass<"lattigo-hoist-rotations", "::mlir::func::FuncOp"> {
let summary = "Hoist multiple rotations of the same ciphertext into a single operation.";
let description = [{
This pass identifies when a ciphertext is rotated by multiple different
shifts, and replaces the `Rotate` ops with `RotateHoistedNew`
followed by `LookupRotated`.
}];
let constructor = "mlir::heir::lattigo::createLattigoHoistRotations()";
}

#endif // LIB_DIALECT_LATTIGO_TRANSFORMS_PASSES_TD_
22 changes: 22 additions & 0 deletions tests/Dialect/Lattigo/Transforms/hoist_rotations.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

// RUN: heir-opt --lattigo-hoist-rotations %s | FileCheck %s

!evaluator = !lattigo.ckks.evaluator
!ct = !lattigo.rlwe.ciphertext

module {
func.func @simple_sum(%evaluator: !evaluator, %ct: !ct) -> !ct {
// CHECK: lattigo.ckks.rotate_hoisted_new
// CHECK-COUNT-4: lattigo.rlwe.lookup_rotated
// CHECK-NOT: lattigo.ckks.rotate_new
%ct_0 = lattigo.ckks.rotate_new %evaluator, %ct {offset = 16} : (!evaluator, !ct) -> !ct
%ct_1 = lattigo.ckks.add_new %evaluator, %ct, %ct_0 : (!evaluator, !ct, !ct) -> !ct
%ct_2 = lattigo.ckks.rotate_new %evaluator, %ct {offset = 8} : (!evaluator, !ct) -> !ct
%ct_3 = lattigo.ckks.add_new %evaluator, %ct_1, %ct_2 : (!evaluator, !ct, !ct) -> !ct
%ct_4 = lattigo.ckks.rotate_new %evaluator, %ct {offset = 5} : (!evaluator, !ct) -> !ct
%ct_5 = lattigo.ckks.add_new %evaluator, %ct_3, %ct_4 : (!evaluator, !ct, !ct) -> !ct
%ct_6 = lattigo.ckks.rotate_new %evaluator, %ct {offset = 12} : (!evaluator, !ct) -> !ct
%ct_7 = lattigo.ckks.add_new %evaluator, %ct_5, %ct_6 : (!evaluator, !ct, !ct) -> !ct
return %ct_7 : !ct
}
}
Loading