Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ struct ConvertRlweOpBootstrap : public OpConversionPattern<BootstrapOp> {
op, LattigoBootstrapOp::create(
rewriter, op.getLoc(),
this->typeConverter->convertType(op.getOutput().getType()),
evaluator, adaptor.getInput()));
evaluator, adaptor.getInput(), adaptor.getInput()));
return success();
}
};
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Lattigo/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:SideEffectInterfaces",
],
)

Expand Down Expand Up @@ -94,6 +95,7 @@ cc_library(
"@heir//lib/Utils/Tablegen:InPlaceOpInterface",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
],
)
Expand Down
48 changes: 30 additions & 18 deletions lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include "LattigoDialect.td"
include "LattigoTypes.td"
include "lib/Dialect/HEIRInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class Lattigo_BGVOp<string mnemonic, list<Trait> traits = []> :
Lattigo_Op<"bgv." # mnemonic, traits> {
Expand All @@ -13,15 +14,18 @@ class Lattigo_BGVOp<string mnemonic, list<Trait> traits = []> :
// This operation cannot be marked as Pure because it represents a buffer allocation.
// If marked as Pure, CSE will deduplicate them, causing multiple in-place operations
// (like encode) to share the same buffer and overwrite each other.
def Lattigo_BGVNewPlaintextOp : Lattigo_BGVOp<"new_plaintext"> {
def Lattigo_BGVNewPlaintextOp : Lattigo_BGVOp<"new_plaintext", [
MemoryEffects<[MemAllocAt<0, FullEffect>]>
]> {
let summary = "Create a new plaintext in the Lattigo BGV dialect";
let description = [{
This operation creates a new plaintext value in the Lattigo BGV dialect.
}];
let arguments = (ins
Lattigo_BGVParameter:$params
);
let results = (outs Lattigo_RLWEPlaintext:$plaintext);
let results = (outs Res<Lattigo_RLWEPlaintext, "",
[MemAllocAt<0, FullEffect>]>:$plaintext);
}

def Lattigo_BGVNewParametersFromLiteralOp : Lattigo_BGVOp<"new_parameters_from_literal", [Pure]> {
Expand All @@ -46,7 +50,11 @@ def Lattigo_BGVNewEncoderOp : Lattigo_BGVOp<"new_encoder", [Pure]> {
let results = (outs Lattigo_BGVEncoder:$encoder);
}

def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InPlaceOpInterface, PlaintextEncodeOpInterface]> {
def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [
InPlaceOpInterface,
PlaintextEncodeOpInterface,
DefaultMemWriteTrait,
]> {
let summary = "Encode a plaintext value in the Lattigo BGV dialect";
let description = [{
This operation encodes a plaintext value using the specified encoder in the Lattigo BGV dialect.
Expand All @@ -58,15 +66,14 @@ def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InPlaceOpInterface, Plaintext
let arguments = (ins
Lattigo_BGVEncoder:$encoder,
TensorOrMemRef<[AnyInteger]>:$value,
Lattigo_RLWEPlaintext:$plaintext,
Arg<Lattigo_RLWEPlaintext, "", [FullMemWrite]>:$plaintext,
DefaultValuedAttr<I64Attr, "1">:$scale
);
let results = (outs Lattigo_RLWEPlaintext:$encoded);

let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }";
}

def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>]> {
def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>, DefaultMemReadTrait]> {
let summary = "Decode a plaintext value in the Lattigo BGV dialect";
let description = [{
This operation decodes a plaintext value using the specified encoder in the Lattigo BGV dialect.
Expand All @@ -75,7 +82,7 @@ def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "deco
}];
let arguments = (ins
Lattigo_BGVEncoder:$encoder,
Lattigo_RLWEPlaintext:$plaintext,
Arg<Lattigo_RLWEPlaintext, "", [FullMemRead]>:$plaintext,
TensorOrMemRef<[AnyInteger]>:$value
// Rely on Lattigo internal tracking of scale for decoding.
// DefaultValuedAttr<I64Attr, "1">:$scale
Expand Down Expand Up @@ -110,11 +117,11 @@ def Lattigo_BGVNewEvaluatorOp : Lattigo_BGVOp<"new_evaluator", [Pure]> {
// ciphertext arithmetic op

class Lattigo_BGVBinaryOp<string mnemonic, list<Trait> traits = []> :
Lattigo_BGVOp<mnemonic, traits # [Pure]> {
Lattigo_BGVOp<mnemonic, traits # [Pure, DefaultMemReadTrait]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertextOrPlaintext:$rhs
Arg<Lattigo_RLWECiphertextOrPlaintext, "", [FullMemRead]>:$rhs
);
let results = (outs Lattigo_RLWECiphertext:$output);
}
Expand All @@ -141,17 +148,19 @@ def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new", [IncreasesMulDepthOpInt
}

class Lattigo_BGVBinaryInPlaceOp<string mnemonic, list<Trait> traits = []> :
Lattigo_BGVOp<mnemonic, traits # [InPlaceOpInterface]> {
Lattigo_BGVOp<mnemonic, traits # [
InPlaceOpInterface,
DefaultMemWriteTrait,
]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertextOrPlaintext:$rhs,
Arg<Lattigo_RLWECiphertextOrPlaintext, "lattigo ciphertext or plaintext", [FullMemRead]>:$rhs,
// Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to
// satisfy the SSA form, so we still have a separate output.
Lattigo_RLWECiphertext:$inplace
Arg<Lattigo_RLWECiphertextOrPlaintext, "lattigo ciphertext or plaintext", [FullMemWrite]>:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInPlaceOperandIndex() { return 3; }";
}

