diff --git a/lib/Parameters/CKKS/Params.cpp b/lib/Parameters/CKKS/Params.cpp index 4363c7aa8a..afcc63c9e6 100644 --- a/lib/Parameters/CKKS/Params.cpp +++ b/lib/Parameters/CKKS/Params.cpp @@ -5,20 +5,230 @@ #include #include "lib/Parameters/RLWEParams.h" +#include "lib/Parameters/RLWESecurityParams.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "src/core/include/openfhecore.h" // from @openfhe namespace mlir { namespace heir { namespace ckks { +/// By original we mean the method in RNS-CKKS implementation +/// Corresponds to FIXED* in OpenFHE +static std::vector moduliQGenerationOpenFHEFixed(int logFirstMod, + int logDefaultScale, + int numLevel, + int ringDim) { + auto cyclOrder = ringDim * 2; + std::vector moduliQ(numLevel); + lbcrypto::NativeInteger q = + lbcrypto::FirstPrime(logDefaultScale, cyclOrder); + moduliQ[numLevel - 1] = q.ConvertToInt(); + + auto maxPrime{q}; + auto minPrime{q}; + + auto qPrev = q; + auto qNext = q; + if (numLevel > 2) { + for (size_t i = numLevel - 2, cnt = 0; i >= 1; --i, ++cnt) { + if ((cnt % 2) == 0) { + qPrev = PreviousPrime(qPrev, cyclOrder); + moduliQ[i] = qPrev.ConvertToInt(); + } else { + qNext = NextPrime(qNext, cyclOrder); + moduliQ[i] = qNext.ConvertToInt(); + } + + if (moduliQ[i] > maxPrime) + maxPrime = moduliQ[i]; + else if (moduliQ[i] < minPrime) + minPrime = moduliQ[i]; + } + } + + if (logFirstMod == logDefaultScale) { // this requires dcrtBits < 60 + moduliQ[0] = + lbcrypto::NextPrime(maxPrime, cyclOrder).ConvertToInt(); + } else { + moduliQ[0] = lbcrypto::LastPrime(logFirstMod, cyclOrder) + .ConvertToInt(); + + // find if the value of moduliQ[0] is already in the vector starting with + // moduliQ[1] and if there is, then get another prime for moduliQ[0] + const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]); + if (pos != moduliQ.end()) { + moduliQ[0] = lbcrypto::NextPrime(maxPrime, cyclOrder) + .ConvertToInt(); + } + } + return moduliQ; +} + +/// See "Reduced Error" paper https://eprint.iacr.org/2020/1118 +/// Corresponds to FLEXIBLE* in OpenFHE +static std::vector moduliQGenerationReducedError(int logFirstMod, + int logDefaultScale, + int numLevel, + int ringDim) { + auto cyclOrder = ringDim * 2; + std::vector moduliQ(numLevel); + lbcrypto::NativeInteger q = + lbcrypto::FirstPrime(logDefaultScale, cyclOrder); + moduliQ[numLevel - 1] = q.ConvertToInt(); + + auto maxPrime{q}; + auto minPrime{q}; + + if (numLevel > 2) { + for (size_t i = numLevel - 2, cnt = 0; i >= 1; --i, ++cnt) { + // Comments from OpenFHE ckksrns-parametergeneration.cpp + /* Scaling factors in FLEXIBLEAUTO are a bit fragile, + * in the sense that once one scaling factor gets far enough from the + * original scaling factor, subsequent level scaling factors quickly + * diverge to either 0 or infinity. To mitigate this problem to a certain + * extend, we have a special prime selection process in place. The goal is + * to maintain the scaling factor of all levels as close to the original + * scale factor of level 0 as possible. + */ + double sf = static_cast(moduliQ[numLevel - 1]); + for (size_t i = numLevel - 2, cnt = 0; i >= 1; --i, ++cnt) { + sf = pow(sf, 2) / static_cast(moduliQ[i + 1]); + NativeInteger sfInt = std::llround(sf); + NativeInteger sfRem = sfInt.Mod(cyclOrder); + bool hasSameMod = true; + if ((cnt % 2) == 0) { + NativeInteger qPrev = + sfInt - NativeInteger(cyclOrder) - sfRem + NativeInteger(1); + while (hasSameMod) { + hasSameMod = false; + qPrev = lbcrypto::PreviousPrime(qPrev, cyclOrder); + for (size_t j = i + 1; j < numLevel; j++) { + if (qPrev == moduliQ[j]) { + hasSameMod = true; + break; + } + } + } + moduliQ[i] = qPrev.ConvertToInt(); + } else { + NativeInteger qNext = + sfInt + NativeInteger(cyclOrder) - sfRem + NativeInteger(1); + while (hasSameMod) { + hasSameMod = false; + qNext = lbcrypto::NextPrime(qNext, cyclOrder); + for (size_t j = i + 1; j < numLevel; j++) { + if (qNext == moduliQ[j]) { + hasSameMod = true; + break; + } + } + } + moduliQ[i] = qNext.ConvertToInt(); + } + if (moduliQ[i] > maxPrime) + maxPrime = moduliQ[i]; + else if (moduliQ[i] < minPrime) + minPrime = moduliQ[i]; + } + } + } + + if (logFirstMod == logDefaultScale) { // this requires dcrtBits < 60 + moduliQ[0] = + lbcrypto::NextPrime(maxPrime, cyclOrder) + .ConvertToInt(); + } else { + moduliQ[0] = + lbcrypto::LastPrime(logFirstMod, cyclOrder) + .ConvertToInt(); + + // find if the value of moduliQ[0] is already in the vector starting with + // moduliQ[1] and if there is, then get another prime for moduliQ[0] + const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]); + if (pos != moduliQ.end()) { + moduliQ[0] = + lbcrypto::NextPrime(maxPrime, cyclOrder) + .ConvertToInt(); + } + } + return moduliQ; +} + +// numScaleMod is L SchemeParam SchemeParam::getConcreteSchemeParam( - std::vector logqi, int logDefaultScale, int slotNumber, - bool usePublicKey, bool encryptionTechniqueExtended) { + int logFirstMod, int logDefaultScale, int numScaleMod, int slotNumber, + bool usePublicKey, bool encryptionTechniqueExtended, bool reducedError) { // CKKS slot number = ringDim / 2 - return SchemeParam(RLWESchemeParam::getConcreteRLWESchemeParam( - std::move(logqi), 2 * slotNumber, usePublicKey, - encryptionTechniqueExtended), - logDefaultScale); + auto minRingDim = 2 * slotNumber; + + auto dnum = computeDnum(numScaleMod); + + // q0 + (q1 + ... + qL) = firstModBits + scalingModBits * L + double logQ = logFirstMod + logDefaultScale * numScaleMod; + // pi can be large + auto sizePi = 60; + + // make P > Q / dnum + auto logP = ceil(logQ / dnum); + auto numPi = ceil(logP / sizePi); + // update logP + logP = numPi * sizePi; + + auto logPQ = logQ + logP; + + // ringDim will change if newLogPQ is too large + auto ringDim = computeRingDim(logPQ, minRingDim); + bool redo = false; + std::vector qiImpl; + std::vector piImpl; + do { + redo = false; + qiImpl.clear(); + piImpl.clear(); + + if (reducedError) { + qiImpl = moduliQGenerationReducedError(logFirstMod, logDefaultScale, + numScaleMod + 1, ringDim); + } else { + qiImpl = moduliQGenerationOpenFHEFixed(logFirstMod, logDefaultScale, + numScaleMod + 1, ringDim); + } + std::vector existingPrimes = qiImpl; + for (size_t i = 0; i < numPi; ++i) { + auto prime = findPrime(sizePi, ringDim, existingPrimes); + piImpl.push_back(prime); + existingPrimes.push_back(prime); + } + + // if generated primes are too large, increase ringDim + double newLogPQ = 0; + for (auto qi : qiImpl) { + newLogPQ += log2(qi); + } + for (auto pi : piImpl) { + newLogPQ += log2(pi); + } + auto newRingDim = computeRingDim(newLogPQ, minRingDim); + if (newRingDim != ringDim) { + ringDim = newRingDim; + redo = true; + } + } while (redo); + + std::vector logqi; + std::vector logpi; + for (auto qi : qiImpl) { + logqi.push_back(log2(qi)); + } + for (auto pi : piImpl) { + logpi.push_back(log2(pi)); + } + + return SchemeParam( + RLWESchemeParam(ringDim, numScaleMod + 1, logqi, qiImpl, dnum, logpi, + piImpl, usePublicKey, encryptionTechniqueExtended), + logDefaultScale); } SchemeParam SchemeParam::getSchemeParamFromAttr(SchemeParamAttr attr) { diff --git a/lib/Parameters/CKKS/Params.h b/lib/Parameters/CKKS/Params.h index ae7eac75be..89780bf095 100644 --- a/lib/Parameters/CKKS/Params.h +++ b/lib/Parameters/CKKS/Params.h @@ -26,10 +26,9 @@ class SchemeParam : public RLWESchemeParam { int64_t getLogDefaultScale() const { return logDefaultScale; } void print(llvm::raw_ostream& os) const override; - static SchemeParam getConcreteSchemeParam(std::vector logqi, - int logDefaultScale, int slotNumber, - bool usePublicKey, - bool encryptionTechniqueExtended); + static SchemeParam getConcreteSchemeParam( + int logFirstMod, int logDefaultScale, int numScaleMod, int slotNumber, + bool usePublicKey, bool encryptionTechniqueExtended, bool reducedError); static SchemeParam getSchemeParamFromAttr(SchemeParamAttr attr); }; diff --git a/lib/Parameters/RLWEParams.h b/lib/Parameters/RLWEParams.h index d0ecb7fde9..b0c3def0a0 100644 --- a/lib/Parameters/RLWEParams.h +++ b/lib/Parameters/RLWEParams.h @@ -131,6 +131,14 @@ class RLWELocalParam { int getDimension() const { return dimension; } }; +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +int computeDnum(int level); +int64_t findPrime(int qi, int ringDim, + const std::vector& existingPrimes); + } // namespace heir } // namespace mlir diff --git a/lib/Transforms/GenerateParam/GenerateParam.td b/lib/Transforms/GenerateParam/GenerateParam.td index 1e67c38e35..db102a39b8 100644 --- a/lib/Transforms/GenerateParam/GenerateParam.td +++ b/lib/Transforms/GenerateParam/GenerateParam.td @@ -144,9 +144,19 @@ def GenerateParamCKKS : Pass<"generate-param-ckks"> { the ciphertext level/dimension. These ops and attributes can be added by a pass like `--secret-insert-mgmt-` and `--annotate-mgmt`. - User can provide custom scheme parameters by annotating bgv::SchemeParamAttr + User can provide custom scheme parameters by annotating ckks::SchemeParamAttr at the module level. + There are two prime selection implementations available: + 1. Fixed (from OpenFHE FIXED*) + 2. Reduced Error (from https://eprint.iacr.org/2020/1118, OpenFHE FLEXIBLE*) + + There is a toggle called `reduced-error` that can choose between them. + The default one is "Fixed". + + Reduced Error implementation works better with level-specific scaling + factor. + (* example filepath=tests/Transforms/generate_param_ckks/doctest.mlir *) }]; @@ -170,6 +180,9 @@ def GenerateParamCKKS : Pass<"generate-param-ckks"> { Option<"inputRange", "input-range", "int", /*default=*/"1", "The range of the plaintexts for input ciphertexts " "for the CKKS scheme; default to [-1, 1]. For other ranges like [-D, D], use D.">, + Option<"reducedError", "reduced-error", "bool", + /*default=*/"false", "If true, uses the prime selection logic in Reduced Error paper " + "(https://eprint.iacr.org/2020/1118).">, ]; } diff --git a/lib/Transforms/GenerateParam/GenerateParamCKKS.cpp b/lib/Transforms/GenerateParam/GenerateParamCKKS.cpp index 5663b1121d..29f6870e8d 100644 --- a/lib/Transforms/GenerateParam/GenerateParamCKKS.cpp +++ b/lib/Transforms/GenerateParam/GenerateParamCKKS.cpp @@ -122,13 +122,9 @@ struct GenerateParamCKKS : impl::GenerateParamCKKSBase { encryptionTechniqueExtended = true; } - // generate scheme parameters - std::vector logPrimes(maxLevel.value_or(0) + 1, scalingModBits); - logPrimes[0] = firstModBits; - auto schemeParam = ckks::SchemeParam::getConcreteSchemeParam( - logPrimes, scalingModBits, slotNumber, usePublicKey, - encryptionTechniqueExtended); + firstModBits, scalingModBits, maxLevel.value_or(0), slotNumber, + usePublicKey, encryptionTechniqueExtended, reducedError); LLVM_DEBUG(llvm::dbgs() << "Scheme Param:\n" << schemeParam << "\n"); diff --git a/tests/Transforms/generate_param_ckks/doctest.mlir b/tests/Transforms/generate_param_ckks/doctest.mlir index 163a56a24b..6ca5861850 100644 --- a/tests/Transforms/generate_param_ckks/doctest.mlir +++ b/tests/Transforms/generate_param_ckks/doctest.mlir @@ -1,6 +1,6 @@ // RUN: heir-opt --generate-param-ckks %s | FileCheck %s -// CHECK: {ckks.schemeParam = #ckks.scheme_param} +// CHECK: {ckks.schemeParam = #ckks.scheme_param} module { func.func @add(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) -> (!secret.secret {mgmt.mgmt = #mgmt.mgmt}) { %0 = secret.generic(%arg0: !secret.secret {mgmt.mgmt = #mgmt.mgmt}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt}} {