Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def GenParamsOp : JaxiteWord_Op<"gen_params"> {
- numEvalMult: Number of evaluation multiplications
}];
let arguments = (ins
JaxiteWord_PublicKey:$publicKey,
JaxiteWord_PrivateKey:$secretKey,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the server need to generate parameters independently of any particular user? Taking a public/private/eval key as input seems like it would make it impossible to have a user-independent FHE program.

And for that matter, wouldn't you need to decide on all these parameters before generating key material? Surely the evaluation key depends on the data pre-computed by the crypto context at initialization...

JaxiteWord_EvalKey:$evaluationKey,
// Scheme parameters
I64Attr:$degree,
I64Attr:$numSlots,
Expand Down Expand Up @@ -134,7 +137,6 @@ def ProgramInitializationOp : JaxiteWord_Op<"program_initialization"> {
}];
let arguments = (ins
JaxiteWord_CryptoContext:$cryptoContext,
JaxiteWord_PrivateKey:$secretKey,
I64Attr:$totalHemulLevels,
DenseI64ArrayAttr:$totalRotationIndices,
I32Attr:$dnum,
Expand Down
19 changes: 13 additions & 6 deletions lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,24 @@ struct ConfigureCryptoContext
LogicalResult generateGenFunc(func::FuncOp op, const std::string& genFuncName,
ImplicitLocOpBuilder& builder) {
Type ccType = CryptoContextType::get(builder.getContext());
Type pkType = PublicKeyType::get(builder.getContext());
Type skType = PrivateKeyType::get(builder.getContext());
Type ekType = EvalKeyType::get(builder.getContext());

SmallVector<Type> funcArgTypes = {pkType, skType, ekType};
SmallVector<Type> funcResultTypes = {ccType};

FunctionType genFuncType =
FunctionType::get(builder.getContext(), {}, funcResultTypes);
FunctionType::get(builder.getContext(), funcArgTypes, funcResultTypes);
auto genFuncOp = func::FuncOp::create(builder, genFuncName, genFuncType);
builder.setInsertionPointToEnd(genFuncOp.addEntryBlock());

Value publicKey = genFuncOp.getArgument(0);
Value secretKey = genFuncOp.getArgument(1);
Value evaluationKey = genFuncOp.getArgument(2);

Value cryptoContext = GenParamsOp::create(
builder, ccType,
builder, ccType, publicKey, secretKey, evaluationKey,
/*degree=*/static_cast<uint64_t>(config.degree),
/*numSlots=*/static_cast<uint64_t>(config.numSlots),
/*scalingFactor=*/llvm::APFloat(config.scalingFactor),
Expand All @@ -119,9 +128,8 @@ struct ConfigureCryptoContext
const std::string& configFuncName,
ImplicitLocOpBuilder& builder) {
Type ccType = CryptoContextType::get(builder.getContext());
Type skType = PrivateKeyType::get(builder.getContext());

SmallVector<Type> funcArgTypes = {ccType, skType};
SmallVector<Type> funcArgTypes = {ccType};
SmallVector<Type> funcResultTypes;

FunctionType configFuncType =
Expand All @@ -131,10 +139,9 @@ struct ConfigureCryptoContext
builder.setInsertionPointToEnd(configFuncOp.addEntryBlock());

Value cryptoContext = configFuncOp.getArgument(0);
Value secretKey = configFuncOp.getArgument(1);

ProgramInitializationOp::create(
builder, cryptoContext, secretKey,
builder, cryptoContext,
/*totalHemulLevels=*/static_cast<int64_t>(config.mulDepth),
/*totalRotationIndices=*/config.rotIndices,
/*dnum=*/config.dnum,
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/JaxiteWord/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ def ConfigureCryptoContext : Pass<"jaxiteword-configure-crypto-context"> {
For example, for an MLIR function `@my_func`, the generated helpers have
the following signatures:
```mlir
func.func @my_func__generate_crypto_context() -> !jaxiteword.crypto_context
func.func @my_func__generate_crypto_context(
!jaxiteword.public_key, !jaxiteword.private_key,
!jaxiteword.eval_key) -> !jaxiteword.crypto_context

func.func @my_func__configure_crypto_context(
!jaxiteword.crypto_context, !jaxiteword.private_key)
!jaxiteword.crypto_context)
```
}];
let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect"];
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cc_library(
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncTransforms",
"@llvm-project//mlir:IR",
Expand Down
160 changes: 138 additions & 22 deletions lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Utils/ConversionUtils.h"
#include "lib/Utils/Utils.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand Down Expand Up @@ -53,6 +54,60 @@ class JaxiteWordTypeConverter : public TypeConverter {

namespace {

bool containsCryptoArgument(func::FuncOp funcOp) {
return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) {
return DialectEqual<lwe::LWEDialect, ckks::CKKSDialect, bgv::BGVDialect>()(
&getElementTypeOrSelf(argType).getDialect());
});
}

bool funcNeedsCryptoContextAndKeys(func::FuncOp funcOp) {
return containsDialects<lwe::LWEDialect, ckks::CKKSDialect, bgv::BGVDialect>(
funcOp) ||
containsCryptoArgument(funcOp);
}

void insertCryptoContextAndKeys(func::FuncOp funcOp) {
if (!funcNeedsCryptoContextAndKeys(funcOp)) return;
if (funcOp.getFunctionType().getNumInputs() >= 2 &&
mlir::isa<jaxiteword::CryptoContextType>(
funcOp.getFunctionType().getInput(0)) &&
mlir::isa<jaxiteword::EvalKeyType>(
funcOp.getFunctionType().getInput(1))) {
return;
}
auto cryptoContextType =
jaxiteword::CryptoContextType::get(funcOp.getContext());
auto evalKeyType = jaxiteword::EvalKeyType::get(funcOp.getContext());
(void)funcOp.insertArgument(0, evalKeyType, nullptr, funcOp.getLoc());
(void)funcOp.insertArgument(0, cryptoContextType, nullptr, funcOp.getLoc());
}

void updateCryptoFuncCalls(Operation* op) {
op->walk([&](func::CallOp callOp) {
auto callee = getCalledFunction(callOp);
if (failed(callee) || !funcNeedsCryptoContextAndKeys(callee.value())) {
return;
}
if (callOp.getNumOperands() == callee.value().getNumArguments()) {
return;
}
auto caller = callOp->getParentOfType<func::FuncOp>();
if (!caller || caller.getNumArguments() < 2 ||
!mlir::isa<jaxiteword::CryptoContextType>(
caller.getArgument(0).getType()) ||
!mlir::isa<jaxiteword::EvalKeyType>(caller.getArgument(1).getType())) {
return;
}
SmallVector<Value> newOperands;
newOperands.push_back(caller.getArgument(0));
newOperands.push_back(caller.getArgument(1));
newOperands.append(callOp.getOperands().begin(),
callOp.getOperands().end());
callOp->setOperands(newOperands);
});
}

FailureOr<Value> getContextualCryptoContextForJaxiteWord(Operation* op) {
auto funcOp = op->getParentOfType<func::FuncOp>();
if (!funcOp) return failure();
Expand All @@ -75,6 +130,25 @@ FailureOr<Value> getContextualEvalKeyForJaxiteWord(Operation* op) {
return funcOp.getArgument(1);
}

static FailureOr<IntegerAttr> getStaticRotationIndex(ckks::RotateOp op,
Value dynamicShift) {
auto i64Type = IntegerType::get(op.getContext(), 64);
if (IntegerAttr staticShift = op.getStaticShiftAttr()) {
return IntegerAttr::get(i64Type, staticShift.getValue().getSExtValue());
}
if (!dynamicShift) {
return failure();
}
auto constOp = dynamicShift.getDefiningOp<arith::ConstantOp>();
if (!constOp) {
return failure();
}
if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) {
return IntegerAttr::get(i64Type, intAttr.getValue().getSExtValue());
}
return failure();
}

struct AddCryptoContextAndKeys : public OpConversionPattern<func::FuncOp> {
AddCryptoContextAndKeys(mlir::MLIRContext* context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}
Expand All @@ -84,10 +158,7 @@ struct AddCryptoContextAndKeys : public OpConversionPattern<func::FuncOp> {
LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto containsCryptoOps =
::mlir::heir::containsDialects<lwe::LWEDialect, ckks::CKKSDialect,
bgv::BGVDialect>(op);
if (!containsCryptoOps) return failure();
if (!funcNeedsCryptoContextAndKeys(op)) return failure();

auto cryptoContextType = jaxiteword::CryptoContextType::get(getContext());
auto evalKeyType = jaxiteword::EvalKeyType::get(getContext());
Expand All @@ -108,6 +179,43 @@ struct AddCryptoContextAndKeys : public OpConversionPattern<func::FuncOp> {
}
};

struct ConvertFuncCallOp : public OpConversionPattern<func::CallOp> {
ConvertFuncCallOp(mlir::MLIRContext* context)
: OpConversionPattern<func::CallOp>(context) {}

using OpConversionPattern<func::CallOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
func::CallOp op, typename func::CallOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto callee = getCalledFunction(op);
if (failed(callee) || !funcNeedsCryptoContextAndKeys(callee.value())) {
return failure();
}
if (op.getNumOperands() >= callee.value().getNumArguments()) {
return failure();
}

FailureOr<Value> ctx = getContextualCryptoContextForJaxiteWord(op);
if (failed(ctx)) return failure();
FailureOr<Value> evalKey = getContextualEvalKeyForJaxiteWord(op);
if (failed(evalKey)) return failure();

SmallVector<Value> newOperands;
newOperands.push_back(ctx.value());
newOperands.push_back(evalKey.value());
newOperands.append(adaptor.getOperands().begin(),
adaptor.getOperands().end());

SmallVector<NamedAttribute> dialectAttrs(op->getDialectAttrs());
rewriter
.replaceOpWithNewOp<func::CallOp>(op, op.getCallee(),
op.getResultTypes(), newOperands)
->setDialectAttrs(dialectAttrs);
return success();
}
};

template <typename SourceOp, typename TargetOp>
struct ConvertBinOp : public OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
Expand Down Expand Up @@ -173,20 +281,16 @@ struct ConvertRotateOp : public OpConversionPattern<ckks::RotateOp> {
FailureOr<Value> evalKey = getContextualEvalKeyForJaxiteWord(op);
if (failed(evalKey)) return failure();

Value dynamicShift = adaptor.getDynamicShift();
IntegerAttr staticShift = op.getStaticShiftAttr();
if (!staticShift && !dynamicShift) {
FailureOr<IntegerAttr> indexAttr =
getStaticRotationIndex(op, adaptor.getDynamicShift());
if (failed(indexAttr)) {
return rewriter.notifyMatchFailure(
op, "rotate op must have either static or dynamic shift");
}
if (dynamicShift) {
return rewriter.notifyMatchFailure(
op, "jaxiteword rotation requires static shift");
op, "jaxiteword rotation requires statically known shift");
}

rewriter.replaceOpWithNewOp<jaxiteword::RotOp>(
op, this->getTypeConverter()->convertType(op.getOutput().getType()),
ctx.value(), adaptor.getInput(), evalKey.value(), staticShift);
ctx.value(), adaptor.getInput(), evalKey.value(), indexAttr.value());
return success();
}
};
Expand Down Expand Up @@ -298,6 +402,9 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase<LWEToJaxiteWord> {
MLIRContext* context = &getContext();
Operation* op = getOperation();

op->walk([&](func::FuncOp funcOp) { insertCryptoContextAndKeys(funcOp); });
updateCryptoFuncCalls(op);

RewritePatternSet patterns(context);
ConversionTarget target(*context);

Expand All @@ -309,24 +416,33 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase<LWEToJaxiteWord> {

JaxiteWordTypeConverter typeConverter(context);

target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
auto containsCryptoOps =
::mlir::heir::containsDialects<lwe::LWEDialect, ckks::CKKSDialect,
bgv::BGVDialect>(op);
if (!containsCryptoOps) return true;
bool hasArgs = op.getFunctionType().getNumInputs() >= 2;
return typeConverter.isSignatureLegal(op.getFunctionType()) && hasArgs &&
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
if (!funcNeedsCryptoContextAndKeys(funcOp)) return true;
bool hasArgs = funcOp.getFunctionType().getNumInputs() >= 2;
return typeConverter.isSignatureLegal(funcOp.getFunctionType()) &&
hasArgs &&
mlir::isa<jaxiteword::CryptoContextType>(
op.getFunctionType().getInput(0)) &&
funcOp.getFunctionType().getInput(0)) &&
mlir::isa<jaxiteword::EvalKeyType>(
op.getFunctionType().getInput(1));
funcOp.getFunctionType().getInput(1));
});

target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp callOp) {
if (auto callee = getCalledFunction(callOp); succeeded(callee)) {
if (funcNeedsCryptoContextAndKeys(callee.value())) {
return callOp.getNumOperands() ==
callOp.getCalleeType().getNumInputs();
}
}
return true;
});

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
addTensorConversionPatterns(typeConverter, patterns, target);

patterns.add<AddCryptoContextAndKeys>(typeConverter, context);
patterns.add<ConvertFuncCallOp>(context);
patterns.add<ConvertBinOp<lwe::AddOp, jaxiteword::AddOp>>(typeConverter,
context);
patterns.add<ConvertBinOp<lwe::RAddOp, jaxiteword::AddOp>>(typeConverter,
Expand Down
Loading
Loading