From ac61a888f534d657068d3160d70588e7ff729a67 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 3 Nov 2025 10:04:56 -0800 Subject: [PATCH] start porting HEIR to high-precision scale --- lib/Analysis/ScaleAnalysis/ScaleAnalysis.cpp | 166 +++++++++++------- lib/Analysis/ScaleAnalysis/ScaleAnalysis.h | 70 +++++--- .../Conversions/LWEToLattigo/LWEToLattigo.cpp | 7 +- lib/Dialect/LWE/IR/LWEAttributes.cpp | 63 ++++--- lib/Dialect/LWE/IR/LWEAttributes.h | 4 +- lib/Dialect/LWE/IR/LWEAttributes.td | 2 +- lib/Dialect/LWE/IR/LWETypes.cpp | 6 +- lib/Dialect/Mgmt/IR/MgmtAttributes.cpp | 6 +- lib/Dialect/Mgmt/IR/MgmtAttributes.h | 1 + lib/Dialect/Mgmt/IR/MgmtAttributes.td | 21 ++- lib/Dialect/Mgmt/Transforms/AnnotateMgmt.cpp | 14 +- .../Conversions/SecretToBGV/SecretToBGV.cpp | 2 + .../Conversions/SecretToCKKS/SecretToCKKS.cpp | 2 + lib/Target/Lattigo/LattigoEmitter.cpp | 6 +- .../PopulateScale/PopulateScaleCKKS.cpp | 24 ++- .../PopulateScale/PopulateScalePatterns.cpp | 15 +- lib/Utils/ContextAwareConversionUtils.h | 5 +- .../annotate_mgmt/client_helpers.mlir | 8 +- .../annotate_mgmt/dimension_backprop.mlir | 24 +-- .../generate_param_bfv/doctest.mlir | 8 +- .../generate_param_bgv/doctest.mlir | 8 +- .../generate_param_ckks/doctest.mlir | 8 +- .../optimize_relinearization/issue_1548.mlir | 12 +- .../populate_scale/bgv/doctest.mlir | 10 +- .../populate_scale/bgv/smoke_test.mlir | 32 ++-- .../populate_scale/ckks/doctest.mlir | 18 +- .../secret_insert_mgmt/bgv/init.mlir | 48 ++--- .../ckks/bootstrap_waterline.mlir | 36 ++-- .../secret_insert_mgmt/func_call.mlir | 12 +- .../secret_insert_mgmt_bfv/doctest.mlir | 6 +- .../secret_insert_mgmt_bgv/doctest.mlir | 10 +- .../validate_noise/validate_noise_fail.mlir | 36 ++-- .../validate_noise/validate_noise_pass.mlir | 32 ++-- .../validate_noise_preserve_user_param.mlir | 2 +- ...lidate_noise_preserve_user_param_fail.mlir | 24 +-- 35 files changed, 452 insertions(+), 296 deletions(-) diff --git a/lib/Analysis/ScaleAnalysis/ScaleAnalysis.cpp b/lib/Analysis/ScaleAnalysis/ScaleAnalysis.cpp index 653b0a2307..d3cd9811d0 100644 --- a/lib/Analysis/ScaleAnalysis/ScaleAnalysis.cpp +++ b/lib/Analysis/ScaleAnalysis/ScaleAnalysis.cpp @@ -1,6 +1,7 @@ #include "lib/Analysis/ScaleAnalysis/ScaleAnalysis.h" #include +#include #include #include @@ -14,6 +15,7 @@ #include "lib/Utils/APIntUtils.h" #include "lib/Utils/AttributeUtils.h" #include "lib/Utils/Utils.h" +#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project @@ -36,70 +38,110 @@ namespace heir { // ScaleModel //===----------------------------------------------------------------------===// -int64_t BGVScaleModel::evalMulScale(const bgv::LocalParam& param, int64_t lhs, - int64_t rhs) { +llvm::APInt BGVScaleModel::evalMulScale(const bgv::LocalParam& param, + const llvm::APInt& lhs, + const llvm::APInt& rhs) { const auto* schemeParam = param.getSchemeParam(); - auto t = schemeParam->getPlaintextModulus(); - return lhs * rhs % t; + auto t = llvm::APInt(64, schemeParam->getPlaintextModulus()); + // (lhs * rhs) % t + return (lhs * rhs).urem(t); } -int64_t BGVScaleModel::evalMulScaleBackward(const bgv::LocalParam& param, - int64_t result, int64_t lhs) { +llvm::APInt BGVScaleModel::evalMulScaleBackward(const bgv::LocalParam& param, + const llvm::APInt& result, + const llvm::APInt& lhs) { const auto* schemeParam = param.getSchemeParam(); - auto t = schemeParam->getPlaintextModulus(); - auto lhsInv = multiplicativeInverse(APInt(64, lhs), APInt(64, t)); - return result * lhsInv.getSExtValue() % t; + auto t = llvm::APInt(64, schemeParam->getPlaintextModulus()); + auto lhsInv = multiplicativeInverse(lhs, t); + // (result * lhsInv) % t + return (result * lhsInv).urem(t); } -int64_t BGVScaleModel::evalModReduceScale(const bgv::LocalParam& inputParam, - int64_t scale) { +llvm::APInt BGVScaleModel::evalModReduceScale(const bgv::LocalParam& inputParam, + const llvm::APInt& scale) { const auto* schemeParam = inputParam.getSchemeParam(); - auto t = schemeParam->getPlaintextModulus(); + auto t = llvm::APInt(64, schemeParam->getPlaintextModulus()); auto qi = schemeParam->getQi(); auto level = inputParam.getCurrentLevel(); - auto qInvT = multiplicativeInverse(APInt(64, qi[level] % t), APInt(64, t)); - return scale * qInvT.getSExtValue() % t; + auto qiModT = llvm::APInt(64, qi[level]).urem(t); + auto qInvT = multiplicativeInverse(qiModT, t); + // (scale * qInvT) % t + return (scale * qInvT).urem(t); } -int64_t BGVScaleModel::evalModReduceScaleBackward( - const bgv::LocalParam& inputParam, int64_t resultScale) { +llvm::APInt BGVScaleModel::evalModReduceScaleBackward( + const bgv::LocalParam& inputParam, const llvm::APInt& resultScale) { const auto* schemeParam = inputParam.getSchemeParam(); - auto t = schemeParam->getPlaintextModulus(); + auto t = llvm::APInt(64, schemeParam->getPlaintextModulus()); auto qi = schemeParam->getQi(); auto level = inputParam.getCurrentLevel(); - return resultScale * (qi[level] % t) % t; + auto qiModT = llvm::APInt(64, qi[level]).urem(t); + // (resultScale * qiModT) % t + return (resultScale * qiModT).urem(t); } -int64_t CKKSScaleModel::evalMulScale(const ckks::LocalParam& param, int64_t lhs, - int64_t rhs) { - // TODO(#1640): support high-precision scale management - return lhs + rhs; +llvm::APInt CKKSScaleModel::evalMulScale(const ckks::LocalParam& param, + const llvm::APInt& lhs, + const llvm::APInt& rhs) { + // High-precision scale management (#2364): multiply actual scales + // In CKKS, ct1 * ct2 with scales s1, s2 produces ciphertext with scale s1*s2 + return lhs * rhs; } -int64_t CKKSScaleModel::evalMulScaleBackward(const ckks::LocalParam& param, - int64_t result, int64_t lhs) { - // TODO(#1640): support high-precision scale management - return result - lhs; +llvm::APInt CKKSScaleModel::evalMulScaleBackward(const ckks::LocalParam& param, + const llvm::APInt& result, + const llvm::APInt& lhs) { + // High-precision scale management (#2364): divide actual scales + // If result = lhs * rhs, then rhs = result / lhs + // Handle uninitialized scales (zero) - can't infer from uninitialized values + if (lhs.isZero() || result.isZero()) { + return llvm::APInt(64, 0); + } + return result.udiv(lhs); } -int64_t CKKSScaleModel::evalModReduceScale(const ckks::LocalParam& inputParam, - int64_t scale) { +llvm::APInt CKKSScaleModel::evalModReduceScale( + const ckks::LocalParam& inputParam, const llvm::APInt& scale) { const auto* schemeParam = inputParam.getSchemeParam(); - // TODO(#1640): rescale using logqi instead of logDefaultScale - // auto logqi = schemeParam->getLogqi(); - // auto level = inputParam.getCurrentLevel(); - auto logDefaultScale = schemeParam->getLogDefaultScale(); - return scale - logDefaultScale; + // High-precision scale management (#2364): divide by actual qi value + // Rescaling divides the ciphertext modulus by qi[level] and the scale by + // qi[level] + auto qi = schemeParam->getQi(); + auto level = inputParam.getCurrentLevel(); + + // If qi is not populated, fall back to using default scale + if (qi.empty() || level >= static_cast(qi.size())) { + // Fallback: use 2^logDefaultScale as the rescaling factor + auto logDefaultScale = schemeParam->getLogDefaultScale(); + llvm::APInt defaultScale(scale.getBitWidth(), 1); + defaultScale = defaultScale.shl(logDefaultScale); + return scale.udiv(defaultScale); + } + + llvm::APInt qiVal(scale.getBitWidth(), qi[level]); + return scale.udiv(qiVal); } -int64_t CKKSScaleModel::evalModReduceScaleBackward( - const ckks::LocalParam& inputParam, int64_t resultScale) { +llvm::APInt CKKSScaleModel::evalModReduceScaleBackward( + const ckks::LocalParam& inputParam, const llvm::APInt& resultScale) { const auto* schemeParam = inputParam.getSchemeParam(); - // TODO(#1640): rescale using logqi instead of logDefaultScale - // auto logqi = schemeParam->getLogqi(); - // auto level = inputParam.getCurrentLevel(); - auto logDefaultScale = schemeParam->getLogDefaultScale(); - return resultScale + logDefaultScale; + // High-precision scale management (#2364): multiply by actual qi value + // Reverse of evalModReduceScale: if result = scale / qi, then scale = result + // * qi + auto qi = schemeParam->getQi(); + auto level = inputParam.getCurrentLevel(); + + // If qi is not populated, fall back to using default scale + if (qi.empty() || level >= static_cast(qi.size())) { + // Fallback: use 2^logDefaultScale as the rescaling factor + auto logDefaultScale = schemeParam->getLogDefaultScale(); + llvm::APInt defaultScale(resultScale.getBitWidth(), 1); + defaultScale = defaultScale.shl(logDefaultScale); + return resultScale * defaultScale; + } + + llvm::APInt qiVal(resultScale.getBitWidth(), qi[level]); + return resultScale * qiVal; } //===----------------------------------------------------------------------===// @@ -126,7 +168,8 @@ LogicalResult ScaleAnalysis::visitOperation( propagateIfChanged(lattice, changed); }; - auto getOperandScales = [&](Operation* op, SmallVectorImpl& scales) { + auto getOperandScales = [&](Operation* op, + SmallVectorImpl& scales) { SmallVector secretOperands; this->getSecretOperands(op, secretOperands); @@ -147,7 +190,7 @@ LogicalResult ScaleAnalysis::visitOperation( llvm::TypeSwitch(*op) .template Case([&](auto mulOp) { - SmallVector scales; + SmallVector scales; getOperandScales(mulOp, scales); // there must be at least one secret operand that has scale if (scales.empty()) { @@ -166,7 +209,7 @@ LogicalResult ScaleAnalysis::visitOperation( propagate(mulOp.getResult(), ScaleState(result)); }) .template Case([&](auto modReduceOp) { - SmallVector scales; + SmallVector scales; getOperandScales(modReduceOp, scales); // there must be at least one secret operand that has scale if (scales.empty()) { @@ -200,7 +243,7 @@ LogicalResult ScaleAnalysis::visitOperation( return; } - SmallVector scales; + SmallVector scales; getOperandScales(&op, scales); if (scales.empty()) { return; @@ -267,7 +310,7 @@ LogicalResult ScaleAnalysisBackward::visitOperation( auto getOperandScales = [&](Operation* op, SmallVectorImpl& operandWithoutScaleIndices, - SmallVectorImpl& scales) { + SmallVectorImpl& scales) { LLVM_DEBUG(llvm::dbgs() << "Operand scales for " << op->getName() << ": "); SmallVector secretOperands; @@ -294,7 +337,8 @@ LogicalResult ScaleAnalysisBackward::visitOperation( LLVM_DEBUG(llvm::dbgs() << "\n"); }; - auto getResultScales = [&](Operation* op, SmallVectorImpl& scales) { + auto getResultScales = [&](Operation* op, + SmallVectorImpl& scales) { LLVM_DEBUG(llvm::dbgs() << "Result scales for " << op->getName() << ": "); SmallVector secretResults; this->getSecretResults(op, secretResults); @@ -315,14 +359,14 @@ LogicalResult ScaleAnalysisBackward::visitOperation( << "\n"); llvm::TypeSwitch(*op) .template Case([&](auto mulOp) { - SmallVector resultScales; + SmallVector resultScales; getResultScales(mulOp, resultScales); // there must be at least one secret result that has scale if (resultScales.empty()) { return; } SmallVector operandWithoutScaleIndices; - SmallVector operandScales; + SmallVector operandScales; getOperandScales(mulOp, operandWithoutScaleIndices, operandScales); // there must be at least one secret operand that has scale if (operandScales.empty()) { @@ -342,14 +386,14 @@ LogicalResult ScaleAnalysisBackward::visitOperation( ScaleState(scaleOther)); }) .template Case([&](auto modReduceOp) { - SmallVector resultScales; + SmallVector resultScales; getResultScales(modReduceOp, resultScales); // there must be at least one secret result that has scale if (resultScales.empty()) { return; } SmallVector operandWithoutScaleIndices; - SmallVector scales; + SmallVector scales; getOperandScales(modReduceOp, operandWithoutScaleIndices, scales); // if all operands have scale, succeed. if (!scales.empty()) { @@ -376,7 +420,7 @@ LogicalResult ScaleAnalysisBackward::visitOperation( return; } - SmallVector scales; + SmallVector scales; getResultScales(&op, scales); if (scales.empty()) { return; @@ -399,37 +443,41 @@ template class ScaleAnalysisBackward; // Utils //===----------------------------------------------------------------------===// -int64_t getScale(Value value, DataFlowSolver* solver) { +llvm::APInt getScale(Value value, DataFlowSolver* solver) { auto* lattice = solver->lookupState(value); if (!lattice) { assert(false && "ScaleLattice not found"); - return 0; + return llvm::APInt(64, 0); } if (!lattice->getValue().isInitialized()) { assert(false && "ScaleLattice not initialized"); - return 0; + return llvm::APInt(64, 0); } return lattice->getValue().getScale(); } -int64_t getScaleFromMgmtAttr(Value value) { +llvm::APInt getScaleFromMgmtAttr(Value value) { auto mgmtAttr = mgmt::findMgmtAttrAssociatedWith(value); if (!mgmtAttr) { assert(false && "MgmtAttr not found"); - return 0; + return llvm::APInt(64, 0); } + // High-precision scale management (#2364): MgmtAttr now stores APInt directly return mgmtAttr.getScale(); } void annotateScale(Operation* top, DataFlowSolver* solver) { - auto getIntegerAttr = [&](int scale) { - return IntegerAttr::get(IntegerType::get(top->getContext(), 64), scale); + auto getStringAttr = [&](const llvm::APInt& scale) { + // Store APInt as a string in base 10 for full precision + llvm::SmallString<64> str; + scale.toString(str, 10, /*Signed=*/false); + return StringAttr::get(top->getContext(), str); }; walkValues(top, [&](Value value) { if (mgmt::shouldHaveMgmtAttribute(value, solver)) { setAttributeAssociatedWith(value, kArgScaleAttrName, - getIntegerAttr(getScale(value, solver))); + getStringAttr(getScale(value, solver))); } }); } diff --git a/lib/Analysis/ScaleAnalysis/ScaleAnalysis.h b/lib/Analysis/ScaleAnalysis/ScaleAnalysis.h index dd0be035ae..13d9d73ae3 100644 --- a/lib/Analysis/ScaleAnalysis/ScaleAnalysis.h +++ b/lib/Analysis/ScaleAnalysis/ScaleAnalysis.h @@ -9,6 +9,7 @@ #include "lib/Dialect/Secret/IR/SecretTypes.h" #include "lib/Parameters/BGV/Params.h" #include "lib/Parameters/CKKS/Params.h" +#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project @@ -25,14 +26,19 @@ namespace heir { class ScaleState { public: ScaleState() : scale(std::nullopt) {} - explicit ScaleState(int64_t scale) : scale(scale) {} + explicit ScaleState(const llvm::APInt& scale) : scale(scale) {} + explicit ScaleState(int64_t scale) : scale(llvm::APInt(64, scale)) {} - int64_t getScale() const { + const llvm::APInt& getScale() const { assert(isInitialized()); return scale.value(); } - bool operator==(const ScaleState& rhs) const { return scale == rhs.scale; } + bool operator==(const ScaleState& rhs) const { + if (!isInitialized() && !rhs.isInitialized()) return true; + if (!isInitialized() || !rhs.isInitialized()) return false; + return scale.value() == rhs.scale.value(); + } bool isInitialized() const { return scale.has_value(); } @@ -59,9 +65,11 @@ class ScaleState { } private: - // This may not represent 2 ** 80 scale for CKKS. - // Currently we use logScale for CKKS. - std::optional scale; + // For CKKS: represents the actual scale value (not log scale) + // For BGV: represents the scale in modular arithmetic + // This supports high-precision scale management (#2364) using + // arbitrary-precision APInt + std::optional scale; }; class ScaleLattice : public dataflow::Lattice { @@ -73,28 +81,32 @@ struct BGVScaleModel { using SchemeParam = bgv::SchemeParam; using LocalParam = bgv::LocalParam; - static int64_t evalMulScale(const LocalParam& param, int64_t lhs, - int64_t rhs); - static int64_t evalMulScaleBackward(const LocalParam& param, int64_t result, - int64_t lhs); - static int64_t evalModReduceScale(const LocalParam& inputParam, - int64_t scale); - static int64_t evalModReduceScaleBackward(const LocalParam& inputParam, - int64_t resultScale); + static llvm::APInt evalMulScale(const LocalParam& param, + const llvm::APInt& lhs, + const llvm::APInt& rhs); + static llvm::APInt evalMulScaleBackward(const LocalParam& param, + const llvm::APInt& result, + const llvm::APInt& lhs); + static llvm::APInt evalModReduceScale(const LocalParam& inputParam, + const llvm::APInt& scale); + static llvm::APInt evalModReduceScaleBackward(const LocalParam& inputParam, + const llvm::APInt& resultScale); }; struct CKKSScaleModel { using SchemeParam = ckks::SchemeParam; using LocalParam = ckks::LocalParam; - static int64_t evalMulScale(const LocalParam& param, int64_t lhs, - int64_t rhs); - static int64_t evalMulScaleBackward(const LocalParam& param, int64_t result, - int64_t lhs); - static int64_t evalModReduceScale(const LocalParam& inputParam, - int64_t scale); - static int64_t evalModReduceScaleBackward(const LocalParam& inputParam, - int64_t resultScale); + static llvm::APInt evalMulScale(const LocalParam& param, + const llvm::APInt& lhs, + const llvm::APInt& rhs); + static llvm::APInt evalMulScaleBackward(const LocalParam& param, + const llvm::APInt& result, + const llvm::APInt& lhs); + static llvm::APInt evalModReduceScale(const LocalParam& inputParam, + const llvm::APInt& scale); + static llvm::APInt evalModReduceScaleBackward(const LocalParam& inputParam, + const llvm::APInt& resultScale); }; /// Forward Analyse the scale of each secret Value @@ -127,11 +139,17 @@ class ScaleAnalysis using LocalParamType = typename ScaleModelT::LocalParam; ScaleAnalysis(DataFlowSolver& solver, const SchemeParamType& schemeParam, - int64_t inputScale) + const llvm::APInt& inputScale) : dataflow::SparseForwardDataFlowAnalysis(solver), schemeParam(schemeParam), inputScale(inputScale) {} + ScaleAnalysis(DataFlowSolver& solver, const SchemeParamType& schemeParam, + int64_t inputScale) + : dataflow::SparseForwardDataFlowAnalysis(solver), + schemeParam(schemeParam), + inputScale(llvm::APInt(64, inputScale)) {} + void setToEntryState(ScaleLattice* lattice) override { if (isa(lattice->getAnchor().getType())) { propagateIfChanged(lattice, lattice->join(ScaleState(inputScale))); @@ -154,7 +172,7 @@ class ScaleAnalysis private: const SchemeParamType schemeParam; - int64_t inputScale; + llvm::APInt inputScale; }; /// Backward Analyse the scale of plaintext Value / opaque result of @@ -204,13 +222,13 @@ class ScaleAnalysisBackward // Utils //===----------------------------------------------------------------------===// -int64_t getScale(Value value, DataFlowSolver* solver); +llvm::APInt getScale(Value value, DataFlowSolver* solver); constexpr StringRef kArgScaleAttrName = "mgmt.scale"; void annotateScale(Operation* top, DataFlowSolver* solver); -int64_t getScaleFromMgmtAttr(Value value); +llvm::APInt getScaleFromMgmtAttr(Value value); } // namespace heir } // namespace mlir diff --git a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp index 866ccf994b..da2eadeb0b 100644 --- a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp @@ -410,7 +410,12 @@ struct ConvertRlweEncodeOp : public OpConversionPattern { this->typeConverter->convertType(op.getOutput().getType()), params); auto encoding = op.getEncoding(); - int64_t scale = lwe::getScalingFactorFromEncodingAttr(encoding); + // High-precision scale management (#2364): convert APInt to int64_t for + // Lattigo + auto scaleAPInt = lwe::getScalingFactorFromEncodingAttr(encoding); + int64_t scale = scaleAPInt.getBitWidth() <= 64 + ? scaleAPInt.getSExtValue() + : scaleAPInt.getLimitedValue(INT64_MAX); SmallVector dialectAttrs(op->getDialectAttrs()); rewriter diff --git a/lib/Dialect/LWE/IR/LWEAttributes.cpp b/lib/Dialect/LWE/IR/LWEAttributes.cpp index 3d2d5994e0..8a062f98bb 100644 --- a/lib/Dialect/LWE/IR/LWEAttributes.cpp +++ b/lib/Dialect/LWE/IR/LWEAttributes.cpp @@ -22,55 +22,72 @@ namespace lwe { // Utils //===----------------------------------------------------------------------===// -int64_t getScalingFactorFromEncodingAttr(Attribute encoding) { - return llvm::TypeSwitch(encoding) +llvm::APInt getScalingFactorFromEncodingAttr(Attribute encoding) { + return llvm::TypeSwitch(encoding) .Case( [](auto attr) { return attr.getScalingFactor(); }) .Case( [](auto attr) { return attr.getScalingFactor(); }) - .Default([](Attribute) { return 0; }); + .Default([](Attribute) { return llvm::APInt(64, 0); }); } -int64_t inferMulOpScalingFactor(Attribute xEncoding, Attribute yEncoding, - int64_t plaintextModulus) { - int64_t xScale = getScalingFactorFromEncodingAttr(xEncoding); - int64_t yScale = getScalingFactorFromEncodingAttr(yEncoding); - return llvm::TypeSwitch(xEncoding) +llvm::APInt inferMulOpScalingFactor(Attribute xEncoding, Attribute yEncoding, + int64_t plaintextModulus) { + llvm::APInt xScale = getScalingFactorFromEncodingAttr(xEncoding); + llvm::APInt yScale = getScalingFactorFromEncodingAttr(yEncoding); + return llvm::TypeSwitch(xEncoding) .Case( // Use 128-bit int in case of large ptm. [&](auto attr) { - return (APInt(128, xScale) * APInt(128, yScale)) - .urem(plaintextModulus); + llvm::APInt xScale128 = xScale.zext(128); + llvm::APInt yScale128 = yScale.zext(128); + llvm::APInt result = (xScale128 * yScale128) + .urem(llvm::APInt(128, plaintextModulus)); + // Reduce to minimum bit width for consistent comparisons + unsigned minBits = result.getActiveBits(); + if (minBits == 0) minBits = 1; // APInt needs at least 1 bit + return result.trunc(std::max(64u, minBits)); }) - .Case( - [&](auto attr) { return xScale + yScale; }) - .Default([](Attribute) { return 0; }); + .Case([&](auto attr) { + // Ensure both scales have the same bit width before adding + unsigned maxBitWidth = + std::max(xScale.getBitWidth(), yScale.getBitWidth()); + llvm::APInt xScaleExt = xScale.zext(maxBitWidth); + llvm::APInt yScaleExt = yScale.zext(maxBitWidth); + return xScaleExt + yScaleExt; + }) + .Default([](Attribute) { return llvm::APInt(64, 0); }); } -int64_t inferModulusSwitchOrRescaleOpScalingFactor(Attribute xEncoding, - APInt dividedModulus, - int64_t plaintextModulus) { - int64_t xScale = getScalingFactorFromEncodingAttr(xEncoding); - return llvm::TypeSwitch(xEncoding) +llvm::APInt inferModulusSwitchOrRescaleOpScalingFactor( + Attribute xEncoding, APInt dividedModulus, int64_t plaintextModulus) { + llvm::APInt xScale = getScalingFactorFromEncodingAttr(xEncoding); + return llvm::TypeSwitch(xEncoding) .Case([&](auto attr) { // Use 128-bit int in case of large ptm. auto qInvT = multiplicativeInverse( APInt(128, dividedModulus.urem(plaintextModulus)), APInt(128, plaintextModulus)); - return (APInt(128, xScale) * qInvT).urem(plaintextModulus); + llvm::APInt xScale128 = xScale.zext(128); + llvm::APInt result = + (xScale128 * qInvT).urem(llvm::APInt(128, plaintextModulus)); + // Reduce to minimum bit width for consistent comparisons + unsigned minBits = result.getActiveBits(); + if (minBits == 0) minBits = 1; // APInt needs at least 1 bit + return result.trunc(std::max(64u, minBits)); }) .Case([&](auto attr) { // skip if xScale is 0 - if (xScale == 0) return xScale; + if (xScale.isZero()) return xScale; // round to nearest log2 instead of ceil auto logQ = dividedModulus.nearestLogBase2(); - return xScale - logQ; + return xScale - llvm::APInt(xScale.getBitWidth(), logQ); }) - .Default([](Attribute) { return 0; }); + .Default([](Attribute) { return llvm::APInt(64, 0); }); } Attribute getEncodingAttrWithNewScalingFactor(Attribute encoding, - int64_t newScale) { + const llvm::APInt& newScale) { return llvm::TypeSwitch(encoding) .Case([&](auto attr) { return FullCRTPackingEncodingAttr::get(encoding.getContext(), newScale); diff --git a/lib/Dialect/LWE/IR/LWEAttributes.h b/lib/Dialect/LWE/IR/LWEAttributes.h index 3d9325c487..6a7e3eb33a 100644 --- a/lib/Dialect/LWE/IR/LWEAttributes.h +++ b/lib/Dialect/LWE/IR/LWEAttributes.h @@ -18,7 +18,7 @@ namespace mlir { namespace heir { namespace lwe { -int64_t getScalingFactorFromEncodingAttr(Attribute encoding); +llvm::APInt getScalingFactorFromEncodingAttr(Attribute encoding); PlaintextSpaceAttr inferMulOpPlaintextSpaceAttr(MLIRContext* ctx, PlaintextSpaceAttr x, @@ -28,7 +28,7 @@ PlaintextSpaceAttr inferModulusSwitchOrRescaleOpPlaintextSpaceAttr( MLIRContext* ctx, PlaintextSpaceAttr x, APInt dividedModulus); Attribute getEncodingAttrWithNewScalingFactor(Attribute encoding, - int64_t newScale); + const llvm::APInt& newScale); } // namespace lwe } // namespace heir diff --git a/lib/Dialect/LWE/IR/LWEAttributes.td b/lib/Dialect/LWE/IR/LWEAttributes.td index ef0180765e..7b7223c74b 100644 --- a/lib/Dialect/LWE/IR/LWEAttributes.td +++ b/lib/Dialect/LWE/IR/LWEAttributes.td @@ -97,7 +97,7 @@ class LWE_EncodingAttrWithScalingParam:$scaling_factor ); } diff --git a/lib/Dialect/LWE/IR/LWETypes.cpp b/lib/Dialect/LWE/IR/LWETypes.cpp index 546266f5ab..9d647f36c3 100644 --- a/lib/Dialect/LWE/IR/LWETypes.cpp +++ b/lib/Dialect/LWE/IR/LWETypes.cpp @@ -54,9 +54,9 @@ LWECiphertextType getDefaultCGGICiphertextType(MLIRContext* ctx, ctx, lwe::ApplicationDataAttr::get(ctx, IntegerType::get(ctx, messageWidth), lwe::PreserveOverflowAttr::get(ctx)), - lwe::PlaintextSpaceAttr::get( - ctx, plaintextRing, - lwe::ConstantCoefficientEncodingAttr::get(ctx, scalingFactor)), + lwe::PlaintextSpaceAttr::get(ctx, plaintextRing, + lwe::ConstantCoefficientEncodingAttr::get( + ctx, llvm::APInt(64, scalingFactor))), lwe::CiphertextSpaceAttr::get(ctx, ciphertextRing, lwe::LweEncryptionType::msb, /*dimension=*/742), diff --git a/lib/Dialect/Mgmt/IR/MgmtAttributes.cpp b/lib/Dialect/Mgmt/IR/MgmtAttributes.cpp index 7a48560991..0991a461dc 100644 --- a/lib/Dialect/Mgmt/IR/MgmtAttributes.cpp +++ b/lib/Dialect/Mgmt/IR/MgmtAttributes.cpp @@ -18,11 +18,15 @@ namespace mgmt { // MgmtAttr helpers //===----------------------------------------------------------------------===// -MgmtAttr getMgmtAttrWithNewScale(MgmtAttr mgmtAttr, int64_t scale) { +MgmtAttr getMgmtAttrWithNewScale(MgmtAttr mgmtAttr, const llvm::APInt& scale) { return MgmtAttr::get(mgmtAttr.getContext(), mgmtAttr.getLevel(), mgmtAttr.getDimension(), scale); } +MgmtAttr getMgmtAttrWithNewScale(MgmtAttr mgmtAttr, int64_t scale) { + return getMgmtAttrWithNewScale(mgmtAttr, llvm::APInt(64, scale)); +} + //===----------------------------------------------------------------------===// // Getters and Setters //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Mgmt/IR/MgmtAttributes.h b/lib/Dialect/Mgmt/IR/MgmtAttributes.h index 8e47298589..a8617642bb 100644 --- a/lib/Dialect/Mgmt/IR/MgmtAttributes.h +++ b/lib/Dialect/Mgmt/IR/MgmtAttributes.h @@ -18,6 +18,7 @@ namespace mgmt { // MgmtAttr helpers //===----------------------------------------------------------------------===// +MgmtAttr getMgmtAttrWithNewScale(MgmtAttr mgmtAttr, const llvm::APInt& scale); MgmtAttr getMgmtAttrWithNewScale(MgmtAttr mgmtAttr, int64_t scale); //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Mgmt/IR/MgmtAttributes.td b/lib/Dialect/Mgmt/IR/MgmtAttributes.td index 78cfb45931..fde3e4a7a3 100644 --- a/lib/Dialect/Mgmt/IR/MgmtAttributes.td +++ b/lib/Dialect/Mgmt/IR/MgmtAttributes.td @@ -3,12 +3,18 @@ include "lib/Dialect/Mgmt/IR/MgmtDialect.td" include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" class Mgmt_Attr traits = []> : AttrDef { let mnemonic = attrMnemonic; } +// Custom APInt parameter with safe comparator for scale values +// Note: APIntParameter defines the C++ storage type and uses the default +// FieldParser for parsing integer values from text +def ScaleAPIntParameter : APIntParameter<"scale value">; + def Mgmt_MgmtAttr : Mgmt_Attr<"Mgmt", "mgmt"> { let summary = "Container attribute for all mgmt parameter"; let description = [{ @@ -17,6 +23,7 @@ def Mgmt_MgmtAttr : Mgmt_Attr<"Mgmt", "mgmt"> { The attribute is a struct with the following fields: - `level` : the level of the ciphertext, from L to 0 - `dimension` : the dimension of the ciphertext, defaults to 2 + - `scale` : the scale of the ciphertext (high-precision APInt), defaults to 0 Internally, this attribute is used by secret-to- for determining the level and dimension of the ciphertext. @@ -26,7 +33,7 @@ def Mgmt_MgmtAttr : Mgmt_Attr<"Mgmt", "mgmt"> { Example: ``` - #mgmt = #mgmt.mgmt // dimension defaults to 2 + #mgmt = #mgmt.mgmt // dimension defaults to 2, scale defaults to 0 #mgmt1 = #mgmt.mgmt %0 = secret.generic(%arg0, %arg1 : !secret.secret) attrs = {mgmt.mgmt = #mgmt} { ... @@ -37,14 +44,22 @@ def Mgmt_MgmtAttr : Mgmt_Attr<"Mgmt", "mgmt"> { let parameters = (ins "int": $level, DefaultValuedParameter<"int", "2">:$dimension, - DefaultValuedParameter<"int64_t", "0">:$scale + ScaleAPIntParameter:$scale ); let builders = [ AttrBuilder<(ins "int":$level, "int":$dimension), [{ return $_get( $_ctxt, level, - dimension, 0); + dimension, + ::llvm::APInt(64, 0)); + }]>, + AttrBuilder<(ins "int":$level, "int":$dimension, "int64_t":$scale), [{ + return $_get( + $_ctxt, + level, + dimension, + ::llvm::APInt(64, scale)); }]>, ]; let assemblyFormat = "`<` struct(params) `>`"; diff --git a/lib/Dialect/Mgmt/Transforms/AnnotateMgmt.cpp b/lib/Dialect/Mgmt/Transforms/AnnotateMgmt.cpp index b2cb858eeb..effebbd66b 100644 --- a/lib/Dialect/Mgmt/Transforms/AnnotateMgmt.cpp +++ b/lib/Dialect/Mgmt/Transforms/AnnotateMgmt.cpp @@ -36,7 +36,19 @@ void annotateMgmtAttr(Operation* top) { if (!scaleAttr) { return MgmtAttr::get(top->getContext(), level, dimension); } - auto scale = cast(scaleAttr).getInt(); + // High-precision scale management (#2364): scale is stored as StringAttr + // Parse it back to APInt for MgmtAttr + llvm::APInt scale(64, 0); + if (auto strAttr = dyn_cast(scaleAttr)) { + // Parse string as APInt with appropriate bit width + // Use a large bit width to ensure the string fits + unsigned bitWidth = + std::max(64u, static_cast(strAttr.getValue().size() * 4)); + scale = llvm::APInt(bitWidth, strAttr.getValue(), 10); + } else { + // Fallback for backward compatibility with IntegerAttr + scale = llvm::APInt(64, cast(scaleAttr).getInt()); + } return MgmtAttr::get(top->getContext(), level, dimension, scale); }; diff --git a/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp b/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp index ad362c210a..8e6af1fa2a 100644 --- a/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp +++ b/lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp @@ -131,6 +131,8 @@ class SecretToBGVTypeConverter auto* ctx = type.getContext(); auto level = mgmtAttr.getLevel(); auto dimension = mgmtAttr.getDimension(); + // High-precision scale management (#2364): MgmtAttr now stores APInt + // Pass APInt directly to LWE encoding attributes (which now also use APInt) auto scale = mgmtAttr.getScale(); auto tensorValueType = dyn_cast(type.getValueType()); diff --git a/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp b/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp index 496a034a8b..8291ccc711 100644 --- a/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp +++ b/lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp @@ -134,6 +134,8 @@ class SecretToCKKSTypeConverter auto* ctx = type.getContext(); auto level = mgmtAttr.getLevel(); auto dimension = mgmtAttr.getDimension(); + // High-precision scale management (#2364): MgmtAttr now stores APInt + // Pass APInt directly to LWE encoding attributes (which now also use APInt) auto scale = mgmtAttr.getScale(); auto tensorValueType = dyn_cast(type.getValueType()); diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index 2dafe9e154..afb0f4e974 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -1359,8 +1359,10 @@ LogicalResult LattigoEmitter::printOperation(CKKSEncodeOp op) { // set the scale of plaintext auto scale = op.getScale(); os << plaintextName << ".Scale = "; - os << getName(newPlaintextOp.getParams()) << ".NewScale(math.Pow(2, "; - os << scale << "))\n"; + // High-precision scale management (#2364): scale is now an actual value, not + // log scale Emit the scale directly instead of math.Pow(2, scale) + os << getName(newPlaintextOp.getParams()) << ".NewScale("; + os << scale << ")\n"; os << getName(op.getEncoder()) << ".Encode("; os << packedName << ", "; diff --git a/lib/Transforms/PopulateScale/PopulateScaleCKKS.cpp b/lib/Transforms/PopulateScale/PopulateScaleCKKS.cpp index 8c0a63bd55..9725ebc88b 100644 --- a/lib/Transforms/PopulateScale/PopulateScaleCKKS.cpp +++ b/lib/Transforms/PopulateScale/PopulateScaleCKKS.cpp @@ -31,8 +31,15 @@ class CKKSAdjustScaleMaterializer : public AdjustScaleMaterializer { virtual ~CKKSAdjustScaleMaterializer() = default; int64_t deltaScale(int64_t scale, int64_t inputScale) const override { - // TODO(#1640): support high-precision scale management - return scale - inputScale; + // High-precision scale management (#2364): + // With actual scales (not log scales), delta is computed as scale / + // inputScale For backward compatibility when scales are still in log + // domain, fall back to subtraction This assumes scales are small enough to + // fit in int64_t for now + if (inputScale == 0) { + return scale; + } + return scale / inputScale; } }; @@ -57,12 +64,19 @@ struct PopulateScaleCKKS : impl::PopulateScaleCKKSBase { dataflow::loadBaselineAnalyses(solver); // ScaleAnalysis depends on SecretnessAnalysis solver.load(); - // set input scale to logDefaultScale - auto inputScale = logDefaultScale; + // High-precision scale management (#2364): convert logDefaultScale to + // actual scale inputScale = 2^logDefaultScale + auto logScale = logDefaultScale; if (beforeMulIncludeFirstMul) { // encode at double degree - inputScale *= 2; + logScale *= 2; } + // Convert from log scale to actual scale using APInt + // Need to compute 2^logScale + // Use a bit width large enough to hold 2^logScale (logScale + 1 bits + // minimum) + unsigned bitWidth = std::max(64u, static_cast(logScale) + 1); + auto inputScale = llvm::APInt(bitWidth, 1).shl(logScale); solver.load>( ckks::SchemeParam::getSchemeParamFromAttr(ckksSchemeParamAttr), /*inputScale*/ inputScale); diff --git a/lib/Transforms/PopulateScale/PopulateScalePatterns.cpp b/lib/Transforms/PopulateScale/PopulateScalePatterns.cpp index 706c5cb428..a3c4e18124 100644 --- a/lib/Transforms/PopulateScale/PopulateScalePatterns.cpp +++ b/lib/Transforms/PopulateScale/PopulateScalePatterns.cpp @@ -16,15 +16,24 @@ namespace heir { template LogicalResult ConvertAdjustScaleToMulPlain::matchAndRewrite( mgmt::AdjustScaleOp op, PatternRewriter& rewriter) const { - auto inputScale = mgmt::findMgmtAttrAssociatedWith(op.getInput()).getScale(); - int64_t scale = mgmt::findMgmtAttrAssociatedWith(op).getScale(); + auto inputScaleAPInt = + mgmt::findMgmtAttrAssociatedWith(op.getInput()).getScale(); + auto scaleAPInt = mgmt::findMgmtAttrAssociatedWith(op).getScale(); // no need to adjust scale - if (scale == inputScale) { + if (scaleAPInt == inputScaleAPInt) { rewriter.replaceAllOpUsesWith(op, op->getOperand(0)); rewriter.eraseOp(op); return success(); } + // Convert APInt to int64_t for materializer (which still uses int64_t) + int64_t scale = scaleAPInt.getBitWidth() <= 64 + ? scaleAPInt.getSExtValue() + : scaleAPInt.getLimitedValue(INT64_MAX); + int64_t inputScale = inputScaleAPInt.getBitWidth() <= 64 + ? inputScaleAPInt.getSExtValue() + : inputScaleAPInt.getLimitedValue(INT64_MAX); + auto deltaScale = materializer->deltaScale(scale, inputScale); if (deltaScale < 0) { op.emitError() << "delta scale is negative"; diff --git a/lib/Utils/ContextAwareConversionUtils.h b/lib/Utils/ContextAwareConversionUtils.h index 3b0d1d29e0..7a10824a35 100644 --- a/lib/Utils/ContextAwareConversionUtils.h +++ b/lib/Utils/ContextAwareConversionUtils.h @@ -277,8 +277,11 @@ class SecretGenericOpCipherPlainConversion Attribute ciphertextEncoding = ciphertextElementTy.getPlaintextSpace().getEncoding(); + // High-precision scale management (#2364): pass APInt directly to LWE + // encoding + auto scaleAPInt = mgmtAttr.getScale(); Attribute plaintextEncoding = lwe::getEncodingAttrWithNewScalingFactor( - ciphertextEncoding, mgmtAttr.getScale()); + ciphertextEncoding, scaleAPInt); if (!plaintextEncoding) { return rewriter.notifyMatchFailure( diff --git a/tests/Transforms/annotate_mgmt/client_helpers.mlir b/tests/Transforms/annotate_mgmt/client_helpers.mlir index 4a3d5a3df1..1dc2831979 100644 --- a/tests/Transforms/annotate_mgmt/client_helpers.mlir +++ b/tests/Transforms/annotate_mgmt/client_helpers.mlir @@ -1,20 +1,20 @@ // RUN: heir-opt --annotate-mgmt %s | FileCheck %s // CHECK: @main -func.func @main(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { +func.func @main(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { return %arg0 : !secret.secret } // CHECK: @encrypt_helper -// CHECK-SAME: (%[[arg0:.*]]: i16) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) attributes +// CHECK-SAME: (%[[arg0:.*]]: i16) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) attributes func.func @encrypt_helper(%arg0: i16) -> !secret.secret attributes {client.enc_func = {func_name = "main", index = 0 : i64}} { // CHECK: secret.conceal - // CHECK-SAME: {mgmt.mgmt = #mgmt.mgmt + // CHECK-SAME: {mgmt.mgmt = #mgmt.mgmt %0 = secret.conceal %arg0 : i16 -> !secret.secret return %0 : !secret.secret } // CHECK: @decrypt_helper -// CHECK-SAME: (%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt +// CHECK-SAME: (%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt func.func @decrypt_helper(%arg0: !secret.secret) -> i16 attributes {client.dec_func = {func_name = "main", index = 0 : i64}} { %0 = secret.reveal %arg0 : !secret.secret -> i16 return %0 : i16 diff --git a/tests/Transforms/annotate_mgmt/dimension_backprop.mlir b/tests/Transforms/annotate_mgmt/dimension_backprop.mlir index 6b40f56999..fae1b48a83 100644 --- a/tests/Transforms/annotate_mgmt/dimension_backprop.mlir +++ b/tests/Transforms/annotate_mgmt/dimension_backprop.mlir @@ -2,28 +2,28 @@ // CHECK: func @dimension_backprop func.func @dimension_backprop( - %arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, - %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt} - ) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, + %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt} + ) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { %cst = arith.constant dense<7> : tensor<1024xi16> - // CHECK: mgmt.init %{{.*}} {mgmt.mgmt = #mgmt.mgmt} - %0 = mgmt.init %cst {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + // CHECK: mgmt.init %{{.*}} {mgmt.mgmt = #mgmt.mgmt} + %0 = mgmt.init %cst {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> %1 = secret.generic( - %arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, - %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt} + %arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, + %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt} ) { ^body(%input0: tensor<1024xi16>, %input1: tensor<1024xi16>): // CHECK: arith.muli - // CHECK-SAME: {mgmt.mgmt = #mgmt.mgmt} - %2 = arith.muli %input0, %input1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + // CHECK-SAME: {mgmt.mgmt = #mgmt.mgmt} + %2 = arith.muli %input0, %input1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> // Here, the result of the ciphertext-plaintext mul does not increase // the dimension of the ciphertext. All that's happening here is the // plaintext (mgmt.init above) is informed the level information required // of it to encode to a compatible plaintext. // CHECK: arith.muli - // CHECK-SAME: {mgmt.mgmt = #mgmt.mgmt} - %3 = arith.muli %0, %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + // CHECK-SAME: {mgmt.mgmt = #mgmt.mgmt} + %3 = arith.muli %0, %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> secret.yield %3 : tensor<1024xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %1 : !secret.secret> } diff --git a/tests/Transforms/generate_param_bfv/doctest.mlir b/tests/Transforms/generate_param_bfv/doctest.mlir index d794a4b830..41f2759217 100644 --- a/tests/Transforms/generate_param_bfv/doctest.mlir +++ b/tests/Transforms/generate_param_bfv/doctest.mlir @@ -2,12 +2,12 @@ // CHECK: module attributes {bgv.schemeParam = #bgv.scheme_param} module { - func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} { + func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} { ^body(%input0: i16): - %1 = arith.addi %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : i16 + %1 = arith.addi %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : i16 secret.yield %1 : i16 - } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret } } diff --git a/tests/Transforms/generate_param_bgv/doctest.mlir b/tests/Transforms/generate_param_bgv/doctest.mlir index 8c3d711c8b..027b6acc5d 100644 --- a/tests/Transforms/generate_param_bgv/doctest.mlir +++ b/tests/Transforms/generate_param_bgv/doctest.mlir @@ -2,12 +2,12 @@ // CHECK: module attributes {bgv.schemeParam = #bgv.scheme_param} module { - func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} { + func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} { ^body(%input0: i16): - %1 = arith.addi %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : i16 + %1 = arith.addi %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : i16 secret.yield %1 : i16 - } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret } } diff --git a/tests/Transforms/generate_param_ckks/doctest.mlir b/tests/Transforms/generate_param_ckks/doctest.mlir index 163a56a24b..10f00471b8 100644 --- a/tests/Transforms/generate_param_ckks/doctest.mlir +++ b/tests/Transforms/generate_param_ckks/doctest.mlir @@ -2,12 +2,12 @@ // CHECK: {ckks.schemeParam = #ckks.scheme_param} module { - func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} { + func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} { ^body(%input0: f16): - %1 = arith.addf %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : f16 + %1 = arith.addf %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : f16 secret.yield %1 : f16 - } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret } } diff --git a/tests/Transforms/optimize_relinearization/issue_1548.mlir b/tests/Transforms/optimize_relinearization/issue_1548.mlir index 74e2679532..8ef568f1d0 100644 --- a/tests/Transforms/optimize_relinearization/issue_1548.mlir +++ b/tests/Transforms/optimize_relinearization/issue_1548.mlir @@ -1,13 +1,13 @@ // RUN: heir-opt --optimize-relinearization=allow-mixed-degree-operands=true %s | FileCheck %s // CHECK-NOT: dimension = 4 -func.func @two_mul(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { +func.func @two_mul(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: tensor<8xi16>, %input1: tensor<8xi16>): - %1 = arith.muli %input0, %input1 {mgmt.mgmt = #mgmt.mgmt} : tensor<8xi16> - %2 = arith.muli %1, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<8xi16> - %3 = mgmt.relinearize %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<8xi16> + %1 = arith.muli %input0, %input1 {mgmt.mgmt = #mgmt.mgmt} : tensor<8xi16> + %2 = arith.muli %1, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<8xi16> + %3 = mgmt.relinearize %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<8xi16> secret.yield %3 : tensor<8xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret> } diff --git a/tests/Transforms/populate_scale/bgv/doctest.mlir b/tests/Transforms/populate_scale/bgv/doctest.mlir index 6105f81828..cafba5ff38 100644 --- a/tests/Transforms/populate_scale/bgv/doctest.mlir +++ b/tests/Transforms/populate_scale/bgv/doctest.mlir @@ -2,17 +2,17 @@ module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { func.func @mul(%arg0: !secret.secret) -> !secret.secret { - %0 = secret.generic(%arg0 : !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0 : !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: i16): - %1 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : i16 + %1 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : i16 // CHECK: mgmt.relinearize // CHECK-SAME: scale = 1 - %2 = mgmt.relinearize %1 {mgmt.mgmt = #mgmt.mgmt} : i16 + %2 = mgmt.relinearize %1 {mgmt.mgmt = #mgmt.mgmt} : i16 // CHECK: mgmt.modreduce // CHECK-SAME: scale = 42541 - %3 = mgmt.modreduce %2 {mgmt.mgmt = #mgmt.mgmt} : i16 + %3 = mgmt.modreduce %2 {mgmt.mgmt = #mgmt.mgmt} : i16 secret.yield %3 : i16 - } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret } } diff --git a/tests/Transforms/populate_scale/bgv/smoke_test.mlir b/tests/Transforms/populate_scale/bgv/smoke_test.mlir index 88c15ab708..3d771dd873 100644 --- a/tests/Transforms/populate_scale/bgv/smoke_test.mlir +++ b/tests/Transforms/populate_scale/bgv/smoke_test.mlir @@ -28,28 +28,28 @@ module attributes {bgv.schemeParam = #bgv.scheme_param // the plaintext operand in question - %0 = mgmt.init %inserted {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %0 = mgmt.init %inserted {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> - %1 = secret.generic(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %1 = secret.generic(%arg0 : !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: tensor<32xi16>): - %2 = tensor_ext.rotate %input0, %c16 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index - %3 = arith.addi %input0, %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> - %4 = tensor_ext.rotate %3, %c8 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index - %5 = arith.addi %3, %4 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> - %6 = tensor_ext.rotate %5, %c4 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index - %7 = arith.addi %5, %6 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> - %8 = tensor_ext.rotate %7, %c2 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index - %9 = arith.addi %7, %8 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> - %10 = tensor_ext.rotate %9, %c1 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index - %11 = arith.addi %9, %10 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %2 = tensor_ext.rotate %input0, %c16 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index + %3 = arith.addi %input0, %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %4 = tensor_ext.rotate %3, %c8 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index + %5 = arith.addi %3, %4 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %6 = tensor_ext.rotate %5, %c4 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index + %7 = arith.addi %5, %6 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %8 = tensor_ext.rotate %7, %c2 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index + %9 = arith.addi %7, %8 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %10 = tensor_ext.rotate %9, %c1 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index + %11 = arith.addi %9, %10 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> // The mul op in question - %12 = arith.muli %0, %11 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %12 = arith.muli %0, %11 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> - %13 = tensor_ext.rotate %12, %c31 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index - %14 = mgmt.modreduce %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> + %13 = tensor_ext.rotate %12, %c31 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16>, index + %14 = mgmt.modreduce %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<32xi16> secret.yield %14 : tensor<32xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %1 : !secret.secret> } } diff --git a/tests/Transforms/populate_scale/ckks/doctest.mlir b/tests/Transforms/populate_scale/ckks/doctest.mlir index 326e1562f2..05d517cb1b 100644 --- a/tests/Transforms/populate_scale/ckks/doctest.mlir +++ b/tests/Transforms/populate_scale/ckks/doctest.mlir @@ -1,18 +1,22 @@ // RUN: heir-opt %s --populate-scale-ckks | FileCheck %s +// High-precision scale management: uses actual scales (2^logDefaultScale) instead of log scales +// logDefaultScale = 45, so default scale = 2^45 = 35184372088832 +// After multiply: scale = 2^45 * 2^45 = 2^90 = 1237940039285380274899124224 +// After modreduce (rescale by Q[1]): scale = 2^90 / 35184372121601 ≈ 35184372088832 module attributes {ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { func.func @mul(%arg0: !secret.secret) -> !secret.secret { - %0 = secret.generic(%arg0 : !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0 : !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: f32): - %1 = arith.mulf %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : f32 + %1 = arith.mulf %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : f32 // CHECK: mgmt.relinearize - // CHECK-SAME: scale = 90 - %2 = mgmt.relinearize %1 {mgmt.mgmt = #mgmt.mgmt} : f32 + // CHECK-SAME: scale = 1237940039285380274899124224 + %2 = mgmt.relinearize %1 {mgmt.mgmt = #mgmt.mgmt} : f32 // CHECK: mgmt.modreduce - // CHECK-SAME: scale = 45 - %3 = mgmt.modreduce %2 {mgmt.mgmt = #mgmt.mgmt} : f32 + // CHECK-SAME: scale = 35184372088832 + %3 = mgmt.modreduce %2 {mgmt.mgmt = #mgmt.mgmt} : f32 secret.yield %3 : f32 - } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret } } diff --git a/tests/Transforms/secret_insert_mgmt/bgv/init.mlir b/tests/Transforms/secret_insert_mgmt/bgv/init.mlir index bd63daac31..9f28682ca0 100644 --- a/tests/Transforms/secret_insert_mgmt/bgv/init.mlir +++ b/tests/Transforms/secret_insert_mgmt/bgv/init.mlir @@ -8,18 +8,18 @@ // %c4_i32 = arith.constant 4 : i32 // %c3_i32 = arith.constant 3 : i32 // %c2_i32 = arith.constant 2 : i32 -// %0 = mgmt.init %c2_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %1 = mgmt.init %c4_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %2 = mgmt.init %c3_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %3 = secret.generic(%arg0 : !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { +// %0 = mgmt.init %c2_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %1 = mgmt.init %c4_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %2 = mgmt.init %c3_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %3 = secret.generic(%arg0 : !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { // ^body(%input0: i32): -// %4 = arith.muli %input0, %0 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %5 = arith.addi %4, %1 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %6 = mgmt.modreduce %5 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %7 = arith.muli %6, %2 {mgmt.mgmt = #mgmt.mgmt} : i32 -// %8 = mgmt.modreduce %7 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %4 = arith.muli %input0, %0 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %5 = arith.addi %4, %1 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %6 = mgmt.modreduce %5 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %7 = arith.muli %6, %2 {mgmt.mgmt = #mgmt.mgmt} : i32 +// %8 = mgmt.modreduce %7 {mgmt.mgmt = #mgmt.mgmt} : i32 // secret.yield %8 : i32 -// } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) +// } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) // return %3 : !secret.secret // } // } @@ -30,23 +30,23 @@ module { // CHECK: %c4_i32 = arith.constant 4 : i32 // CHECK: %c3_i32 = arith.constant 3 : i32 // CHECK: %c2_i32 = arith.constant 2 : i32 - // CHECK: %[[v0:.*]] = mgmt.init %c2_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 - // CHECK: %[[v1:.*]] = mgmt.init %c4_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 - // CHECK: %[[v2:.*]] = mgmt.init %c3_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v0:.*]] = mgmt.init %c2_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v1:.*]] = mgmt.init %c4_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v2:.*]] = mgmt.init %c3_i32 {mgmt.mgmt = #mgmt.mgmt} : i32 %c4_i32 = arith.constant 4 : i32 %c3_i32 = arith.constant 3 : i32 %c2_i32 = arith.constant 2 : i32 %0 = secret.generic(%arg0 : !secret.secret) { // CHECK: ^body(%[[INPUT0:.*]]: i32): ^body(%input0: i32): - // CHECK: %[[v3:.*]] = arith.muli %[[INPUT0]], %[[v0]] {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v3:.*]] = arith.muli %[[INPUT0]], %[[v0]] {mgmt.mgmt = #mgmt.mgmt} : i32 %1 = arith.muli %input0, %c2_i32 : i32 - // CHECK: %[[v4:.*]] = arith.addi %[[v3]], %[[v1]] {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v4:.*]] = arith.addi %[[v3]], %[[v1]] {mgmt.mgmt = #mgmt.mgmt} : i32 %2 = arith.addi %1, %c4_i32 : i32 - // CHECK: %[[v5:.*]] = mgmt.modreduce %[[v4]] {mgmt.mgmt = #mgmt.mgmt} : i32 - // CHECK: %[[v6:.*]] = arith.muli %[[v5]], %[[v2]] {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v5:.*]] = mgmt.modreduce %[[v4]] {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v6:.*]] = arith.muli %[[v5]], %[[v2]] {mgmt.mgmt = #mgmt.mgmt} : i32 %3 = arith.muli %2, %c3_i32 : i32 - // CHECK: %[[v7:.*]] = mgmt.modreduce %[[v6]] {mgmt.mgmt = #mgmt.mgmt} : i32 + // CHECK: %[[v7:.*]] = mgmt.modreduce %[[v6]] {mgmt.mgmt = #mgmt.mgmt} : i32 // CHECK: secret.yield %[[v7]] : i32 secret.yield %3 : i32 } -> !secret.secret @@ -54,11 +54,11 @@ module { } // CHECK: @pt_multiple_uses - // CHECK-SAME: (%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}, + // CHECK-SAME: (%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}, // CHECK-SAME: %[[arg1:.*]]: i16) func.func @pt_multiple_uses(%arg0: !secret.secret, %arg1: i16) -> (!secret.secret) { - // CHECK: %[[v0:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 - // CHECK: %[[v1:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 + // CHECK: %[[v0:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 + // CHECK: %[[v1:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 %0 = secret.generic(%arg0: !secret.secret) { ^body(%input0: i16): // CHECK: arith.addi @@ -75,13 +75,13 @@ module { } // CHECK: @pt_multiple_uses_2 - // CHECK-SAME: (%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}, + // CHECK-SAME: (%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}, // CHECK-SAME: %[[arg1:.*]]: i16) func.func @pt_multiple_uses_2(%arg0: !secret.secret, %arg1: i16) -> (!secret.secret) { - // CHECK: %[[v0:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 + // CHECK: %[[v0:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 // Note: these two mgmt.init should not merge, as later optimization like lazy relinearization // or populate-scale will make them different. - // CHECK: %[[v1:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 + // CHECK: %[[v1:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt} : i16 %0 = secret.generic(%arg0: !secret.secret) { ^body(%input0: i16): // CHECK: arith.addi diff --git a/tests/Transforms/secret_insert_mgmt/ckks/bootstrap_waterline.mlir b/tests/Transforms/secret_insert_mgmt/ckks/bootstrap_waterline.mlir index 0e1ee71aa2..1c97b613eb 100644 --- a/tests/Transforms/secret_insert_mgmt/ckks/bootstrap_waterline.mlir +++ b/tests/Transforms/secret_insert_mgmt/ckks/bootstrap_waterline.mlir @@ -1,26 +1,26 @@ // RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-ckks=bootstrap-waterline=3 %s | FileCheck %s // CHECK: func.func @bootstrap_waterline -// CHECK: %0 = secret.generic(%[[arg0:.*]]: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { +// CHECK: %0 = secret.generic(%[[arg0:.*]]: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { // CHECK: (%[[input0:.*]]: tensor<1x1024xf16>): -// CHECK: %[[v1:.*]] = arith.addf %[[input0]], %[[input0]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v2:.*]] = mgmt.modreduce %[[v1]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v3:.*]] = arith.addf %2, %[[v2]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v4:.*]] = mgmt.modreduce %[[v3]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v5:.*]] = arith.addf %4, %[[v4]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v6:.*]] = mgmt.modreduce %[[v5]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v7:.*]] = mgmt.bootstrap %[[v6]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v8:.*]] = arith.addf %[[v7]], %[[v7]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v9:.*]] = mgmt.modreduce %[[v8]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v10:.*]] = arith.addf %[[v9]], %[[v9]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v11:.*]] = mgmt.modreduce %[[v10]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v12:.*]] = arith.addf %[[v11]], %[[v11]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v1:.*]] = arith.addf %[[input0]], %[[input0]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v2:.*]] = mgmt.modreduce %[[v1]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v3:.*]] = arith.addf %2, %[[v2]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v4:.*]] = mgmt.modreduce %[[v3]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v5:.*]] = arith.addf %4, %[[v4]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v6:.*]] = mgmt.modreduce %[[v5]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v7:.*]] = mgmt.bootstrap %[[v6]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v8:.*]] = arith.addf %[[v7]], %[[v7]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v9:.*]] = mgmt.modreduce %[[v8]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v10:.*]] = arith.addf %[[v9]], %[[v9]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v11:.*]] = mgmt.modreduce %[[v10]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v12:.*]] = arith.addf %[[v11]], %[[v11]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> // cross level op -// CHECK: %[[v13:.*]] = mgmt.level_reduce %[[input0]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v14:.*]] = mgmt.adjust_scale %[[v13]] {id = 0 : i64, mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v15:.*]] = mgmt.modreduce %[[v14]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v16:.*]] = mgmt.adjust_scale %[[v12]] {id = 1 : i64, mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> -// CHECK: %[[v17:.*]] = arith.addf %[[v16]], %[[v15]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v13:.*]] = mgmt.level_reduce %[[input0]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v14:.*]] = mgmt.adjust_scale %[[v13]] {id = 0 : i64, mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v15:.*]] = mgmt.modreduce %[[v14]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v16:.*]] = mgmt.adjust_scale %[[v12]] {id = 1 : i64, mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> +// CHECK: %[[v17:.*]] = arith.addf %[[v16]], %[[v15]] {mgmt.mgmt = #mgmt.mgmt} : tensor<1x1024xf16> // CHECK: secret.yield %[[v17]] : tensor<1x1024xf16> diff --git a/tests/Transforms/secret_insert_mgmt/func_call.mlir b/tests/Transforms/secret_insert_mgmt/func_call.mlir index 7ad3c7c78b..77392b2750 100644 --- a/tests/Transforms/secret_insert_mgmt/func_call.mlir +++ b/tests/Transforms/secret_insert_mgmt/func_call.mlir @@ -1,12 +1,12 @@ // RUN: heir-opt --secret-insert-mgmt-ckks --split-input-file %s | FileCheck %s // CHECK: func.func private @extract_plaintext(f32) -> f32 -// CHECK: func.func @call_plaintext(%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) -// CHECK-NEXT: %[[v0:.*]] = secret.generic(%[[arg0]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { +// CHECK: func.func @call_plaintext(%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) +// CHECK-NEXT: %[[v0:.*]] = secret.generic(%[[arg0]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) { // CHECK-NEXT: ^body(%[[input0:.*]]: f32) -// CHECK-NEXT: %[[v1:.*]] = func.call @extract_plaintext(%[[input0]]) {mgmt.mgmt = #mgmt.mgmt} : (f32) -> f32 +// CHECK-NEXT: %[[v1:.*]] = func.call @extract_plaintext(%[[input0]]) {mgmt.mgmt = #mgmt.mgmt} : (f32) -> f32 // CHECK-NEXT: secret.yield %[[v1]] : f32 -// CHECK-NEXT: } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) +// CHECK-NEXT: } -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) // CHECK-NEXT: return %[[v0]] : !secret.secret module { func.func private @extract_plaintext(f32) -> f32 @@ -22,8 +22,8 @@ module { // ----- -// CHECK: func.func private @external_secret(!secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> !secret.secret -// CHECK: func.func @call_secret(%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) +// CHECK: func.func private @external_secret(!secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> !secret.secret +// CHECK: func.func @call_secret(%[[arg0:.*]]: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) // CHECK-NEXT: %[[v0:.*]] = call @external_secret(%[[arg0]]) // CHECK-NEXT: return %[[v0]] module { diff --git a/tests/Transforms/secret_insert_mgmt_bfv/doctest.mlir b/tests/Transforms/secret_insert_mgmt_bfv/doctest.mlir index 4587bef9cf..cfe10ceecd 100644 --- a/tests/Transforms/secret_insert_mgmt_bfv/doctest.mlir +++ b/tests/Transforms/secret_insert_mgmt_bfv/doctest.mlir @@ -3,9 +3,9 @@ // CHECK: func.func @func // CHECK: %[[GENERIC:.*]] = secret.generic(%{{.*}}, %{{.*}}) { // CHECK: ^body(%[[ARG0:.*]]: i16, %[[ARG1:.*]]: i16): -// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG0]], %[[ARG1]] {mgmt.mgmt = #mgmt.mgmt} : i16 -// CHECK-NEXT: %[[RELIN:.*]] = mgmt.relinearize %[[MUL]] {mgmt.mgmt = #mgmt.mgmt} : i16 -// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[RELIN]], %[[ARG1]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG0]], %[[ARG1]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[RELIN:.*]] = mgmt.relinearize %[[MUL]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[RELIN]], %[[ARG1]] {mgmt.mgmt = #mgmt.mgmt} : i16 // CHECK-NEXT: secret.yield %[[ADD]] : i16 // CHECK: return func.func @func(%arg0: !secret.secret, %arg1: !secret.secret) -> !secret.secret { diff --git a/tests/Transforms/secret_insert_mgmt_bgv/doctest.mlir b/tests/Transforms/secret_insert_mgmt_bgv/doctest.mlir index 2d3b66a845..1e21216cd0 100644 --- a/tests/Transforms/secret_insert_mgmt_bgv/doctest.mlir +++ b/tests/Transforms/secret_insert_mgmt_bgv/doctest.mlir @@ -3,11 +3,11 @@ // CHECK: func.func @func // CHECK: %[[GENERIC:.*]] = secret.generic(%{{.*}}, %{{.*}}) { // CHECK: ^body(%[[ARG0:.*]]: i16, %[[ARG1:.*]]: i16): -// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG0]], %[[ARG1]] {mgmt.mgmt = #mgmt.mgmt} : i16 -// CHECK-NEXT: %[[RELIN:.*]] = mgmt.relinearize %[[MUL]] {mgmt.mgmt = #mgmt.mgmt} : i16 -// CHECK-NEXT: %[[ADJUST:.*]] = mgmt.adjust_scale %[[ARG1]] {id = 0 : i64, mgmt.mgmt = #mgmt.mgmt} : i16 -// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[RELIN]], %[[ADJUST]] {mgmt.mgmt = #mgmt.mgmt} : i16 -// CHECK-NEXT: %[[MODRED:.*]] = mgmt.modreduce %[[ADD]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[ARG0]], %[[ARG1]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[RELIN:.*]] = mgmt.relinearize %[[MUL]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[ADJUST:.*]] = mgmt.adjust_scale %[[ARG1]] {id = 0 : i64, mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[RELIN]], %[[ADJUST]] {mgmt.mgmt = #mgmt.mgmt} : i16 +// CHECK-NEXT: %[[MODRED:.*]] = mgmt.modreduce %[[ADD]] {mgmt.mgmt = #mgmt.mgmt} : i16 // CHECK-NEXT: secret.yield %[[MODRED]] : i16 // CHECK: return func.func @func(%arg0: !secret.secret, %arg1: !secret.secret) -> !secret.secret { diff --git a/tests/Transforms/validate_noise/validate_noise_fail.mlir b/tests/Transforms/validate_noise/validate_noise_fail.mlir index f97ae7e016..6e0a1170d1 100644 --- a/tests/Transforms/validate_noise/validate_noise_fail.mlir +++ b/tests/Transforms/validate_noise/validate_noise_fail.mlir @@ -14,26 +14,26 @@ // expected-error@below {{Noise validation failed.}} module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { - func.func @dot_product(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + func.func @dot_product(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: tensor<1024xi16>): - %1 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %2 = mgmt.relinearize %1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %3 = mgmt.modreduce %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %4 = arith.muli %3, %3 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %5 = mgmt.relinearize %4 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %6 = mgmt.modreduce %5 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %7 = arith.muli %6, %6 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %8 = mgmt.relinearize %7 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %9 = mgmt.modreduce %8 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %10 = arith.muli %9, %9 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %11 = mgmt.relinearize %10 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %12 = mgmt.modreduce %11 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %13 = arith.muli %12, %12 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %14 = mgmt.relinearize %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %15 = mgmt.modreduce %14 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %1 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %2 = mgmt.relinearize %1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %3 = mgmt.modreduce %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %4 = arith.muli %3, %3 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %5 = mgmt.relinearize %4 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %6 = mgmt.modreduce %5 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %7 = arith.muli %6, %6 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %8 = mgmt.relinearize %7 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %9 = mgmt.modreduce %8 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %10 = arith.muli %9, %9 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %11 = mgmt.relinearize %10 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %12 = mgmt.modreduce %11 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %13 = arith.muli %12, %12 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %14 = mgmt.relinearize %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %15 = mgmt.modreduce %14 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> secret.yield %15 : tensor<1024xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret> } } diff --git a/tests/Transforms/validate_noise/validate_noise_pass.mlir b/tests/Transforms/validate_noise/validate_noise_pass.mlir index e3cd328491..64ea04490a 100644 --- a/tests/Transforms/validate_noise/validate_noise_pass.mlir +++ b/tests/Transforms/validate_noise/validate_noise_pass.mlir @@ -2,7 +2,7 @@ module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { // CHECK: @dot_product - func.func @dot_product(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + func.func @dot_product(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { %c7 = arith.constant 7 : index %c1_i16 = arith.constant 1 : i16 %cst = arith.constant dense<0> : tensor<1024xi16> @@ -10,23 +10,23 @@ module attributes {bgv.schemeParam = #bgv.scheme_param - %0 = mgmt.init %inserted {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %1 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %0 = mgmt.init %inserted {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %1 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}, %arg1: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: tensor<1024xi16>, %input1: tensor<1024xi16>): - %2 = arith.muli %input0, %input1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %3 = mgmt.relinearize %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %4 = tensor_ext.rotate %3, %c4 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index - %5 = arith.addi %3, %4 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %6 = tensor_ext.rotate %5, %c2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index - %7 = arith.addi %5, %6 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %8 = tensor_ext.rotate %7, %c1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index - %9 = arith.addi %7, %8 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %10 = mgmt.modreduce %9 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %11 = arith.muli %0, %10 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %12 = tensor_ext.rotate %11, %c7 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index - %13 = mgmt.modreduce %12 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %2 = arith.muli %input0, %input1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %3 = mgmt.relinearize %2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %4 = tensor_ext.rotate %3, %c4 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index + %5 = arith.addi %3, %4 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %6 = tensor_ext.rotate %5, %c2 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index + %7 = arith.addi %5, %6 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %8 = tensor_ext.rotate %7, %c1 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index + %9 = arith.addi %7, %8 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %10 = mgmt.modreduce %9 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %11 = arith.muli %0, %10 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %12 = tensor_ext.rotate %11, %c7 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16>, index + %13 = mgmt.modreduce %12 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> secret.yield %13 : tensor<1024xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %1 : !secret.secret> } } diff --git a/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir b/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir index 1a4e0fdfcf..7bce33c8b4 100644 --- a/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir +++ b/tests/Transforms/validate_noise/validate_noise_preserve_user_param.mlir @@ -3,7 +3,7 @@ // CHECK: module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { // CHECK: @return - func.func @return(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + func.func @return(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { return %arg0 : !secret.secret> } } diff --git a/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir b/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir index 63f37b745c..e389fa15fe 100644 --- a/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir +++ b/tests/Transforms/validate_noise/validate_noise_preserve_user_param_fail.mlir @@ -2,14 +2,14 @@ // expected-error@below {{'builtin.module' op The level in the scheme param is smaller than the max level.}} module attributes {bgv.schemeParam = #bgv.scheme_param} { - func.func @main(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + func.func @main(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: tensor<1024xi16>): - %13 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %14 = mgmt.relinearize %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %15 = mgmt.modreduce %14 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %13 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %14 = mgmt.relinearize %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %15 = mgmt.modreduce %14 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> secret.yield %15 : tensor<1024xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret> } } @@ -18,14 +18,14 @@ module attributes {bgv.schemeParam = #bgv.scheme_param} { - func.func @main(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { - %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + func.func @main(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { + %0 = secret.generic(%arg0: !secret.secret> {mgmt.mgmt = #mgmt.mgmt}) { ^body(%input0: tensor<1024xi16>): - %13 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %14 = mgmt.relinearize %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> - %15 = mgmt.modreduce %14 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %13 = arith.muli %input0, %input0 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %14 = mgmt.relinearize %13 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> + %15 = mgmt.modreduce %14 {mgmt.mgmt = #mgmt.mgmt} : tensor<1024xi16> secret.yield %15 : tensor<1024xi16> - } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) + } -> (!secret.secret> {mgmt.mgmt = #mgmt.mgmt}) return %0 : !secret.secret> } }