diff --git a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp index 088d7afd06..593329be4a 100644 --- a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp @@ -483,7 +483,7 @@ struct ConvertRlweOpBootstrap : public OpConversionPattern { op, LattigoBootstrapOp::create( rewriter, op.getLoc(), this->typeConverter->convertType(op.getOutput().getType()), - evaluator, adaptor.getInput())); + evaluator, adaptor.getInput(), adaptor.getInput())); return success(); } }; diff --git a/lib/Dialect/Lattigo/IR/BUILD b/lib/Dialect/Lattigo/IR/BUILD index 890890cbd7..5a1cfa2fdd 100644 --- a/lib/Dialect/Lattigo/IR/BUILD +++ b/lib/Dialect/Lattigo/IR/BUILD @@ -33,6 +33,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SideEffectInterfaces", ], ) @@ -94,6 +95,7 @@ cc_library( "@heir//lib/Utils/Tablegen:InPlaceOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], ) diff --git a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td index 7488eb37a7..673157fb7a 100644 --- a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td @@ -5,6 +5,7 @@ include "LattigoDialect.td" include "LattigoTypes.td" include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" class Lattigo_BGVOp traits = []> : Lattigo_Op<"bgv." # mnemonic, traits> { @@ -13,7 +14,9 @@ class Lattigo_BGVOp traits = []> : // This operation cannot be marked as Pure because it represents a buffer allocation. // If marked as Pure, CSE will deduplicate them, causing multiple in-place operations // (like encode) to share the same buffer and overwrite each other. -def Lattigo_BGVNewPlaintextOp : Lattigo_BGVOp<"new_plaintext"> { +def Lattigo_BGVNewPlaintextOp : Lattigo_BGVOp<"new_plaintext", [ + MemoryEffects<[MemAllocAt<0, FullEffect>]> +]> { let summary = "Create a new plaintext in the Lattigo BGV dialect"; let description = [{ This operation creates a new plaintext value in the Lattigo BGV dialect. @@ -21,7 +24,8 @@ def Lattigo_BGVNewPlaintextOp : Lattigo_BGVOp<"new_plaintext"> { let arguments = (ins Lattigo_BGVParameter:$params ); - let results = (outs Lattigo_RLWEPlaintext:$plaintext); + let results = (outs Res]>:$plaintext); } def Lattigo_BGVNewParametersFromLiteralOp : Lattigo_BGVOp<"new_parameters_from_literal", [Pure]> { @@ -46,7 +50,11 @@ def Lattigo_BGVNewEncoderOp : Lattigo_BGVOp<"new_encoder", [Pure]> { let results = (outs Lattigo_BGVEncoder:$encoder); } -def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InPlaceOpInterface, PlaintextEncodeOpInterface]> { +def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [ + InPlaceOpInterface, + PlaintextEncodeOpInterface, + DefaultMemWriteTrait, +]> { let summary = "Encode a plaintext value in the Lattigo BGV dialect"; let description = [{ This operation encodes a plaintext value using the specified encoder in the Lattigo BGV dialect. @@ -58,15 +66,14 @@ def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InPlaceOpInterface, Plaintext let arguments = (ins Lattigo_BGVEncoder:$encoder, TensorOrMemRef<[AnyInteger]>:$value, - Lattigo_RLWEPlaintext:$plaintext, + Arg:$plaintext, DefaultValuedAttr:$scale ); let results = (outs Lattigo_RLWEPlaintext:$encoded); - let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }"; } -def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>]> { +def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>, DefaultMemReadTrait]> { let summary = "Decode a plaintext value in the Lattigo BGV dialect"; let description = [{ This operation decodes a plaintext value using the specified encoder in the Lattigo BGV dialect. @@ -75,7 +82,7 @@ def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "deco }]; let arguments = (ins Lattigo_BGVEncoder:$encoder, - Lattigo_RLWEPlaintext:$plaintext, + Arg:$plaintext, TensorOrMemRef<[AnyInteger]>:$value // Rely on Lattigo internal tracking of scale for decoding. // DefaultValuedAttr:$scale @@ -110,11 +117,11 @@ def Lattigo_BGVNewEvaluatorOp : Lattigo_BGVOp<"new_evaluator", [Pure]> { // ciphertext arithmetic op class Lattigo_BGVBinaryOp traits = []> : - Lattigo_BGVOp { + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, - Lattigo_RLWECiphertextOrPlaintext:$rhs + Arg:$rhs ); let results = (outs Lattigo_RLWECiphertext:$output); } @@ -141,17 +148,19 @@ def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new", [IncreasesMulDepthOpInt } class Lattigo_BGVBinaryInPlaceOp traits = []> : - Lattigo_BGVOp { + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, - Lattigo_RLWECiphertextOrPlaintext:$rhs, + Arg:$rhs, // Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to // satisfy the SSA form, so we still have a separate output. - Lattigo_RLWECiphertext:$inplace + Arg:$inplace ); let results = (outs Lattigo_RLWECiphertext:$output); - let extraClassDeclaration = "int getInPlaceOperandIndex() { return 3; }"; } @@ -241,12 +250,15 @@ def Lattigo_BGVRotateRowsNewOp : Lattigo_BGVUnaryOp<"rotate_rows_new"> { } class Lattigo_BGVUnaryInPlaceOp traits = []> : - Lattigo_BGVOp { + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$input, // see BinaryInPlaceOp above - Lattigo_RLWECiphertext:$inplace + Arg:$inplace ); let results = (outs Lattigo_RLWECiphertext:$output); @@ -276,7 +288,7 @@ def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [ } def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInPlaceOp<"rotate_columns", [ - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, DefaultMemWriteTrait, ]> { let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect"; let description = [{ @@ -291,8 +303,8 @@ def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInPlaceOp<"rotate_columns", [ }]; let arguments = (ins Lattigo_BGVEvaluator:$evaluator, - Lattigo_RLWECiphertext:$input, - Lattigo_RLWECiphertext:$inplace, + Arg:$input, + Arg:$inplace, Optional:$dynamic_shift, OptionalAttr:$static_shift ); diff --git a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td index c6c8ae7dc9..92cfd9e6f6 100644 --- a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td @@ -15,7 +15,9 @@ class Lattigo_CKKSOp traits = []> : // This operation cannot be marked as Pure because it represents a buffer allocation. // If marked as Pure, CSE will deduplicate them, causing multiple in-place operations // (like encode) to share the same buffer and overwrite each other. -def Lattigo_CKKSNewPlaintextOp : Lattigo_CKKSOp<"new_plaintext"> { +def Lattigo_CKKSNewPlaintextOp : Lattigo_CKKSOp<"new_plaintext", [ + MemoryEffects<[MemAllocAt<0, FullEffect>]> +]> { let summary = "Create a new plaintext in the Lattigo CKKS dialect"; let description = [{ This operation creates a new plaintext value in the Lattigo CKKS dialect. @@ -23,7 +25,8 @@ def Lattigo_CKKSNewPlaintextOp : Lattigo_CKKSOp<"new_plaintext"> { let arguments = (ins Lattigo_CKKSParameter:$params ); - let results = (outs Lattigo_RLWEPlaintext:$plaintext); + let results = (outs Res]>:$plaintext); } def Lattigo_CKKSNewParametersFromLiteralOp : Lattigo_CKKSOp<"new_parameters_from_literal", [Pure]> { @@ -48,7 +51,11 @@ def Lattigo_CKKSNewEncoderOp : Lattigo_CKKSOp<"new_encoder", [Pure]> { let results = (outs Lattigo_CKKSEncoder:$encoder); } -def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [InPlaceOpInterface, PlaintextEncodeOpInterface]> { +def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [ + InPlaceOpInterface, + PlaintextEncodeOpInterface, + DefaultMemWriteTrait, +]> { let summary = "Encode a plaintext value in the Lattigo CKKS dialect"; let description = [{ This operation encodes a plaintext value using the specified encoder in the Lattigo CKKS dialect. @@ -60,7 +67,7 @@ def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [InPlaceOpInterface, Plainte let arguments = (ins Lattigo_CKKSEncoder:$encoder, TensorOrMemRef<[AnyFloat, AnyComplex]>:$value, - Lattigo_RLWEPlaintext:$plaintext, + Arg:$plaintext, DefaultValuedAttr:$scale ); let results = (outs Lattigo_RLWEPlaintext:$encoded); @@ -68,7 +75,7 @@ def Lattigo_CKKSEncodeOp : Lattigo_CKKSOp<"encode", [InPlaceOpInterface, Plainte let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }"; } -def Lattigo_CKKSDecodeOp : Lattigo_CKKSOp<"decode", [AllTypesMatch<["value", "decoded"]>]> { +def Lattigo_CKKSDecodeOp : Lattigo_CKKSOp<"decode", [AllTypesMatch<["value", "decoded"]>, DefaultMemReadTrait]> { let summary = "Decode a plaintext value in the Lattigo CKKS dialect"; let description = [{ This operation decodes a plaintext value using the specified encoder in the Lattigo CKKS dialect. @@ -77,7 +84,7 @@ def Lattigo_CKKSDecodeOp : Lattigo_CKKSOp<"decode", [AllTypesMatch<["value", "de }]; let arguments = (ins Lattigo_CKKSEncoder:$encoder, - Lattigo_RLWEPlaintext:$plaintext, + Arg:$plaintext, TensorOrMemRef<[AnyFloat, AnyComplex]>:$value ); // although bgv.Decode is also an inplace operation as bgv.Encode, as there are post-processing @@ -144,11 +151,11 @@ def Lattigo_CKKSNewBootstrappingEvaluatorOp : Lattigo_CKKSOp<"new_bootstrapping_ // ciphertext arithmetic op class Lattigo_CKKSBinaryOp traits = []> : - Lattigo_CKKSOp { + Lattigo_CKKSOp { let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, - Lattigo_RLWECiphertextOrPlaintext:$rhs + Arg:$rhs ); let results = (outs Lattigo_RLWECiphertext:$output); } @@ -175,17 +182,19 @@ def Lattigo_CKKSMulNewOp : Lattigo_CKKSBinaryOp<"mul_new", [IncreasesMulDepthOpI } class Lattigo_CKKSBinaryInPlaceOp traits = []> : - Lattigo_CKKSOp { + Lattigo_CKKSOp { let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, - Lattigo_RLWECiphertextOrPlaintext:$rhs, + Arg:$rhs, // Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to // satisfy the SSA form, so we still have a separate output. - Lattigo_RLWECiphertext:$inplace + Arg:$inplace ); let results = (outs Lattigo_RLWECiphertext:$output); - let extraClassDeclaration = "int getInPlaceOperandIndex() { return 3; }"; } @@ -269,15 +278,17 @@ def Lattigo_CKKSRotateNewOp : Lattigo_CKKSOp<"rotate_new", [ } class Lattigo_CKKSUnaryInPlaceOp traits = []> : - Lattigo_CKKSOp { + Lattigo_CKKSOp { let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, Lattigo_RLWECiphertext:$input, // see BinaryInPlaceOp above - Lattigo_RLWECiphertext:$inplace + Arg:$inplace ); let results = (outs Lattigo_RLWECiphertext:$output); - let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }"; } @@ -304,7 +315,8 @@ def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [ } def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [ - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods, + DefaultMemWriteTrait, ]> { let summary = "Rotate slots of a ciphertext in the Lattigo CKKS dialect"; let description = [{ @@ -321,9 +333,9 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [ }]; let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, - Lattigo_RLWECiphertext:$input, + Arg:$input, // see BinaryInPlaceOp above - Lattigo_RLWECiphertext:$inplace, + Arg:$inplace, Optional:$dynamic_shift, OptionalAttr:$static_shift ); @@ -331,7 +343,10 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [ let hasVerifier = 1; } -def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> { +def Lattigo_CKKSBootstrapOp : Lattigo_CKKSOp<"bootstrap", [ + InPlaceOpInterface, + DefaultMemWriteTrait, +]> { let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect"; let description = [{ Bootstraps a ciphertext value in the Lattigo CKKS dialect. @@ -345,9 +360,12 @@ def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> { let arguments = (ins Lattigo_CKKSBootstrappingEvaluator:$evaluator, - Lattigo_RLWECiphertext:$input + Lattigo_RLWECiphertext:$input, + Arg:$inplace ); let results = (outs Lattigo_RLWECiphertext:$output); + + let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }"; } def Lattigo_CKKSLinearTransformOp : Lattigo_CKKSOp<"linear_transform", [ diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.h b/lib/Dialect/Lattigo/IR/LattigoOps.h index 3897ffe5ee..3751a07893 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.h +++ b/lib/Dialect/Lattigo/IR/LattigoOps.h @@ -8,6 +8,7 @@ #include "lib/Utils/Tablegen/InPlaceOpInterface.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project // IWYU pragma: end_keep #define GET_OP_CLASSES diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.td b/lib/Dialect/Lattigo/IR/LattigoOps.td index cbd5c95ee8..d557fc9011 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoOps.td @@ -15,6 +15,11 @@ class Lattigo_Op traits = []> : }]; } +def FullMemRead : MemReadAt<0, FullEffect>; +def FullMemWrite : MemWriteAt<0, FullEffect>; +def DefaultMemReadTrait : MemoryEffects<[FullMemRead]>; +def DefaultMemWriteTrait : MemoryEffects<[FullMemWrite]>; + include "LattigoBGVOps.td" include "LattigoCKKSOps.td" include "LattigoRLWEOps.td" diff --git a/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td b/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td index 155327329f..602c25566d 100644 --- a/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoRLWEOps.td @@ -133,8 +133,11 @@ def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new", let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", - [InPlaceOpInterface, DeclareOpInterfaceMethods]> { +def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [ + InPlaceOpInterface, + DeclareOpInterfaceMethods, + DefaultMemWriteTrait, +]> { let summary = "Drop level of a ciphertext"; let description = [{ This operation drops the level of a ciphertext @@ -144,12 +147,11 @@ def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", }]; let arguments = (ins Lattigo_RLWEEvaluator:$evaluator, - Lattigo_RLWECiphertext:$input, - Lattigo_RLWECiphertext:$inplace, + Arg:$input, + Arg:$inplace, DefaultValuedAttr:$levelToDrop ); let results = (outs Lattigo_RLWECiphertext:$output); - let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }"; } @@ -162,7 +164,10 @@ def Lattigo_RLWENegateNewOp : Lattigo_RLWEOp<"negate_new", [Pure]> { let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [InPlaceOpInterface]> { +def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [ + InPlaceOpInterface, + DefaultMemWriteTrait, +]> { let summary = "Negate of a ciphertext"; let description = [{ This operation negates a ciphertext @@ -172,11 +177,10 @@ def Lattigo_RLWENegateOp : Lattigo_RLWEOp<"negate", [InPlaceOpInterface]> { }]; let arguments = (ins Lattigo_RLWEEvaluator:$evaluator, - Lattigo_RLWECiphertext:$input, - Lattigo_RLWECiphertext:$inplace + Arg:$input, + Arg:$inplace ); let results = (outs Lattigo_RLWECiphertext:$output); - let extraClassDeclaration = "int getInPlaceOperandIndex() { return 2; }"; } diff --git a/tests/Dialect/Lattigo/IR/memory_effects.mlir b/tests/Dialect/Lattigo/IR/memory_effects.mlir new file mode 100644 index 0000000000..64ae94183a --- /dev/null +++ b/tests/Dialect/Lattigo/IR/memory_effects.mlir @@ -0,0 +1,51 @@ +// RUN: heir-opt --cse %s | FileCheck %s + +!pk = !lattigo.rlwe.public_key +!eval_key_set = !lattigo.rlwe.evaluation_key_set +!ct = !lattigo.rlwe.ciphertext +!pt = !lattigo.rlwe.plaintext + +!bgv_evaluator = !lattigo.bgv.evaluator +!bgv_params = !lattigo.bgv.parameter + +!ckks_evaluator = !lattigo.ckks.bootstrapping_evaluator + +#paramsLiteral = #lattigo.bgv.parameters_literal< + logN = 14, + logQ = [56, 55, 55], + logP = [55], + plaintextModulus = 0x3ee0001 +> + +module { + // CHECK-LABEL: @test_bgv_new_plaintext_cse + func.func @test_bgv_new_plaintext_cse() { + %params = lattigo.bgv.new_parameters_from_literal {paramsLiteral = #paramsLiteral} : () -> !bgv_params + // Allocations should not be deduplicated. + // CHECK: lattigo.bgv.new_plaintext + // CHECK: lattigo.bgv.new_plaintext + %pt1 = lattigo.bgv.new_plaintext %params : (!bgv_params) -> !pt + %pt2 = lattigo.bgv.new_plaintext %params : (!bgv_params) -> !pt + return + } + + // CHECK-LABEL: @test_bgv_add_cse + func.func @test_bgv_add_cse(%evaluator: !bgv_evaluator, %lhs: !ct, %rhs: !ct) { + // In-place operations should not be deduplicated because they mutate their inplace operand. + // CHECK: lattigo.bgv.add + // CHECK: lattigo.bgv.add + %output1 = lattigo.bgv.add %evaluator, %lhs, %rhs, %lhs : (!bgv_evaluator, !ct, !ct, !ct) -> !ct + %output2 = lattigo.bgv.add %evaluator, %lhs, %rhs, %lhs : (!bgv_evaluator, !ct, !ct, !ct) -> !ct + return + } + + // CHECK-LABEL: @test_ckks_bootstrap_cse + func.func @test_ckks_bootstrap_cse(%evaluator: !ckks_evaluator, %ct: !ct) { + // In-place operations should not be deduplicated. + // CHECK: lattigo.ckks.bootstrap + // CHECK: lattigo.ckks.bootstrap + %output1 = lattigo.ckks.bootstrap %evaluator, %ct, %ct : (!ckks_evaluator, !ct, !ct) -> !ct + %output2 = lattigo.ckks.bootstrap %evaluator, %ct, %ct : (!ckks_evaluator, !ct, !ct) -> !ct + return + } +}