Expand Down Expand Up @@ -241,12 +250,15 @@ def Lattigo_BGVRotateRowsNewOp : Lattigo_BGVUnaryOp<"rotate_rows_new"> {
}

class Lattigo_BGVUnaryInPlaceOp<string mnemonic, list<Trait> traits = []> :
Lattigo_BGVOp<mnemonic, traits # [InPlaceOpInterface]> {
Lattigo_BGVOp<mnemonic, traits # [
InPlaceOpInterface,
DefaultMemWriteTrait,
]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
// see BinaryInPlaceOp above
Lattigo_RLWECiphertext:$inplace
Arg<Lattigo_RLWECiphertext, "", [FullMemWrite]>:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

Expand Down Expand Up @@ -276,7 +288,7 @@ def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [
}

def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInPlaceOp<"rotate_columns", [
DeclareOpInterfaceMethods<RotationOpInterface>
DeclareOpInterfaceMethods<RotationOpInterface>, DefaultMemWriteTrait,
]> {
let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect";
let description = [{
Expand All @@ -291,8 +303,8 @@ def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInPlaceOp<"rotate_columns", [
}];
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Lattigo_RLWECiphertext:$inplace,
Arg<Lattigo_RLWECiphertext, "", [FullMemRead]>:$input,
Arg<Lattigo_RLWECiphertext, "", [FullMemWrite]>:$inplace,
Optional<AnySignlessIntegerOrIndex>:$dynamic_shift,
OptionalAttr<Builtin_IntegerAttr>:$static_shift
);
Expand Down
58 changes: 38 additions & 20 deletions lib/Dialect/Lattigo/IR/LattigoCKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ class Lattigo_CKKSOp<string mnemonic, list<Trait> traits = []> :
// This operation cannot be marked as Pure because it represents a buffer allocation.
// If marked as Pure, CSE will deduplicate them, causing multiple in-place operations
// (like encode) to share the same buffer and overwrite each other.
def Lattigo_CKKSNewPlaintextOp : Lattigo_CKKSOp<"new_plaintext"> {
def Lattigo_CKKSNewPlaintextOp : Lattigo_CKKSOp<"new_plaintext", [
MemoryEffects<[MemAllocAt<0, FullEffect>]>
]> {
let summary = "Create a new plaintext in the Lattigo CKKS dialect";
let description = [{
This operation creates a new plaintext value in the Lattigo CKKS dialect.
}];
let arguments = (ins
Lattigo_CKKSParameter:$params
);
let results = (outs Lattigo_RLWEPlaintext:$plaintext);
let results = (outs Res<Lattigo_RLWEPlaintext, "",
[MemAllocAt<0, FullEffect>]>:$plaintext);
}

def Lattigo_CKKSNewParametersFromLiteralOp : Lattigo_CKKSOp<"new_parameters_from_literal", [Pure]> {
Expand All @@ -48,7 +51,11 @@ def Lattigo_CKKSNewEncoderOp : Lattigo_CKKSOp<"new_encoder", [Pure]> {
let results = (outs Lattigo_CKKSEncoder:$encoder);
}

def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [InPlaceOpInterface, PlaintextEncodeOpInterface]> {
def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [
InPlaceOpInterface,
PlaintextEncodeOpInterface,
DefaultMemWriteTrait,
]> {
let summary = "Encode a plaintext value in the Lattigo CKKS dialect";
let description = [{
This operation encodes a plaintext value using the specified encoder in the Lattigo CKKS dialect.
Expand All @@ -60,15 +67,15 @@ def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [InPlaceOpInterface, Plainte
let arguments = (ins
Lattigo_CKKSEncoder:$encoder,
TensorOrMemRef<[AnyFloat, AnyComplex]>:$value,
Lattigo_RLWEPlaintext:$plaintext,
Arg<Lattigo_RLWEPlaintext, "", [FullMemWrite]>:$plaintext,
DefaultValuedAttr<I64Attr, "1">:$scale
);
let results = (outs Lattigo_RLWEPlaintext:$encoded);

let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }";
}

def Lattigo_CKKSDecodeOp : Lattigo_CKKSOp<"decode", [AllTypesMatch<["value", "decoded"]>]> {
def Lattigo_CKKSDecodeOp : Lattigo_CKKSOp<"decode", [AllTypesMatch<["value", "decoded"]>, DefaultMemReadTrait]> {
let summary = "Decode a plaintext value in the Lattigo CKKS dialect";
let description = [{
This operation decodes a plaintext value using the specified encoder in the Lattigo CKKS dialect.
Expand All @@ -77,7 +84,7 @@ def Lattigo_CKKSDecodeOp : Lattigo_CKKSOp<"decode", [AllTypesMatch<["value", "de
}];
let arguments = (ins
Lattigo_CKKSEncoder:$encoder,
Lattigo_RLWEPlaintext:$plaintext,
Arg<Lattigo_RLWEPlaintext, "", [FullMemRead]>:$plaintext,
TensorOrMemRef<[AnyFloat, AnyComplex]>:$value
);
// although bgv.Decode is also an inplace operation as bgv.Encode, as there are post-processing
Expand Down Expand Up @@ -144,11 +151,11 @@ def Lattigo_CKKSNewBootstrappingEvaluatorOp : Lattigo_CKKSOp<"new_bootstrapping_
// ciphertext arithmetic op

class Lattigo_CKKSBinaryOp<string mnemonic, list<Trait> traits = []> :
Lattigo_CKKSOp<mnemonic, traits # [Pure]> {
Lattigo_CKKSOp<mnemonic, traits # [Pure, DefaultMemReadTrait]> {
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertextOrPlaintext:$rhs
Arg<Lattigo_RLWECiphertextOrPlaintext, "", [FullMemRead]>:$rhs
);
let results = (outs Lattigo_RLWECiphertext:$output);
}
Expand All @@ -175,17 +182,19 @@ def Lattigo_CKKSMulNewOp : Lattigo_CKKSBinaryOp<"mul_new", [IncreasesMulDepthOpI
}

class Lattigo_CKKSBinaryInPlaceOp<string mnemonic, list<Trait> traits = []> :
Lattigo_CKKSOp<mnemonic, traits # [InPlaceOpInterface]> {
Lattigo_CKKSOp<mnemonic, traits # [
InPlaceOpInterface,
DefaultMemWriteTrait,
]> {
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertextOrPlaintext:$rhs,
Arg<Lattigo_RLWECiphertextOrPlaintext, "lattigo ciphertext or plaintext", [FullMemRead]>:$rhs,
// Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to
// satisfy the SSA form, so we still have a separate output.
Lattigo_RLWECiphertext:$inplace
Arg<Lattigo_RLWECiphertextOrPlaintext, "lattigo ciphertext or plaintext", [FullMemWrite]>:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInPlaceOperandIndex() { return 3; }";
}

Expand Down Expand Up @@ -269,15 +278,17 @@ def Lattigo_CKKSRotateNewOp : Lattigo_CKKSOp<"rotate_new", [
}

class Lattigo_CKKSUnaryInPlaceOp<string mnemonic, list<Trait> traits = []> :
Lattigo_CKKSOp<mnemonic, traits # [InPlaceOpInterface]> {
Lattigo_CKKSOp<mnemonic, traits # [
InPlaceOpInterface,
DefaultMemWriteTrait,
]> {
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
// see BinaryInPlaceOp above
Lattigo_RLWECiphertext:$inplace
Arg<Lattigo_RLWECiphertext, "", [FullMemWrite]>:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }";
}

Expand All @@ -304,7 +315,8 @@ def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [
}

def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [
DeclareOpInterfaceMethods<RotationOpInterface>
DeclareOpInterfaceMethods<RotationOpInterface>,
DefaultMemWriteTrait,
]> {
let summary = "Rotate slots of a ciphertext in the Lattigo CKKS dialect";
let description = [{
Expand All @@ -321,17 +333,20 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [
}];
let arguments = (ins
Lattigo_CKKSEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Arg<Lattigo_RLWECiphertext, "", [FullMemRead]>:$input,
// see BinaryInPlaceOp above
Lattigo_RLWECiphertext:$inplace,
Arg<Lattigo_RLWECiphertext, "", [FullMemWrite]>:$inplace,
Optional<AnySignlessIntegerOrIndex>:$dynamic_shift,
OptionalAttr<Builtin_IntegerAttr>:$static_shift
);
let results = (outs Lattigo_RLWECiphertext:$output);
let hasVerifier = 1;
}

def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> {
def Lattigo_CKKSBootstrapOp : Lattigo_CKKSOp<"bootstrap", [
InPlaceOpInterface,
DefaultMemWriteTrait,
]> {
let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect";
let description = [{
Bootstraps a ciphertext value in the Lattigo CKKS dialect.
Expand All @@ -345,9 +360,12 @@ def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> {

let arguments = (ins
Lattigo_CKKSBootstrappingEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input
Lattigo_RLWECiphertext:$input,
Arg<Lattigo_RLWECiphertext, "", [FullMemWrite]>:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }";
}

def Lattigo_CKKSLinearTransformOp : Lattigo_CKKSOp<"linear_transform", [
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "lib/Utils/Tablegen/InPlaceOpInterface.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
// IWYU pragma: end_keep

#define GET_OP_CLASSES
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class Lattigo_Op<string mnemonic, list<Trait> traits = []> :
}];
}

def FullMemRead : MemReadAt<0, FullEffect>;
def FullMemWrite : MemWriteAt<0, FullEffect>;
def DefaultMemReadTrait : MemoryEffects<[FullMemRead]>;
def DefaultMemWriteTrait : MemoryEffects<[FullMemWrite]>;

include "LattigoBGVOps.td"
include "LattigoCKKSOps.td"
include "LattigoRLWEOps.td"
Expand Down
Loading
Loading