Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 216 additions & 6 deletions lib/Parameters/CKKS/Params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,230 @@
#include <vector>

#include "lib/Parameters/RLWEParams.h"
#include "lib/Parameters/RLWESecurityParams.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "src/core/include/openfhecore.h" // from @openfhe

namespace mlir {
namespace heir {
namespace ckks {

/// By original we mean the method in RNS-CKKS implementation
/// Corresponds to FIXED* in OpenFHE
static std::vector<int64_t> moduliQGenerationOpenFHEFixed(int logFirstMod,
int logDefaultScale,
int numLevel,
int ringDim) {
auto cyclOrder = ringDim * 2;
std::vector<int64_t> moduliQ(numLevel);
lbcrypto::NativeInteger q =
lbcrypto::FirstPrime<NativeInteger>(logDefaultScale, cyclOrder);
moduliQ[numLevel - 1] = q.ConvertToInt();

auto maxPrime{q};
auto minPrime{q};

auto qPrev = q;
auto qNext = q;
if (numLevel > 2) {
for (size_t i = numLevel - 2, cnt = 0; i >= 1; --i, ++cnt) {
if ((cnt % 2) == 0) {
qPrev = PreviousPrime(qPrev, cyclOrder);
moduliQ[i] = qPrev.ConvertToInt();
} else {
qNext = NextPrime(qNext, cyclOrder);
moduliQ[i] = qNext.ConvertToInt();
}

if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];
}
}

if (logFirstMod == logDefaultScale) { // this requires dcrtBits < 60
moduliQ[0] =
lbcrypto::NextPrime<NativeInteger>(maxPrime, cyclOrder).ConvertToInt();
} else {
moduliQ[0] = lbcrypto::LastPrime<NativeInteger>(logFirstMod, cyclOrder)
.ConvertToInt();

// find if the value of moduliQ[0] is already in the vector starting with
// moduliQ[1] and if there is, then get another prime for moduliQ[0]
const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]);
if (pos != moduliQ.end()) {
moduliQ[0] = lbcrypto::NextPrime<NativeInteger>(maxPrime, cyclOrder)
.ConvertToInt();
}
}
return moduliQ;
}

