Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 107 additions & 59 deletions lib/Analysis/ScaleAnalysis/ScaleAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "lib/Analysis/ScaleAnalysis/ScaleAnalysis.h"

#include <cassert>
#include <cmath>
#include <cstdint>
#include <functional>

Expand All @@ -14,6 +15,7 @@
#include "lib/Utils/APIntUtils.h"
#include "lib/Utils/AttributeUtils.h"
#include "lib/Utils/Utils.h"
#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
Expand All @@ -36,70 +38,110 @@ namespace heir {
// ScaleModel
//===----------------------------------------------------------------------===//

int64_t BGVScaleModel::evalMulScale(const bgv::LocalParam& param, int64_t lhs,
int64_t rhs) {
llvm::APInt BGVScaleModel::evalMulScale(const bgv::LocalParam& param,
const llvm::APInt& lhs,
const llvm::APInt& rhs) {
const auto* schemeParam = param.getSchemeParam();
auto t = schemeParam->getPlaintextModulus();
return lhs * rhs % t;
auto t = llvm::APInt(64, schemeParam->getPlaintextModulus());
// (lhs * rhs) % t
return (lhs * rhs).urem(t);
}

int64_t BGVScaleModel::evalMulScaleBackward(const bgv::LocalParam& param,
int64_t result, int64_t lhs) {
llvm::APInt BGVScaleModel::evalMulScaleBackward(const bgv::LocalParam& param,
const llvm::APInt& result,
const llvm::APInt& lhs) {
const auto* schemeParam = param.getSchemeParam();
auto t = schemeParam->getPlaintextModulus();
auto lhsInv = multiplicativeInverse(APInt(64, lhs), APInt(64, t));
return result * lhsInv.getSExtValue() % t;
auto t = llvm::APInt(64, schemeParam->getPlaintextModulus());
auto lhsInv = multiplicativeInverse(lhs, t);
// (result * lhsInv) % t
return (result * lhsInv).urem(t);
}

int64_t BGVScaleModel::evalModReduceScale(const bgv::LocalParam& inputParam,
int64_t scale) {
llvm::APInt BGVScaleModel::evalModReduceScale(const bgv::LocalParam& inputParam,
const llvm::APInt& scale) {
const auto* schemeParam = inputParam.getSchemeParam();
auto t = schemeParam->getPlaintextModulus();
auto t = llvm::APInt(64, schemeParam->getPlaintextModulus());
auto qi = schemeParam->getQi();
auto level = inputParam.getCurrentLevel();
auto qInvT = multiplicativeInverse(APInt(64, qi[level] % t), APInt(64, t));
return scale * qInvT.getSExtValue() % t;
auto qiModT = llvm::APInt(64, qi[level]).urem(t);
auto qInvT = multiplicativeInverse(qiModT, t);
// (scale * qInvT) % t
return (scale * qInvT).urem(t);
}

int64_t BGVScaleModel::evalModReduceScaleBackward(
const bgv::LocalParam& inputParam, int64_t resultScale) {
llvm::APInt BGVScaleModel::evalModReduceScaleBackward(
const bgv::LocalParam& inputParam, const llvm::APInt& resultScale) {
const auto* schemeParam = inputParam.getSchemeParam();
auto t = schemeParam->getPlaintextModulus();
auto t = llvm::APInt(64, schemeParam->getPlaintextModulus());
auto qi = schemeParam->getQi();
auto level = inputParam.getCurrentLevel();
return resultScale * (qi[level] % t) % t;
auto qiModT = llvm::APInt(64, qi[level]).urem(t);
// (resultScale * qiModT) % t
return (resultScale * qiModT).urem(t);
}

int64_t CKKSScaleModel::evalMulScale(const ckks::LocalParam& param, int64_t lhs,
int64_t rhs) {
// TODO(#1640): support high-precision scale management
return lhs + rhs;
llvm::APInt CKKSScaleModel::evalMulScale(const ckks::LocalParam& param,
const llvm::APInt& lhs,
const llvm::APInt& rhs) {
// High-precision scale management (#2364): multiply actual scales
// In CKKS, ct1 * ct2 with scales s1, s2 produces ciphertext with scale s1*s2
return lhs * rhs;
}

int64_t CKKSScaleModel::evalMulScaleBackward(const ckks::LocalParam& param,
int64_t result, int64_t lhs) {
// TODO(#1640): support high-precision scale management
return result - lhs;
llvm::APInt CKKSScaleModel::evalMulScaleBackward(const ckks::LocalParam& param,
const llvm::APInt& result,
const llvm::APInt& lhs) {
// High-precision scale management (#2364): divide actual scales
// If result = lhs * rhs, then rhs = result / lhs
// Handle uninitialized scales (zero) - can't infer from uninitialized values
if (lhs.isZero() || result.isZero()) {
return llvm::APInt(64, 0);
}
return result.udiv(lhs);
}

int64_t CKKSScaleModel::evalModReduceScale(const ckks::LocalParam& inputParam,
int64_t scale) {
llvm::APInt CKKSScaleModel::evalModReduceScale(
const ckks::LocalParam& inputParam, const llvm::APInt& scale) {
const auto* schemeParam = inputParam.getSchemeParam();
// TODO(#1640): rescale using logqi instead of logDefaultScale
// auto logqi = schemeParam->getLogqi();
// auto level = inputParam.getCurrentLevel();
auto logDefaultScale = schemeParam->getLogDefaultScale();
return scale - logDefaultScale;
// High-precision scale management (#2364): divide by actual qi value
// Rescaling divides the ciphertext modulus by qi[level] and the scale by
// qi[level]
auto qi = schemeParam->getQi();
auto level = inputParam.getCurrentLevel();

// If qi is not populated, fall back to using default scale
if (qi.empty() || level >= static_cast<int>(qi.size())) {
// Fallback: use 2^logDefaultScale as the rescaling factor
auto logDefaultScale = schemeParam->getLogDefaultScale();
llvm::APInt defaultScale(scale.getBitWidth(), 1);
defaultScale = defaultScale.shl(logDefaultScale);
return scale.udiv(defaultScale);
}

llvm::APInt qiVal(scale.getBitWidth(), qi[level]);
return scale.udiv(qiVal);
}

int64_t CKKSScaleModel::evalModReduceScaleBackward(
const ckks::LocalParam& inputParam, int64_t resultScale) {
llvm::APInt CKKSScaleModel::evalModReduceScaleBackward(
const ckks::LocalParam& inputParam, const llvm::APInt& resultScale) {
const auto* schemeParam = inputParam.getSchemeParam();
// TODO(#1640): rescale using logqi instead of logDefaultScale
// auto logqi = schemeParam->getLogqi();
// auto level = inputParam.getCurrentLevel();
auto logDefaultScale = schemeParam->getLogDefaultScale();
return resultScale + logDefaultScale;
// High-precision scale management (#2364): multiply by actual qi value
// Reverse of evalModReduceScale: if result = scale / qi, then scale = result
// * qi
auto qi = schemeParam->getQi();
auto level = inputParam.getCurrentLevel();

// If qi is not populated, fall back to using default scale
if (qi.empty() || level >= static_cast<int>(qi.size())) {
// Fallback: use 2^logDefaultScale as the rescaling factor
auto logDefaultScale = schemeParam->getLogDefaultScale();
llvm::APInt defaultScale(resultScale.getBitWidth(), 1);
defaultScale = defaultScale.shl(logDefaultScale);
return resultScale * defaultScale;
}

llvm::APInt qiVal(resultScale.getBitWidth(), qi[level]);
return resultScale * qiVal;
}

//===----------------------------------------------------------------------===//
Expand All @@ -126,7 +168,8 @@ LogicalResult ScaleAnalysis<ScaleModelT>::visitOperation(
propagateIfChanged(lattice, changed);
};

auto getOperandScales = [&](Operation* op, SmallVectorImpl<int64_t>& scales) {
auto getOperandScales = [&](Operation* op,
SmallVectorImpl<llvm::APInt>& scales) {
SmallVector<OpOperand*> secretOperands;
this->getSecretOperands(op, secretOperands);

Expand All @@ -147,7 +190,7 @@ LogicalResult ScaleAnalysis<ScaleModelT>::visitOperation(

llvm::TypeSwitch<Operation&>(*op)
.template Case<arith::MulIOp, arith::MulFOp>([&](auto mulOp) {
SmallVector<int64_t> scales;
SmallVector<llvm::APInt> scales;
getOperandScales(mulOp, scales);
// there must be at least one secret operand that has scale
if (scales.empty()) {
Expand All @@ -166,7 +209,7 @@ LogicalResult ScaleAnalysis<ScaleModelT>::visitOperation(
propagate(mulOp.getResult(), ScaleState(result));
})
.template Case<mgmt::ModReduceOp>([&](auto modReduceOp) {
SmallVector<int64_t> scales;
SmallVector<llvm::APInt> scales;
getOperandScales(modReduceOp, scales);
// there must be at least one secret operand that has scale
if (scales.empty()) {
Expand Down Expand Up @@ -200,7 +243,7 @@ LogicalResult ScaleAnalysis<ScaleModelT>::visitOperation(
return;
}

SmallVector<int64_t> scales;
SmallVector<llvm::APInt> scales;
getOperandScales(&op, scales);
if (scales.empty()) {
return;
Expand Down Expand Up @@ -267,7 +310,7 @@ LogicalResult ScaleAnalysisBackward<ScaleModelT>::visitOperation(

auto getOperandScales =
[&](Operation* op, SmallVectorImpl<int64_t>& operandWithoutScaleIndices,
SmallVectorImpl<int64_t>& scales) {
SmallVectorImpl<llvm::APInt>& scales) {
LLVM_DEBUG(llvm::dbgs()
<< "Operand scales for " << op->getName() << ": ");
SmallVector<OpOperand*> secretOperands;
Expand All @@ -294,7 +337,8 @@ LogicalResult ScaleAnalysisBackward<ScaleModelT>::visitOperation(
LLVM_DEBUG(llvm::dbgs() << "\n");
};

auto getResultScales = [&](Operation* op, SmallVectorImpl<int64_t>& scales) {
auto getResultScales = [&](Operation* op,
SmallVectorImpl<llvm::APInt>& scales) {
LLVM_DEBUG(llvm::dbgs() << "Result scales for " << op->getName() << ": ");
SmallVector<OpResult> secretResults;
this->getSecretResults(op, secretResults);
Expand All @@ -315,14 +359,14 @@ LogicalResult ScaleAnalysisBackward<ScaleModelT>::visitOperation(
<< "\n");
llvm::TypeSwitch<Operation&>(*op)
.template Case<arith::MulIOp, arith::MulFOp>([&](auto mulOp) {
SmallVector<int64_t> resultScales;
SmallVector<llvm::APInt> resultScales;
getResultScales(mulOp, resultScales);
// there must be at least one secret result that has scale
if (resultScales.empty()) {
return;
}
SmallVector<int64_t> operandWithoutScaleIndices;
SmallVector<int64_t> operandScales;
SmallVector<llvm::APInt> operandScales;
getOperandScales(mulOp, operandWithoutScaleIndices, operandScales);
// there must be at least one secret operand that has scale
if (operandScales.empty()) {
Expand All @@ -342,14 +386,14 @@ LogicalResult ScaleAnalysisBackward<ScaleModelT>::visitOperation(
ScaleState(scaleOther));
})
.template Case<mgmt::ModReduceOp>([&](auto modReduceOp) {
SmallVector<int64_t> resultScales;
SmallVector<llvm::APInt> resultScales;
getResultScales(modReduceOp, resultScales);
// there must be at least one secret result that has scale
if (resultScales.empty()) {
return;
}
SmallVector<int64_t> operandWithoutScaleIndices;
SmallVector<int64_t> scales;
SmallVector<llvm::APInt> scales;
getOperandScales(modReduceOp, operandWithoutScaleIndices, scales);
// if all operands have scale, succeed.
if (!scales.empty()) {
Expand All @@ -376,7 +420,7 @@ LogicalResult ScaleAnalysisBackward<ScaleModelT>::visitOperation(
return;
}

SmallVector<int64_t> scales;
SmallVector<llvm::APInt> scales;
getResultScales(&op, scales);
if (scales.empty()) {
return;
Expand All @@ -399,37 +443,41 @@ template class ScaleAnalysisBackward<CKKSScaleModel>;
// Utils
//===----------------------------------------------------------------------===//

int64_t getScale(Value value, DataFlowSolver* solver) {
llvm::APInt getScale(Value value, DataFlowSolver* solver) {
auto* lattice = solver->lookupState<ScaleLattice>(value);
if (!lattice) {
assert(false && "ScaleLattice not found");
return 0;
return llvm::APInt(64, 0);
}
if (!lattice->getValue().isInitialized()) {
assert(false && "ScaleLattice not initialized");
return 0;
return llvm::APInt(64, 0);
}
return lattice->getValue().getScale();
}

int64_t getScaleFromMgmtAttr(Value value) {
llvm::APInt getScaleFromMgmtAttr(Value value) {
auto mgmtAttr = mgmt::findMgmtAttrAssociatedWith(value);
if (!mgmtAttr) {
assert(false && "MgmtAttr not found");
return 0;
return llvm::APInt(64, 0);
}
// High-precision scale management (#2364): MgmtAttr now stores APInt directly
return mgmtAttr.getScale();
}

void annotateScale(Operation* top, DataFlowSolver* solver) {
auto getIntegerAttr = [&](int scale) {
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), scale);
auto getStringAttr = [&](const llvm::APInt& scale) {
// Store APInt as a string in base 10 for full precision
llvm::SmallString<64> str;
scale.toString(str, 10, /*Signed=*/false);
return StringAttr::get(top->getContext(), str);
};

walkValues(top, [&](Value value) {
if (mgmt::shouldHaveMgmtAttribute(value, solver)) {
setAttributeAssociatedWith(value, kArgScaleAttrName,
getIntegerAttr(getScale(value, solver)));
getStringAttr(getScale(value, solver)));
}
});
}
Expand Down
Loading
Loading