diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td index a5a8c8dffb..f787ef719b 100644 --- a/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td @@ -80,6 +80,9 @@ def GenParamsOp : JaxiteWord_Op<"gen_params"> { - numEvalMult: Number of evaluation multiplications }]; let arguments = (ins + JaxiteWord_PublicKey:$publicKey, + JaxiteWord_PrivateKey:$secretKey, + JaxiteWord_EvalKey:$evaluationKey, // Scheme parameters I64Attr:$degree, I64Attr:$numSlots, @@ -134,7 +137,6 @@ def ProgramInitializationOp : JaxiteWord_Op<"program_initialization"> { }]; let arguments = (ins JaxiteWord_CryptoContext:$cryptoContext, - JaxiteWord_PrivateKey:$secretKey, I64Attr:$totalHemulLevels, DenseI64ArrayAttr:$totalRotationIndices, I32Attr:$dnum, diff --git a/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp b/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp index 367cee033b..06b1064fba 100644 --- a/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp +++ b/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp @@ -90,15 +90,24 @@ struct ConfigureCryptoContext LogicalResult generateGenFunc(func::FuncOp op, const std::string& genFuncName, ImplicitLocOpBuilder& builder) { Type ccType = CryptoContextType::get(builder.getContext()); + Type pkType = PublicKeyType::get(builder.getContext()); + Type skType = PrivateKeyType::get(builder.getContext()); + Type ekType = EvalKeyType::get(builder.getContext()); + + SmallVector funcArgTypes = {pkType, skType, ekType}; SmallVector funcResultTypes = {ccType}; FunctionType genFuncType = - FunctionType::get(builder.getContext(), {}, funcResultTypes); + FunctionType::get(builder.getContext(), funcArgTypes, funcResultTypes); auto genFuncOp = func::FuncOp::create(builder, genFuncName, genFuncType); builder.setInsertionPointToEnd(genFuncOp.addEntryBlock()); + Value publicKey = genFuncOp.getArgument(0); + Value secretKey = genFuncOp.getArgument(1); + Value evaluationKey = genFuncOp.getArgument(2); + Value cryptoContext = GenParamsOp::create( - builder, ccType, + builder, ccType, publicKey, secretKey, evaluationKey, /*degree=*/static_cast(config.degree), /*numSlots=*/static_cast(config.numSlots), /*scalingFactor=*/llvm::APFloat(config.scalingFactor), @@ -119,9 +128,8 @@ struct ConfigureCryptoContext const std::string& configFuncName, ImplicitLocOpBuilder& builder) { Type ccType = CryptoContextType::get(builder.getContext()); - Type skType = PrivateKeyType::get(builder.getContext()); - SmallVector funcArgTypes = {ccType, skType}; + SmallVector funcArgTypes = {ccType}; SmallVector funcResultTypes; FunctionType configFuncType = @@ -131,10 +139,9 @@ struct ConfigureCryptoContext builder.setInsertionPointToEnd(configFuncOp.addEntryBlock()); Value cryptoContext = configFuncOp.getArgument(0); - Value secretKey = configFuncOp.getArgument(1); ProgramInitializationOp::create( - builder, cryptoContext, secretKey, + builder, cryptoContext, /*totalHemulLevels=*/static_cast(config.mulDepth), /*totalRotationIndices=*/config.rotIndices, /*dnum=*/config.dnum, diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.td b/lib/Dialect/JaxiteWord/Transforms/Passes.td index f9ab1ebdba..9112a118ae 100644 --- a/lib/Dialect/JaxiteWord/Transforms/Passes.td +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.td @@ -16,10 +16,12 @@ def ConfigureCryptoContext : Pass<"jaxiteword-configure-crypto-context"> { For example, for an MLIR function `@my_func`, the generated helpers have the following signatures: ```mlir - func.func @my_func__generate_crypto_context() -> !jaxiteword.crypto_context + func.func @my_func__generate_crypto_context( + !jaxiteword.public_key, !jaxiteword.private_key, + !jaxiteword.eval_key) -> !jaxiteword.crypto_context func.func @my_func__configure_crypto_context( - !jaxiteword.crypto_context, !jaxiteword.private_key) + !jaxiteword.crypto_context) ``` }]; let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect"]; diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD index 49ee33ca34..e69178735d 100644 --- a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD @@ -18,6 +18,7 @@ cc_library( "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp index 8949062aa8..2cc6ed89da 100644 --- a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp @@ -15,6 +15,7 @@ #include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Utils/ConversionUtils.h" #include "lib/Utils/Utils.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -53,6 +54,60 @@ class JaxiteWordTypeConverter : public TypeConverter { namespace { +bool containsCryptoArgument(func::FuncOp funcOp) { + return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) { + return DialectEqual()( + &getElementTypeOrSelf(argType).getDialect()); + }); +} + +bool funcNeedsCryptoContextAndKeys(func::FuncOp funcOp) { + return containsDialects( + funcOp) || + containsCryptoArgument(funcOp); +} + +void insertCryptoContextAndKeys(func::FuncOp funcOp) { + if (!funcNeedsCryptoContextAndKeys(funcOp)) return; + if (funcOp.getFunctionType().getNumInputs() >= 2 && + mlir::isa( + funcOp.getFunctionType().getInput(0)) && + mlir::isa( + funcOp.getFunctionType().getInput(1))) { + return; + } + auto cryptoContextType = + jaxiteword::CryptoContextType::get(funcOp.getContext()); + auto evalKeyType = jaxiteword::EvalKeyType::get(funcOp.getContext()); + (void)funcOp.insertArgument(0, evalKeyType, nullptr, funcOp.getLoc()); + (void)funcOp.insertArgument(0, cryptoContextType, nullptr, funcOp.getLoc()); +} + +void updateCryptoFuncCalls(Operation* op) { + op->walk([&](func::CallOp callOp) { + auto callee = getCalledFunction(callOp); + if (failed(callee) || !funcNeedsCryptoContextAndKeys(callee.value())) { + return; + } + if (callOp.getNumOperands() == callee.value().getNumArguments()) { + return; + } + auto caller = callOp->getParentOfType(); + if (!caller || caller.getNumArguments() < 2 || + !mlir::isa( + caller.getArgument(0).getType()) || + !mlir::isa(caller.getArgument(1).getType())) { + return; + } + SmallVector newOperands; + newOperands.push_back(caller.getArgument(0)); + newOperands.push_back(caller.getArgument(1)); + newOperands.append(callOp.getOperands().begin(), + callOp.getOperands().end()); + callOp->setOperands(newOperands); + }); +} + FailureOr getContextualCryptoContextForJaxiteWord(Operation* op) { auto funcOp = op->getParentOfType(); if (!funcOp) return failure(); @@ -75,6 +130,25 @@ FailureOr getContextualEvalKeyForJaxiteWord(Operation* op) { return funcOp.getArgument(1); } +static FailureOr getStaticRotationIndex(ckks::RotateOp op, + Value dynamicShift) { + auto i64Type = IntegerType::get(op.getContext(), 64); + if (IntegerAttr staticShift = op.getStaticShiftAttr()) { + return IntegerAttr::get(i64Type, staticShift.getValue().getSExtValue()); + } + if (!dynamicShift) { + return failure(); + } + auto constOp = dynamicShift.getDefiningOp(); + if (!constOp) { + return failure(); + } + if (auto intAttr = dyn_cast(constOp.getValue())) { + return IntegerAttr::get(i64Type, intAttr.getValue().getSExtValue()); + } + return failure(); +} + struct AddCryptoContextAndKeys : public OpConversionPattern { AddCryptoContextAndKeys(mlir::MLIRContext* context) : OpConversionPattern(context, /* benefit= */ 2) {} @@ -84,10 +158,7 @@ struct AddCryptoContextAndKeys : public OpConversionPattern { LogicalResult matchAndRewrite( func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto containsCryptoOps = - ::mlir::heir::containsDialects(op); - if (!containsCryptoOps) return failure(); + if (!funcNeedsCryptoContextAndKeys(op)) return failure(); auto cryptoContextType = jaxiteword::CryptoContextType::get(getContext()); auto evalKeyType = jaxiteword::EvalKeyType::get(getContext()); @@ -108,6 +179,43 @@ struct AddCryptoContextAndKeys : public OpConversionPattern { } }; +struct ConvertFuncCallOp : public OpConversionPattern { + ConvertFuncCallOp(mlir::MLIRContext* context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::CallOp op, typename func::CallOp::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto callee = getCalledFunction(op); + if (failed(callee) || !funcNeedsCryptoContextAndKeys(callee.value())) { + return failure(); + } + if (op.getNumOperands() >= callee.value().getNumArguments()) { + return failure(); + } + + FailureOr ctx = getContextualCryptoContextForJaxiteWord(op); + if (failed(ctx)) return failure(); + FailureOr evalKey = getContextualEvalKeyForJaxiteWord(op); + if (failed(evalKey)) return failure(); + + SmallVector newOperands; + newOperands.push_back(ctx.value()); + newOperands.push_back(evalKey.value()); + newOperands.append(adaptor.getOperands().begin(), + adaptor.getOperands().end()); + + SmallVector dialectAttrs(op->getDialectAttrs()); + rewriter + .replaceOpWithNewOp(op, op.getCallee(), + op.getResultTypes(), newOperands) + ->setDialectAttrs(dialectAttrs); + return success(); + } +}; + template struct ConvertBinOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -173,20 +281,16 @@ struct ConvertRotateOp : public OpConversionPattern { FailureOr evalKey = getContextualEvalKeyForJaxiteWord(op); if (failed(evalKey)) return failure(); - Value dynamicShift = adaptor.getDynamicShift(); - IntegerAttr staticShift = op.getStaticShiftAttr(); - if (!staticShift && !dynamicShift) { + FailureOr indexAttr = + getStaticRotationIndex(op, adaptor.getDynamicShift()); + if (failed(indexAttr)) { return rewriter.notifyMatchFailure( - op, "rotate op must have either static or dynamic shift"); - } - if (dynamicShift) { - return rewriter.notifyMatchFailure( - op, "jaxiteword rotation requires static shift"); + op, "jaxiteword rotation requires statically known shift"); } rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getOutput().getType()), - ctx.value(), adaptor.getInput(), evalKey.value(), staticShift); + ctx.value(), adaptor.getInput(), evalKey.value(), indexAttr.value()); return success(); } }; @@ -298,6 +402,9 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { MLIRContext* context = &getContext(); Operation* op = getOperation(); + op->walk([&](func::FuncOp funcOp) { insertCryptoContextAndKeys(funcOp); }); + updateCryptoFuncCalls(op); + RewritePatternSet patterns(context); ConversionTarget target(*context); @@ -309,17 +416,25 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { JaxiteWordTypeConverter typeConverter(context); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - auto containsCryptoOps = - ::mlir::heir::containsDialects(op); - if (!containsCryptoOps) return true; - bool hasArgs = op.getFunctionType().getNumInputs() >= 2; - return typeConverter.isSignatureLegal(op.getFunctionType()) && hasArgs && + target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { + if (!funcNeedsCryptoContextAndKeys(funcOp)) return true; + bool hasArgs = funcOp.getFunctionType().getNumInputs() >= 2; + return typeConverter.isSignatureLegal(funcOp.getFunctionType()) && + hasArgs && mlir::isa( - op.getFunctionType().getInput(0)) && + funcOp.getFunctionType().getInput(0)) && mlir::isa( - op.getFunctionType().getInput(1)); + funcOp.getFunctionType().getInput(1)); + }); + + target.addDynamicallyLegalOp([&](func::CallOp callOp) { + if (auto callee = getCalledFunction(callOp); succeeded(callee)) { + if (funcNeedsCryptoContextAndKeys(callee.value())) { + return callOp.getNumOperands() == + callOp.getCalleeType().getNumInputs(); + } + } + return true; }); populateFunctionOpInterfaceTypeConversionPattern( @@ -327,6 +442,7 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { addTensorConversionPatterns(typeConverter, patterns, target); patterns.add(typeConverter, context); + patterns.add(context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, diff --git a/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp b/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp index 0122cb4517..381e6f7b4c 100644 --- a/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp +++ b/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp @@ -30,6 +30,7 @@ #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project @@ -115,7 +116,7 @@ LogicalResult JaxiteWordEmitter::translate(Operation& op) { LogicalResult status = llvm::TypeSwitch(op) .Case([&](auto op) { return printOperation(op); }) - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Case(arg.getType())) { - CiphertextArg_ = argName; - } } os.unindent(); os << ")"; @@ -213,6 +212,26 @@ LogicalResult JaxiteWordEmitter::printOperation(func::FuncOp funcOp) { return success(); } +LogicalResult JaxiteWordEmitter::printOperation(func::CallOp op) { + if (op.getNumResults() == 1) { + emitAssignPrefix(op.getResult(0)); + } else if (op.getNumResults() > 1) { + os << "("; + for (auto [idx, result] : llvm::enumerate(op.getResults())) { + if (idx > 0) os << ", "; + os << variableNames->getNameForValue(result); + } + os << ") = "; + } + + os << op.getCallee().str() << "("; + os << commaSeparatedValues(op.getOperands(), [&](Value value) { + return variableNames->getNameForValue(value); + }); + os << ")\n"; + return success(); +} + LogicalResult JaxiteWordEmitter::printOperation(func::ReturnOp op) { std::function resultValue = [&](Value value) { if (isa(value)) { @@ -236,38 +255,56 @@ LogicalResult JaxiteWordEmitter::printOperation(func::ReturnOp op) { } LogicalResult JaxiteWordEmitter::printOperation(AddOp op) { - return printBinaryOpHelper( - op.getResult(), op.getLhs(), op.getRhs(), - [&](StringRef lhs, StringRef rhs, StringRef result) { - os << lhs << "\n"; - os << llvm::formatv(kAddCoreTemplate.data(), result, rhs); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + auto result = variableNames->getNameForValue(op.getResult()); + + emitModularAdd(result, ctx, lhs, rhs); + return success(); } LogicalResult JaxiteWordEmitter::printOperation(SubOp op) { - return printBinaryOpHelper( - op.getResult(), op.getLhs(), op.getRhs(), - [&](StringRef lhs, StringRef rhs, StringRef result) { - os << lhs << "\n"; - std::string rhsCiphertext = (rhs + ".ciphertext").str(); - os << llvm::formatv(kSubTemplate.data(), result, rhsCiphertext); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string rhsWork = (result + "_rhs"); + + emitNormalizeCiphertext(result, ctx, lhs); + emitNormalizeCiphertext(rhsWork, ctx, rhs); + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << result << ".polynomial = jnp.where(" << result << ".polynomial < " + << rhsWork << ".polynomial, " << result << ".polynomial + _moduli - " + << rhsWork << ".polynomial, " << result << ".polynomial - " << rhsWork + << ".polynomial)\n"; + return success(); } LogicalResult JaxiteWordEmitter::printOperation(NegateOp op) { - emitAssignPrefix(op.getResult()); - os << variableNames->getNameForValue(op.getCiphertext()) << ".mul(-1)\n"; + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto ct = variableNames->getNameForValue(op.getCiphertext()); + auto result = variableNames->getNameForValue(op.getResult()); + + emitNormalizeCiphertext(result, ctx, ct); + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << result << ".polynomial = jnp.where(" << result << ".polynomial == 0, " + << result << ".polynomial, _moduli - " << result << ".polynomial)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(SquareOp op) { auto ct = variableNames->getNameForValue(op.getCiphertext()); auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto result = variableNames->getNameForValue(op.getResult()); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/1); + std::string inputLevel = ("(" + level + ") + 1"); + std::string ctWork = (result + "_arg"); - emitAssignPrefix(op.getResult()); - os << ctx << ".he_mul[" << level << "].mul(" << ct << ", " << ct << ")\n"; + emitNormalizeCiphertext(ctWork, ctx, ct, inputLevel); + os << result << " = " << ctx << ".he_mul[" << level << "].mul(" << ctWork + << ", " << ctWork << ")\n"; return success(); } @@ -275,26 +312,72 @@ void JaxiteWordEmitter::emitAssignPrefix(Value result) { os << variableNames->getNameForValue(result) << " = "; } -LogicalResult JaxiteWordEmitter::printBinaryOpHelper( - Value result, Value lhs, Value rhs, - llvm::function_ref callback) { - auto lhsName = variableNames->getNameForValue(lhs); - auto rhsName = variableNames->getNameForValue(rhs); - auto resultName = variableNames->getNameForValue(result); +void JaxiteWordEmitter::emitAssignCiphertext(StringRef targetName, + StringRef sourceName) { + os << "_assign_poly(" << targetName << ", " << sourceName << ")\n"; +} - emitAssignPrefix(result); - callback(lhsName, rhsName, resultName); - return success(); +void JaxiteWordEmitter::emitNormalizeCiphertext(StringRef resultName, + StringRef ctxName, + StringRef sourceName, + StringRef levelExpr) { + os << resultName << " = _ensure_poly(" << ctxName << ", " << sourceName; + if (!levelExpr.empty()) os << ", " << levelExpr; + os << ")\n"; } -LogicalResult JaxiteWordEmitter::printInPlaceBinaryOpHelper( - Value lhs, Value rhs, - llvm::function_ref callback) { - auto lhsName = variableNames->getNameForValue(lhs); - auto rhsName = variableNames->getNameForValue(rhs); +void JaxiteWordEmitter::emitModularAdd(StringRef resultName, StringRef ctxName, + StringRef lhsName, StringRef rhsName) { + std::string lhsData = (resultName + "_lhs").str(); + std::string rhsData = (resultName + "_rhs").str(); + std::string numModuli = (resultName + "_num_moduli").str(); + std::string moduliSrc = (resultName + "_moduli_src").str(); + std::string moduli = (resultName + "_moduli").str(); + std::string sum = (resultName + "_sum").str(); + + os << lhsData << " = " << lhsName << ".polynomial if hasattr(" << lhsName + << ", \"polynomial\") else " << lhsName << "\n"; + os << rhsData << " = " << rhsName << ".polynomial if hasattr(" << rhsName + << ", \"polynomial\") else " << rhsName << "\n"; + os << lhsData << " = " << lhsData << ".reshape(" << lhsData << ".shape[0], " + << lhsData << ".shape[1], " << ctxName << "._param_cache.r, " << ctxName + << "._param_cache.c, " << lhsData << ".shape[-1])\n"; + os << rhsData << " = " << rhsData << ".reshape(" << rhsData << ".shape[0], " + << rhsData << ".shape[1], " << ctxName << "._param_cache.r, " << ctxName + << "._param_cache.c, " << rhsData << ".shape[-1])\n"; + os << "if " << lhsData << ".shape != " << rhsData << ".shape:\n"; + os.indent(); + os << "raise ValueError(\"ciphertext add shape mismatch\")\n"; + os.unindent(); + os << numModuli << " = " << lhsData << ".shape[-1]\n"; + os << "if hasattr(" << lhsName << ", \"moduli\") and hasattr(" << rhsName + << ", \"moduli\"):\n"; + os.indent(); + os << "if list(" << lhsName << ".moduli)[:" << numModuli << "] != list(" + << rhsName << ".moduli)[:" << numModuli << "]:\n"; + os.indent(); + os << "raise ValueError(\"ciphertext add modulus mismatch\")\n"; + os.unindent(); + os.unindent(); + os << moduliSrc << " = getattr(" << lhsName << ", \"moduli\", getattr(" + << rhsName << ", \"moduli\", " << ctxName << ".q_towers))\n"; + os << "if isinstance(" << moduliSrc << ", (int, np.integer)):\n"; + os.indent(); + os << moduliSrc << " = [" << moduliSrc << "]\n"; + os.unindent(); + os << moduli << " = jnp.array(list(" << moduliSrc << ")[:" << numModuli + << "], dtype=jnp.uint64)\n"; + os << sum << " = " << lhsData << ".astype(jnp.uint64) + " << rhsData + << ".astype(jnp.uint64)\n"; + os << resultName << " = jnp.where(" << sum << " >= " << moduli << ", " << sum + << " - " << moduli << ", " << sum << ").astype(jnp.uint32)\n"; +} - callback(lhsName, rhsName); - return success(); +void JaxiteWordEmitter::emitModularReduce(StringRef targetName) { + os << "_moduli = jnp.array(" << targetName << ".moduli, dtype=jnp.uint32)\n"; + os << targetName << ".polynomial = jnp.where(" << targetName + << ".polynomial >= _moduli, " << targetName << ".polynomial - _moduli, " + << targetName << ".polynomial)\n"; } LogicalResult JaxiteWordEmitter::printMulOpHelper( @@ -308,7 +391,6 @@ LogicalResult JaxiteWordEmitter::printMulOpHelper( auto resultName = variableNames->getNameForValue(result); auto level = getCrossLevelExpr(lhs, ctxName, op, /*extraOffset=*/1); - emitAssignPrefix(result); callback(lhsName, rhsName, ctxName, resultName, level); return success(); } @@ -323,64 +405,96 @@ LogicalResult JaxiteWordEmitter::printOperation(EncodeOp op) { LogicalResult JaxiteWordEmitter::printOperation(EncryptOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto pk = variableNames->getNameForValue(op.getPublicKey()); - os << ctx << ".public_key = " << pk << "\n"; + auto pt = variableNames->getNameForValue(op.getPlaintext()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string raw = result + "_raw"; - emitAssignPrefix(op.getResult()); - os << ctx << ".encrypt(" << variableNames->getNameForValue(op.getPlaintext()) - << ")\n"; + os << ctx << ".public_key = " << pk << "\n"; + os << raw << " = " << ctx << ".encrypt(" << pt << ")\n"; + emitNormalizeCiphertext(result, ctx, raw); return success(); } LogicalResult JaxiteWordEmitter::printOperation(MulOp op) { - return printMulOpHelper(op.getResult(), op.getLhs(), op.getRhs(), - op.getCryptoContext(), op.getOperation(), - [&](StringRef lhs, StringRef rhs, StringRef ctx, - StringRef result, StringRef level) { - os << ctx << ".he_mul[" << level << "].hemul(" - << lhs << ", " << rhs << ")\n"; - }); + return printMulOpHelper( + op.getResult(), op.getLhs(), op.getRhs(), op.getCryptoContext(), + op.getOperation(), + [&](StringRef lhs, StringRef rhs, StringRef ctx, StringRef result, + StringRef level) { + std::string inputLevel = ("(" + level + ") + 1").str(); + std::string lhsWork = (result + "_lhs").str(); + std::string rhsWork = (result + "_rhs").str(); + emitNormalizeCiphertext(lhsWork, ctx, lhs, inputLevel); + emitNormalizeCiphertext(rhsWork, ctx, rhs, inputLevel); + os << result << " = " << ctx << ".he_mul[" << level << "].mul(" + << lhsWork << ", " << rhsWork << ")\n"; + }); } LogicalResult JaxiteWordEmitter::printOperation(MulNoRelinOp op) { - return printMulOpHelper(op.getResult(), op.getLhs(), op.getRhs(), - op.getCryptoContext(), op.getOperation(), - [&](StringRef lhs, StringRef rhs, StringRef ctx, - StringRef result, StringRef level) { - os << ctx << ".he_mul[" << level - << "].hemul_no_relin(" << lhs << ", " << rhs - << ")\n"; - }); + return printMulOpHelper( + op.getResult(), op.getLhs(), op.getRhs(), op.getCryptoContext(), + op.getOperation(), + [&](StringRef lhs, StringRef rhs, StringRef ctx, StringRef result, + StringRef level) { + std::string inputLevel = ("(" + level + ") + 1").str(); + std::string lhsWork = (result + "_lhs").str(); + std::string rhsWork = (result + "_rhs").str(); + std::string raw = (result + "_raw").str(); + emitNormalizeCiphertext(lhsWork, ctx, lhs, inputLevel); + emitNormalizeCiphertext(rhsWork, ctx, rhs, inputLevel); + os << raw << " = " << ctx << ".he_mul[" << level << "].hemul_no_relin(" + << lhsWork << ", " << rhsWork << ")\n"; + emitNormalizeCiphertext(result, ctx, raw, level); + }); } LogicalResult JaxiteWordEmitter::printOperation(RelinOp op) { auto ct = variableNames->getNameForValue(op.getCiphertext()); auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto result = variableNames->getNameForValue(op.getOutput()); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/1); + std::string ctData = result + "_ct_data"; - auto result = variableNames->getNameForValue(op.getOutput()); - os << llvm::formatv(kRelinTemplate.data(), result, ctx, level, ct); - + os << ctData << " = " << ct << ".polynomial if hasattr(" << ct + << ", \"polynomial\") else " << ct << "\n"; + os << result << " = " << ctx << ".he_mul[" << level << "].relinearize(" + << ctData << ")\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(ModReduceOp op) { + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto ct = variableNames->getNameForValue(op.getCiphertext()); - emitAssignPrefix(op.getResult()); - os << ct << "\n"; + auto result = variableNames->getNameForValue(op.getResult()); + auto srcLevel = + getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/0); + auto dstLevel = getCrossLevelExpr(op.getResult(), ctx, op, /*extraOffset=*/0); + std::string ctWork = result + "_arg"; + + emitNormalizeCiphertext(ctWork, ctx, ct, srcLevel); + if (srcLevel == dstLevel) { + emitNormalizeCiphertext(result, ctx, ctWork, dstLevel); + return success(); + } + os << result << " = " << ctx << ".he_rescale[" << srcLevel << ", " << dstLevel + << "](" << ctWork << ")\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(RotOp op) { auto ct = variableNames->getNameForValue(op.getCiphertext()); auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto result = variableNames->getNameForValue(op.getResult()); auto rotIndex = op.getIndex(); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/0); + std::string ctWork = result + "_arg"; - emitAssignPrefix(op.getResult()); - os << ctx << ".he_rot[" << level << ", " << rotIndex << "].rotate(" << ct - << ")\n"; + emitNormalizeCiphertext(ctWork, ctx, ct, level); + os << result << " = " << ctx << ".he_rot[" << level << ", " << rotIndex + << "].rotate(" << ctWork << ")\n"; return success(); } @@ -388,14 +502,23 @@ LogicalResult JaxiteWordEmitter::printOperation(DecryptOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto sk = variableNames->getNameForValue(op.getSecretKey()); auto ct = variableNames->getNameForValue(op.getCiphertext()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string ctWork = result + "_ct"; - auto ctType = cast(op.getCiphertext().getType()); - int current = ctType.getModulusChain().getCurrent(); - int maxCurrent = getMaxCurrentInModule(op); - int rescales = maxCurrent - current; - - os << llvm::formatv(kDecryptTemplate.data(), ctx, sk, rescales, ct); - + os << ctx << ".secret_key = " << sk << "\n"; + emitNormalizeCiphertext(ctWork, ctx, ct); + os << "_num_moduli = " << ctWork << ".polynomial.shape[-1]\n"; + os << "_q_sub = list(getattr(" << ctWork << ", \"moduli\", " << ctx + << ".q_towers))[:_num_moduli]\n"; + os << "_ct_for_dec = Polynomial({\"batch\": " << ctWork + << ".polynomial.shape[0], \"num_elements\": " << ctWork + << ".polynomial.shape[1], \"degree\": " << ctx + << ".degree, \"precision\": 32, \"num_moduli\": _num_moduli, " + "\"degree_layout\": (" + << ctx << ".degree,)}, {\"moduli\": _q_sub})\n"; + os << "_ct_for_dec.set_batch_polynomial(" << ctWork << ".polynomial.reshape(" + << ctWork << ".polynomial.shape[0], " << ctWork << ".polynomial.shape[1], " + << ctx << ".degree, _num_moduli))\n"; emitAssignPrefix(op.getResult()); os << ctx << ".decrypt(_ct_for_dec)\n"; return success(); @@ -424,48 +547,114 @@ LogicalResult JaxiteWordEmitter::printOperation(MulPlainOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto ct = variableNames->getNameForValue(op.getCiphertext()); auto pt = variableNames->getNameForValue(op.getPlaintext()); + auto result = variableNames->getNameForValue(op.getResult()); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/0); - - os << ctx << ".ptct_mul[" << level << "].set_plaintext(" << pt << ")\n"; - emitAssignPrefix(op.getResult()); - os << ctx << ".ptct_mul[" << level << "].mul(" << ct << ")\n"; + std::string ctWork = result + "_arg"; + std::string ptNtt = result + "_pt_ntt"; + std::string opName = result + "_ptct"; + + emitNormalizeCiphertext(ctWork, ctx, ct, level); + os << ptNtt << " = " << pt << ".polynomial[0, 0, :, :" << ctWork + << ".polynomial.shape[-1]].reshape(" << ctWork << ".r, " << ctWork + << ".c, " << ctWork << ".polynomial.shape[-1]).astype(jnp.uint32)\n"; + os << opName << " = " << ctx << ".ptct_mul[" << level << "]\n"; + os << opName << ".set_plaintext(" << ptNtt << ")\n"; + os << result << " = " << opName << ".mul(" << ctWork << ", use_bat=False)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(AddPlainOp op) { - auto lhs = variableNames->getNameForValue(op.getLhs()); - auto rhs = variableNames->getNameForValue(op.getRhs()); - os << lhs << ".ciphertext = " << lhs << ".ciphertext + " << rhs << "\n"; - os << llvm::formatv(kAddModReduceTemplate.data(), lhs); - - emitAssignPrefix(op.getResult()); - os << lhs << "\n"; + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + bool lhsPlain = + isa(getElementTypeOrSelf(op.getLhs().getType())); + auto ct = + variableNames->getNameForValue(lhsPlain ? op.getRhs() : op.getLhs()); + auto pt = + variableNames->getNameForValue(lhsPlain ? op.getLhs() : op.getRhs()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string numModuli = result + "_m"; + std::string ptData = result + "_pt_data"; + + emitNormalizeCiphertext(result, ctx, ct); + os << numModuli << " = " << result << ".polynomial.shape[-1]\n"; + os << ptData << " = " << pt << ".polynomial[0:1, 0:1, :, :" << numModuli + << "].reshape(1, 1, " << result << ".r, " << result << ".c, " << numModuli + << ")\n"; + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << "_c0 = " << result << ".polynomial[:, 0:1, ...] + " << ptData << "\n"; + os << "_c0 = jnp.where(_c0 >= _moduli, _c0 - _moduli, _c0)\n"; + os << result << ".polynomial = " << result + << ".polynomial.at[:, 0:1, ...].set(_c0)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(SubPlainOp op) { + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + bool lhsPlain = + isa(getElementTypeOrSelf(op.getLhs().getType())); auto lhs = variableNames->getNameForValue(op.getLhs()); auto rhs = variableNames->getNameForValue(op.getRhs()); - os << llvm::formatv(kSubTemplate.data(), lhs, rhs); - emitAssignPrefix(op.getResult()); - os << lhs << "\n"; + auto ct = lhsPlain ? rhs : lhs; + auto pt = lhsPlain ? lhs : rhs; + auto result = variableNames->getNameForValue(op.getResult()); + std::string numModuli = result + "_m"; + std::string ptData = result + "_pt_data"; + + emitNormalizeCiphertext(result, ctx, ct); + os << numModuli << " = " << result << ".polynomial.shape[-1]\n"; + os << ptData << " = " << pt << ".polynomial[0:1, 0:1, :, :" << numModuli + << "].reshape(1, 1, " << result << ".r, " << result << ".c, " << numModuli + << ")\n"; + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << "_c0 = " << result << ".polynomial[:, 0:1, ...]\n"; + if (lhsPlain) { + os << "_c0 = jnp.where(" << ptData << " < _c0, " << ptData + << " + _moduli - _c0, " << ptData << " - _c0)\n"; + os << "_c1 = " << result << ".polynomial[:, 1:2, ...]\n"; + os << "_c1 = jnp.where(_c1 == 0, _c1, _moduli - _c1)\n"; + os << result << ".polynomial = " << result + << ".polynomial.at[:, 1:2, ...].set(_c1)\n"; + } else { + os << "_c0 = jnp.where(_c0 < " << ptData << ", _c0 + _moduli - " << ptData + << ", _c0 - " << ptData << ")\n"; + } + os << result << ".polynomial = " << result + << ".polynomial.at[:, 0:1, ...].set(_c0)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(AddInPlaceOp op) { - return printInPlaceBinaryOpHelper( - op.getLhs(), op.getRhs(), [&](StringRef lhs, StringRef rhs) { - os << llvm::formatv(kAddCoreTemplate.data(), lhs, rhs); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + std::string tmp = lhs + "_inplace"; + std::string rhsWork = lhs + "_rhs"; + + emitNormalizeCiphertext(tmp, ctx, lhs); + emitNormalizeCiphertext(rhsWork, ctx, rhs); + os << tmp << ".add(" << rhsWork << ")\n"; + emitModularReduce(tmp); + emitAssignCiphertext(lhs, tmp); + return success(); } LogicalResult JaxiteWordEmitter::printOperation(SubInPlaceOp op) { - return printInPlaceBinaryOpHelper( - op.getLhs(), op.getRhs(), [&](StringRef lhs, StringRef rhs) { - std::string rhsCiphertext = (rhs + ".ciphertext").str(); - os << llvm::formatv(kSubTemplate.data(), lhs, rhsCiphertext); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + std::string tmp = lhs + "_inplace"; + std::string rhsWork = lhs + "_rhs"; + + emitNormalizeCiphertext(tmp, ctx, lhs); + emitNormalizeCiphertext(rhsWork, ctx, rhs); + os << "_moduli = jnp.array(" << tmp << ".moduli, dtype=jnp.uint32)\n"; + os << tmp << ".polynomial = jnp.where(" << tmp << ".polynomial < " << rhsWork + << ".polynomial, " << tmp << ".polynomial + _moduli - " << rhsWork + << ".polynomial, " << tmp << ".polynomial - " << rhsWork + << ".polynomial)\n"; + emitAssignCiphertext(lhs, tmp); + return success(); } LogicalResult JaxiteWordEmitter::printOperation(GenKeyPairOp op) { @@ -489,9 +678,6 @@ LogicalResult JaxiteWordEmitter::printOperation(GenMulKeyOp op) { << ".q_towers, " << "P=" << ctx << ".p_towers, " << "dnum=" << ctx << ".parameters.get('dnum', 3)" << ")\n"; - heMulVarName_ = "he_mul"; - os << llvm::formatv(kGenMulKeyTemplate.data(), ctx, heMulVarName_, ek); - return success(); } @@ -500,20 +686,24 @@ LogicalResult JaxiteWordEmitter::printOperation(GenRotKeyOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto sk = variableNames->getNameForValue(op.getSecretKey()); - rotKeysDictVarName_ = rk + "_dict"; - std::string indicesStr; - llvm::raw_string_ostream indicesOs(indicesStr); - llvm::interleaveComma(op.getIndices(), indicesOs); - - os << llvm::formatv(kGenRotKeyTemplate.data(), rotKeysDictVarName_, - indicesStr, sk, ctx, heRotVarName_, rk); + os << rk << " = {}\n"; + os << "for _rot_idx in ["; + llvm::interleaveComma(op.getIndices(), os); + os << "]:\n"; + os.indent(); + os << rk << "[_rot_idx] = key_gen.gen_rotation_key(" << sk << ", " << ctx + << ".q_towers, " << ctx << ".p_towers, rot_index=_rot_idx, dnum=" << ctx + << ".parameters.get('dnum', 3))[_rot_idx]\n"; + os.unindent(); return success(); } LogicalResult JaxiteWordEmitter::printOperation(GenParamsOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); - cryptoContextVarName_ = ctx; + auto publicKey = variableNames->getNameForValue(op.getPublicKey()); + auto secretKey = variableNames->getNameForValue(op.getSecretKey()); + auto evaluationKey = variableNames->getNameForValue(op.getEvaluationKey()); os << "params = {\n"; os.indent(); @@ -543,7 +733,10 @@ LogicalResult JaxiteWordEmitter::printOperation(GenParamsOp op) { os << "\"max_bits_in_word\": 61,\n"; os << "\"max_bits_value\": " << ((1ULL << 63) - (1ULL << 9) - 1) << ",\n"; os << "\"noise_scale_degree\": 1,\n"; - os << "\"CKKS_M_FACTOR\": 1\n"; + os << "\"CKKS_M_FACTOR\": 1,\n"; + os << "\"public_key\": " << publicKey << ",\n"; + os << "\"secret_key\": " << secretKey << ",\n"; + os << "\"evaluation_key\": " << evaluationKey << "\n"; os.unindent(); os << "}\n"; @@ -553,9 +746,7 @@ LogicalResult JaxiteWordEmitter::printOperation(GenParamsOp op) { LogicalResult JaxiteWordEmitter::printOperation(ProgramInitializationOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); - auto sk = variableNames->getNameForValue(op.getSecretKey()); - os << ctx << ".secret_key = " << sk << "\n"; os << ctx << ".program_initialization("; os << "total_hemul_levels=" << op.getTotalHemulLevels() << ", "; @@ -682,7 +873,8 @@ LogicalResult JaxiteWordEmitter::printOperation(tensor::EmptyOp op) { return success(); } - os << " = np.zeros(("; + emitAssignPrefix(op.getResult()); + os << "np.zeros(("; for (size_t i = 0; i < shape.size(); ++i) { if (i > 0) os << ", "; os << shape[i]; diff --git a/lib/Target/JaxiteWord/JaxiteWordEmitter.h b/lib/Target/JaxiteWord/JaxiteWordEmitter.h index c744b6205f..5a21d06c77 100644 --- a/lib/Target/JaxiteWord/JaxiteWordEmitter.h +++ b/lib/Target/JaxiteWord/JaxiteWordEmitter.h @@ -43,24 +43,9 @@ class JaxiteWordEmitter { // values. SelectVariableNames* variableNames; - // ciphertext arg. - std::string CiphertextArg_; - - // A list of modulus to be used for the add operation. - std::string ModulusListArg_; - - // Crypto context variable name (set by GenParamsOp, used by accessor calls) - std::string cryptoContextVarName_; - - // Legacy member variables kept for backward compatibility with old pipeline - // (GenMulKeyOp / GenRotKeyOp path). Not used by the new - // ProgramInitializationOp path. - std::string heMulVarName_; - std::string heRotVarName_; - std::string rotKeysDictVarName_; - LogicalResult printOperation(ModuleOp moduleOp); LogicalResult printOperation(func::FuncOp funcOp); + LogicalResult printOperation(func::CallOp op); LogicalResult printOperation(func::ReturnOp returnOp); LogicalResult printOperation(AddOp op); LogicalResult printOperation(SubOp op); @@ -125,14 +110,12 @@ class JaxiteWordEmitter { FailureOr convertType(Type type); void emitAssignPrefix(Value result); - - LogicalResult printBinaryOpHelper( - Value result, Value lhs, Value rhs, - llvm::function_ref callback); - - LogicalResult printInPlaceBinaryOpHelper( - Value lhs, Value rhs, - llvm::function_ref callback); + void emitAssignCiphertext(StringRef targetName, StringRef sourceName); + void emitNormalizeCiphertext(StringRef resultName, StringRef ctxName, + StringRef sourceName, StringRef levelExpr = ""); + void emitModularAdd(StringRef resultName, StringRef ctxName, + StringRef lhsName, StringRef rhsName); + void emitModularReduce(StringRef targetName); LogicalResult printMulOpHelper( Value result, Value lhs, Value rhs, Value ctx, Operation* op, diff --git a/lib/Target/JaxiteWord/JaxiteWordTemplates.h b/lib/Target/JaxiteWord/JaxiteWordTemplates.h index d979ef6c05..14c4da1958 100644 --- a/lib/Target/JaxiteWord/JaxiteWordTemplates.h +++ b/lib/Target/JaxiteWord/JaxiteWordTemplates.h @@ -10,6 +10,7 @@ namespace jaxiteword { constexpr std::string_view kModulePrelude = R"python( import jax import jax.numpy as jnp +import key_gen import numpy as np from ciphertext import Ciphertext from polynomial import Polynomial @@ -17,111 +18,69 @@ import ckks_ctx as ckks )python"; -// Template for GenMulKeyOp -// This template initializes HEMul for homomorphic multiplication and sets up -// relinearization. It computes r and c from the degree if they are not provided -// in the parameters. -constexpr std::string_view kGenMulKeyTemplate = R"python( -_degree = {0}.parameters.get('degree') -if _degree is not None: - _log_degree = int(math.log2(_degree)) - _half_k = _log_degree // 2 - _default_r = 2 ** _half_k - _default_c = _degree // _default_r -else: - _default_r = 4 - _default_c = 4 -{1} = HEMul( - batch={0}.parameters.get('batch', 1), - r={0}.parameters.get('r', _default_r), - c={0}.parameters.get('c', _default_c), - dnum={0}.parameters.get('dnum', 3), - num_eval_mult={0}.parameters.get('numEvalMult', 1), - original_moduli={0}.q_towers, - extend_moduli={0}.p_towers -) -{1}.control_gen(degree_layout=({0}.parameters.get('r', _default_r), {0}.parameters.get('c', _default_c))) -{1}.setup_relinearization(jnp.array({2}["a"], dtype=jnp.uint32).transpose(0,2,1), jnp.array({2}["b"], dtype=jnp.uint32).transpose(0,2,1)) -)python"; +constexpr std::string_view kEnsurePolyHelper = R"python( +def _ensure_poly(ctx, x, level=None): + _cache = ctx._param_cache + _r = _cache.r + _c = _cache.c + _m = _cache.num_q_at_level(level) if level is not None else None -// Template for GenRotKeyOp -// This template generates rotation keys for power-of-2 indices and initializes -// HERot. It computes r and c from the degree if they are not provided in the -// parameters. -constexpr std::string_view kGenRotKeyTemplate = R"python( -{0} = {{}} -_all_indices = [{1}] -_max_abs_rot_idx = max([abs(idx) for idx in _all_indices]) if _all_indices else 1 -_power_of_2_indices = [] -_pow2 = 1 -while _pow2 <= _max_abs_rot_idx: - _power_of_2_indices.append(_pow2) - _pow2 <<= 1 -_neg_power_of_2_indices = [-idx for idx in _power_of_2_indices] -_all_pow2_indices = _power_of_2_indices + _neg_power_of_2_indices + _data = x.polynomial if isinstance(x, Polynomial) else x + _m_in = _data.shape[-1] + if _m is None: + _m = _m_in + if _m > _m_in: + raise ValueError( + f"_ensure_poly: requested {_m} moduli but data only has {_m_in}" + ) -for _rot_idx in _all_pow2_indices: - {0}[_rot_idx] = key_gen.gen_rotation_key({2}, {3}.q_towers, {3}.p_towers, rot_index=_rot_idx, dnum={3}.parameters.get('dnum', 3)) + if level is not None: + _moduli = _cache.q_moduli_at_level(level) + else: + _moduli_src = getattr(x, "moduli", ctx.q_towers) + if isinstance(_moduli_src, (int, np.integer)): + _moduli_src = [int(_moduli_src)] + _moduli = list(_moduli_src)[:_m] -_degree_rot = {3}.parameters.get('degree') -if _degree_rot is not None: - _log_degree_rot = int(math.log2(_degree_rot)) - _default_r_rot = 1 << (_log_degree_rot // 2) - _default_c_rot = _degree_rot // _default_r_rot -else: - _default_r_rot = 4 - _default_c_rot = 4 -{4} = HERot( - r={3}.parameters.get('r', _default_r_rot), - c={3}.parameters.get('c', _default_c_rot), - dnum={3}.parameters.get('dnum', 3), - rotate_in_ciphertext_moduli={3}.q_towers, - extend_moduli={3}.p_towers -) -{4}.control_gen(batch=1, degree_layout=({3}.parameters.get('r', _default_r_rot), {3}.parameters.get('c', _default_c_rot))) -{5} = {0} -)python"; + # Return a fresh wrapper even when x is already tiled: emitted add/sub paths + # mutate the result object, so aliasing the source would violate SSA semantics. + _out = Polynomial( + { + "batch": _data.shape[0], + "num_elements": _data.shape[1], + "degree": ctx.degree, + "num_moduli": _m, + "precision": 32, + "degree_layout": (_r, _c), + }, + {"moduli": _moduli}, + ) + _out.polynomial = _data.reshape( + _data.shape[0], _data.shape[1], _r, _c, _m_in + )[..., :_m] + return _out -// Template for DecryptOp -// This template prepares the ciphertext for decryption by extracting the -// required moduli. -constexpr std::string_view kDecryptTemplate = R"python( -{0}.secret_key = {1} -_rescales = {2} -_num_moduli = len({0}.q_towers) - _rescales * {0}.composite_degree -_q_sub = {0}.q_towers[:_num_moduli] -_ct_for_dec = Polynomial( - {{'batch': 1, 'num_elements': 2, 'degree': {0}.degree, - 'precision': 32, 'num_moduli': _num_moduli, - 'degree_layout': ({0}.degree,)}}, - {{'moduli': _q_sub}}) -_ct_for_dec.set_batch_polynomial({3}.polynomial.reshape(1, 2, {0}.degree, _num_moduli)) -)python"; - -// Template for AddOp and AddInPlaceOp -// This template performs addition and modular reduction. -constexpr std::string_view kAddCoreTemplate = R"python( -{0}.add({1}) -{0}.ciphertext = jnp.where({0}.ciphertext >= {0}.moduli_array, {0}.ciphertext - {0}.moduli_array, {0}.ciphertext) -)python"; - -// Template for SubOp, SubInPlaceOp, and SubPlainOp -// This template performs subtraction and modular reduction. -// {1} should be rhs.ciphertext for Sub/SubInPlace and just rhs for SubPlain. -constexpr std::string_view kSubTemplate = R"python( -{0}.ciphertext = jnp.where({0}.ciphertext < {1}, {0}.ciphertext + {0}.moduli_array - {1}, {0}.ciphertext - {1}) -)python"; - -// Template for AddPlainOp modular reduction -constexpr std::string_view kAddModReduceTemplate = R"python( -{0}.ciphertext = jnp.where({0}.ciphertext >= {0}.moduli_array, {0}.ciphertext - {0}.moduli_array, {0}.ciphertext) -)python"; +def _assign_poly(dst, src): + for _attr in ( + "batch", + "num_elements", + "num_moduli", + "degree", + "precision", + "degree_layout", + "r", + "c", + "moduli", + "moduli_array", + "ntt_ctx", + "shape_in_ntt_all_limbs", + ): + if hasattr(src, _attr): + setattr(dst, _attr, getattr(src, _attr)) + dst.polynomial = src.polynomial + if hasattr(src, "extend_polynomial"): + dst.extend_polynomial = src.extend_polynomial -// Template for RelinOp -constexpr std::string_view kRelinTemplate = R"python( -{0} = {1}.he_mul[{2}].relinearize({3}) -_s = {0}.polynomial.shape -{0}.polynomial = {0}.polynomial.reshape(_s[0], _s[1], {0}.degree, _s[-1]) )python"; } // namespace jaxiteword diff --git a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir index 6ba1dcd271..2258a2f372 100644 --- a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir @@ -33,10 +33,15 @@ module attributes {ckks.schemeParam = #ckks.scheme_param diff --git a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir index e0d1df3ceb..431f13425f 100644 --- a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir @@ -32,6 +32,9 @@ module { } // CHECK: @simple_mul__generate_crypto_context +// CHECK-SAME: !jaxiteword.public_key +// CHECK-SAME: !jaxiteword.private_key +// CHECK-SAME: !jaxiteword.eval_key // CHECK: jaxiteword.gen_params // CHECK-SAME: batch = 1 : i32 // CHECK-SAME: c = 4 : i32 diff --git a/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir b/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir index cf1a999229..911993a2c8 100644 --- a/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir +++ b/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir @@ -15,7 +15,7 @@ func.func @test_add(%ctx: !jaxiteword.crypto_context<>, %ct1 : !ct_L1, %ct2 : !c } // CHECK: def test_mul( -// CHECK: hemul( +// CHECK: .he_mul[ func.func @test_mul(%ctx: !jaxiteword.crypto_context<>, %ct1 : !ct_L1, %ct2 : !ct_L1) -> !ct_L1 { %pk, %sk = jaxiteword.gen_keypair %ctx : (!jaxiteword.crypto_context<>) -> (!jaxiteword.public_key<>, !jaxiteword.private_key<>) %ek = jaxiteword.gen_mulkey %ctx, %sk : (!jaxiteword.crypto_context<>, !jaxiteword.private_key<>) -> !jaxiteword.eval_key<> @@ -32,8 +32,14 @@ func.func @test_mul_no_relin(%ctx: !jaxiteword.crypto_context<>, %ct1 : !ct_L1, // CHECK: def test_gen_params( // CHECK: "scaling_factor": 563019763943521 -func.func @test_gen_params() -> !jaxiteword.crypto_context<> { - %ctx = jaxiteword.gen_params { +// CHECK: "public_key": +// CHECK: "secret_key": +// CHECK: "evaluation_key": +func.func @test_gen_params( + %pk: !jaxiteword.public_key<>, + %sk: !jaxiteword.private_key<>, + %ek: !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> { + %ctx = jaxiteword.gen_params %pk, %sk, %ek { degree = 8192 : i64, numSlots = 4096 : i64, scalingFactor = 563019763943521.0 : f64, @@ -45,6 +51,6 @@ func.func @test_gen_params() -> !jaxiteword.crypto_context<> { dnum = 3 : i32, numEvalMult = 2 : i32, compositeDegree = 1 : i32 - } : () -> !jaxiteword.crypto_context<> + } : (!jaxiteword.public_key<>, !jaxiteword.private_key<>, !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> return %ctx : !jaxiteword.crypto_context<> }