/// See "Reduced Error" paper https://eprint.iacr.org/2020/1118
/// Corresponds to FLEXIBLE* in OpenFHE
static std::vector<int64_t> moduliQGenerationReducedError(int logFirstMod,
int logDefaultScale,
int numLevel,
int ringDim) {
auto cyclOrder = ringDim * 2;
std::vector<int64_t> moduliQ(numLevel);
lbcrypto::NativeInteger q =
lbcrypto::FirstPrime<lbcrypto::NativeInteger>(logDefaultScale, cyclOrder);
moduliQ[numLevel - 1] = q.ConvertToInt();

auto maxPrime{q};
auto minPrime{q};

if (numLevel > 2) {
for (size_t i = numLevel - 2, cnt = 0; i >= 1; --i, ++cnt) {
// Comments from OpenFHE ckksrns-parametergeneration.cpp
/* Scaling factors in FLEXIBLEAUTO are a bit fragile,
* in the sense that once one scaling factor gets far enough from the
* original scaling factor, subsequent level scaling factors quickly
* diverge to either 0 or infinity. To mitigate this problem to a certain
* extend, we have a special prime selection process in place. The goal is
* to maintain the scaling factor of all levels as close to the original
* scale factor of level 0 as possible.
*/
double sf = static_cast<double>(moduliQ[numLevel - 1]);
for (size_t i = numLevel - 2, cnt = 0; i >= 1; --i, ++cnt) {
sf = pow(sf, 2) / static_cast<double>(moduliQ[i + 1]);
NativeInteger sfInt = std::llround(sf);
NativeInteger sfRem = sfInt.Mod(cyclOrder);
bool hasSameMod = true;
if ((cnt % 2) == 0) {
NativeInteger qPrev =
sfInt - NativeInteger(cyclOrder) - sfRem + NativeInteger(1);
while (hasSameMod) {
hasSameMod = false;
qPrev = lbcrypto::PreviousPrime(qPrev, cyclOrder);
for (size_t j = i + 1; j < numLevel; j++) {
if (qPrev == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qPrev.ConvertToInt();
} else {
NativeInteger qNext =
sfInt + NativeInteger(cyclOrder) - sfRem + NativeInteger(1);
while (hasSameMod) {
hasSameMod = false;
qNext = lbcrypto::NextPrime(qNext, cyclOrder);
for (size_t j = i + 1; j < numLevel; j++) {
if (qNext == moduliQ[j]) {
hasSameMod = true;
break;
}
}
}
moduliQ[i] = qNext.ConvertToInt();
}
if (moduliQ[i] > maxPrime)
maxPrime = moduliQ[i];
else if (moduliQ[i] < minPrime)
minPrime = moduliQ[i];
}
}
}

if (logFirstMod == logDefaultScale) { // this requires dcrtBits < 60
moduliQ[0] =
lbcrypto::NextPrime<lbcrypto::NativeInteger>(maxPrime, cyclOrder)
.ConvertToInt();
} else {
moduliQ[0] =
lbcrypto::LastPrime<lbcrypto::NativeInteger>(logFirstMod, cyclOrder)
.ConvertToInt();

// find if the value of moduliQ[0] is already in the vector starting with
// moduliQ[1] and if there is, then get another prime for moduliQ[0]
const auto pos = std::find(moduliQ.begin() + 1, moduliQ.end(), moduliQ[0]);
if (pos != moduliQ.end()) {
moduliQ[0] =
lbcrypto::NextPrime<lbcrypto::NativeInteger>(maxPrime, cyclOrder)
.ConvertToInt();
}
}
return moduliQ;
}

// numScaleMod is L
SchemeParam SchemeParam::getConcreteSchemeParam(
std::vector<double> logqi, int logDefaultScale, int slotNumber,
bool usePublicKey, bool encryptionTechniqueExtended) {
int logFirstMod, int logDefaultScale, int numScaleMod, int slotNumber,
bool usePublicKey, bool encryptionTechniqueExtended, bool reducedError) {
// CKKS slot number = ringDim / 2
return SchemeParam(RLWESchemeParam::getConcreteRLWESchemeParam(
std::move(logqi), 2 * slotNumber, usePublicKey,
encryptionTechniqueExtended),
logDefaultScale);
auto minRingDim = 2 * slotNumber;

auto dnum = computeDnum(numScaleMod);

// q0 + (q1 + ... + qL) = firstModBits + scalingModBits * L
double logQ = logFirstMod + logDefaultScale * numScaleMod;
// pi can be large
auto sizePi = 60;

// make P > Q / dnum
auto logP = ceil(logQ / dnum);
auto numPi = ceil(logP / sizePi);
// update logP
logP = numPi * sizePi;

auto logPQ = logQ + logP;

// ringDim will change if newLogPQ is too large
auto ringDim = computeRingDim(logPQ, minRingDim);
bool redo = false;
std::vector<int64_t> qiImpl;
std::vector<int64_t> piImpl;
do {
redo = false;
qiImpl.clear();
piImpl.clear();

if (reducedError) {
qiImpl = moduliQGenerationReducedError(logFirstMod, logDefaultScale,
numScaleMod + 1, ringDim);
} else {
qiImpl = moduliQGenerationOpenFHEFixed(logFirstMod, logDefaultScale,
numScaleMod + 1, ringDim);
}
std::vector<int64_t> existingPrimes = qiImpl;
for (size_t i = 0; i < numPi; ++i) {
auto prime = findPrime(sizePi, ringDim, existingPrimes);
piImpl.push_back(prime);
existingPrimes.push_back(prime);
}

// if generated primes are too large, increase ringDim
double newLogPQ = 0;
for (auto qi : qiImpl) {
newLogPQ += log2(qi);
}
for (auto pi : piImpl) {
newLogPQ += log2(pi);
}
auto newRingDim = computeRingDim(newLogPQ, minRingDim);
if (newRingDim != ringDim) {
ringDim = newRingDim;
redo = true;
}
} while (redo);

std::vector<double> logqi;
std::vector<double> logpi;
for (auto qi : qiImpl) {
logqi.push_back(log2(qi));
}
for (auto pi : piImpl) {
logpi.push_back(log2(pi));
}

return SchemeParam(
RLWESchemeParam(ringDim, numScaleMod + 1, logqi, qiImpl, dnum, logpi,
piImpl, usePublicKey, encryptionTechniqueExtended),
logDefaultScale);
}

SchemeParam SchemeParam::getSchemeParamFromAttr(SchemeParamAttr attr) {
Expand Down
7 changes: 3 additions & 4 deletions lib/Parameters/CKKS/Params.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ class SchemeParam : public RLWESchemeParam {
int64_t getLogDefaultScale() const { return logDefaultScale; }
void print(llvm::raw_ostream& os) const override;

static SchemeParam getConcreteSchemeParam(std::vector<double> logqi,
int logDefaultScale, int slotNumber,
bool usePublicKey,
bool encryptionTechniqueExtended);
static SchemeParam getConcreteSchemeParam(
int logFirstMod, int logDefaultScale, int numScaleMod, int slotNumber,
bool usePublicKey, bool encryptionTechniqueExtended, bool reducedError);

static SchemeParam getSchemeParamFromAttr(SchemeParamAttr attr);
};
Expand Down
8 changes: 8 additions & 0 deletions lib/Parameters/RLWEParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ class RLWELocalParam {
int getDimension() const { return dimension; }
};

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//

int computeDnum(int level);
int64_t findPrime(int qi, int ringDim,
const std::vector<int64_t>& existingPrimes);

} // namespace heir
} // namespace mlir

Expand Down
15 changes: 14 additions & 1 deletion lib/Transforms/GenerateParam/GenerateParam.td
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,19 @@ def GenerateParamCKKS : Pass<"generate-param-ckks"> {
the ciphertext level/dimension. These ops and attributes can be added by
a pass like `--secret-insert-mgmt-<scheme>` and `--annotate-mgmt`.

User can provide custom scheme parameters by annotating bgv::SchemeParamAttr
User can provide custom scheme parameters by annotating ckks::SchemeParamAttr
at the module level.

There are two prime selection implementations available:
1. Fixed (from OpenFHE FIXED*)
2. Reduced Error (from https://eprint.iacr.org/2020/1118, OpenFHE FLEXIBLE*)

There is a toggle called `reduced-error` that can choose between them.
The default one is "Fixed".

Reduced Error implementation works better with level-specific scaling
factor.

(* example filepath=tests/Transforms/generate_param_ckks/doctest.mlir *)
}];

Expand All @@ -170,6 +180,9 @@ def GenerateParamCKKS : Pass<"generate-param-ckks"> {
Option<"inputRange", "input-range", "int",
/*default=*/"1", "The range of the plaintexts for input ciphertexts "
"for the CKKS scheme; default to [-1, 1]. For other ranges like [-D, D], use D.">,
Option<"reducedError", "reduced-error", "bool",
/*default=*/"false", "If true, uses the prime selection logic in Reduced Error paper "
"(https://eprint.iacr.org/2020/1118).">,
];
}

Expand Down
8 changes: 2 additions & 6 deletions lib/Transforms/GenerateParam/GenerateParamCKKS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,9 @@ struct GenerateParamCKKS : impl::GenerateParamCKKSBase<GenerateParamCKKS> {
encryptionTechniqueExtended = true;
}

// generate scheme parameters
std::vector<double> logPrimes(maxLevel.value_or(0) + 1, scalingModBits);
logPrimes[0] = firstModBits;

auto schemeParam = ckks::SchemeParam::getConcreteSchemeParam(
logPrimes, scalingModBits, slotNumber, usePublicKey,
encryptionTechniqueExtended);
firstModBits, scalingModBits, maxLevel.value_or(0), slotNumber,
usePublicKey, encryptionTechniqueExtended, reducedError);

LLVM_DEBUG(llvm::dbgs() << "Scheme Param:\n" << schemeParam << "\n");

Expand Down
2 changes: 1 addition & 1 deletion tests/Transforms/generate_param_ckks/doctest.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: heir-opt --generate-param-ckks %s | FileCheck %s

// CHECK: {ckks.schemeParam = #ckks.scheme_param<logN = 13, Q = [36028797019389953], P = [36028797019488257], logDefaultScale = 45>}
// CHECK: {ckks.schemeParam = #ckks.scheme_param<logN = 13, Q = [36028797018652673], P = [1152921504606994433], logDefaultScale = 45>}
module {
func.func @add(%arg0: !secret.secret<f16> {mgmt.mgmt = #mgmt.mgmt<level = 0>}) -> (!secret.secret<f16> {mgmt.mgmt = #mgmt.mgmt<level = 0>}) {
%0 = secret.generic(%arg0: !secret.secret<f16> {mgmt.mgmt = #mgmt.mgmt<level = 0>}) attrs = {arg0 = {mgmt.mgmt = #mgmt.mgmt<level = 0>}} {
Expand Down
Loading