From abe375dfa66e3f34a977f3fb860bf911c0e778e4 Mon Sep 17 00:00:00 2001 From: Zohaib58 Date: Wed, 17 Jun 2026 05:42:27 +0000 Subject: [PATCH 1/3] Thread JaxiteWord keys through context generation and fix emitted ciphertext ops Change jaxiteword.gen_params to take public, secret, and evaluation keys, and emit them into the generated ckks.CKKSParameters. Update the crypto-context configuration pass so __generate_crypto_context accepts those three keys and __configure_crypto_context only takes the crypto context. Teach LWE-to-JaxiteWord lowering to add crypto context/eval-key arguments to functions with crypto-typed arguments, update func.call operands for converted callees, and lower CKKS rotations when a dynamic shift is defined by an arith constant. Rework the JaxiteWord Python emitter around tiled Polynomial values: - emit func.call operations - add _ensure_poly and _assign_poly helpers - normalize ciphertext operands before add/sub/mul/square/rotate/rescale/decrypt - use he_mul[level].mul for relined multiplication - emit he_rescale for modulus reduction across levels - implement plaintext add/sub against the c0 limb - generate rotation keys via key_gen.gen_rotation_key - store key material in generated CKKS params instead of mutating program init Update JaxiteWord emitter/configure-context FileCheck tests and add 8x8 matvec MLIR/Python artifacts for validating the emitted JaxiteWord path. --- bazel-heir-private | 1 + lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td | 4 +- .../Transforms/ConfigureCryptoContext.cpp | 19 +- lib/Dialect/JaxiteWord/Transforms/Passes.td | 6 +- .../LWE/Conversions/LWEToJaxiteWord/BUILD | 1 + .../LWEToJaxiteWord/LWEToJaxiteWord.cpp | 160 +- lib/Target/JaxiteWord/JaxiteWordEmitter.cpp | 412 +- lib/Target/JaxiteWord/JaxiteWordEmitter.h | 31 +- lib/Target/JaxiteWord/JaxiteWordTemplates.h | 159 +- matvec_8x8.mlir | 105 + matvec_8x8_cross.py | 2503 ++++ matvec_8x8_jaxite.mlir | 465 + matvec_8x8_jaxiteword.mlir | 0 matvec_8x8_jaxiteword.py | 11122 ++++++++++++++++ .../Transforms/configure_crypto_context.mlir | 5 + .../configure_crypto_context_defaults.mlir | 3 + tests/Emitter/JaxiteWord/emit_jaxiteword.mlir | 14 +- 17 files changed, 14741 insertions(+), 269 deletions(-) create mode 120000 bazel-heir-private create mode 100644 matvec_8x8.mlir create mode 100644 matvec_8x8_cross.py create mode 100644 matvec_8x8_jaxite.mlir create mode 100644 matvec_8x8_jaxiteword.mlir create mode 100644 matvec_8x8_jaxiteword.py diff --git a/bazel-heir-private b/bazel-heir-private new file mode 120000 index 0000000000..1786659bb8 --- /dev/null +++ b/bazel-heir-private @@ -0,0 +1 @@ +/home/zohaib/.cache/bazel/_bazel_zohaib/f8ad823d70143bc66e0160e8a7bf9f07/execroot/_main \ No newline at end of file diff --git a/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td index a5a8c8dffb..f787ef719b 100644 --- a/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td +++ b/lib/Dialect/JaxiteWord/IR/JaxiteWordOps.td @@ -80,6 +80,9 @@ def GenParamsOp : JaxiteWord_Op<"gen_params"> { - numEvalMult: Number of evaluation multiplications }]; let arguments = (ins + JaxiteWord_PublicKey:$publicKey, + JaxiteWord_PrivateKey:$secretKey, + JaxiteWord_EvalKey:$evaluationKey, // Scheme parameters I64Attr:$degree, I64Attr:$numSlots, @@ -134,7 +137,6 @@ def ProgramInitializationOp : JaxiteWord_Op<"program_initialization"> { }]; let arguments = (ins JaxiteWord_CryptoContext:$cryptoContext, - JaxiteWord_PrivateKey:$secretKey, I64Attr:$totalHemulLevels, DenseI64ArrayAttr:$totalRotationIndices, I32Attr:$dnum, diff --git a/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp b/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp index 367cee033b..06b1064fba 100644 --- a/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp +++ b/lib/Dialect/JaxiteWord/Transforms/ConfigureCryptoContext.cpp @@ -90,15 +90,24 @@ struct ConfigureCryptoContext LogicalResult generateGenFunc(func::FuncOp op, const std::string& genFuncName, ImplicitLocOpBuilder& builder) { Type ccType = CryptoContextType::get(builder.getContext()); + Type pkType = PublicKeyType::get(builder.getContext()); + Type skType = PrivateKeyType::get(builder.getContext()); + Type ekType = EvalKeyType::get(builder.getContext()); + + SmallVector funcArgTypes = {pkType, skType, ekType}; SmallVector funcResultTypes = {ccType}; FunctionType genFuncType = - FunctionType::get(builder.getContext(), {}, funcResultTypes); + FunctionType::get(builder.getContext(), funcArgTypes, funcResultTypes); auto genFuncOp = func::FuncOp::create(builder, genFuncName, genFuncType); builder.setInsertionPointToEnd(genFuncOp.addEntryBlock()); + Value publicKey = genFuncOp.getArgument(0); + Value secretKey = genFuncOp.getArgument(1); + Value evaluationKey = genFuncOp.getArgument(2); + Value cryptoContext = GenParamsOp::create( - builder, ccType, + builder, ccType, publicKey, secretKey, evaluationKey, /*degree=*/static_cast(config.degree), /*numSlots=*/static_cast(config.numSlots), /*scalingFactor=*/llvm::APFloat(config.scalingFactor), @@ -119,9 +128,8 @@ struct ConfigureCryptoContext const std::string& configFuncName, ImplicitLocOpBuilder& builder) { Type ccType = CryptoContextType::get(builder.getContext()); - Type skType = PrivateKeyType::get(builder.getContext()); - SmallVector funcArgTypes = {ccType, skType}; + SmallVector funcArgTypes = {ccType}; SmallVector funcResultTypes; FunctionType configFuncType = @@ -131,10 +139,9 @@ struct ConfigureCryptoContext builder.setInsertionPointToEnd(configFuncOp.addEntryBlock()); Value cryptoContext = configFuncOp.getArgument(0); - Value secretKey = configFuncOp.getArgument(1); ProgramInitializationOp::create( - builder, cryptoContext, secretKey, + builder, cryptoContext, /*totalHemulLevels=*/static_cast(config.mulDepth), /*totalRotationIndices=*/config.rotIndices, /*dnum=*/config.dnum, diff --git a/lib/Dialect/JaxiteWord/Transforms/Passes.td b/lib/Dialect/JaxiteWord/Transforms/Passes.td index f9ab1ebdba..9112a118ae 100644 --- a/lib/Dialect/JaxiteWord/Transforms/Passes.td +++ b/lib/Dialect/JaxiteWord/Transforms/Passes.td @@ -16,10 +16,12 @@ def ConfigureCryptoContext : Pass<"jaxiteword-configure-crypto-context"> { For example, for an MLIR function `@my_func`, the generated helpers have the following signatures: ```mlir - func.func @my_func__generate_crypto_context() -> !jaxiteword.crypto_context + func.func @my_func__generate_crypto_context( + !jaxiteword.public_key, !jaxiteword.private_key, + !jaxiteword.eval_key) -> !jaxiteword.crypto_context func.func @my_func__configure_crypto_context( - !jaxiteword.crypto_context, !jaxiteword.private_key) + !jaxiteword.crypto_context) ``` }]; let dependentDialects = ["mlir::heir::jaxiteword::JaxiteWordDialect"]; diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD index 49ee33ca34..e69178735d 100644 --- a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/BUILD @@ -18,6 +18,7 @@ cc_library( "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp index 8949062aa8..2cc6ed89da 100644 --- a/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToJaxiteWord/LWEToJaxiteWord.cpp @@ -15,6 +15,7 @@ #include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Utils/ConversionUtils.h" #include "lib/Utils/Utils.h" +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -53,6 +54,60 @@ class JaxiteWordTypeConverter : public TypeConverter { namespace { +bool containsCryptoArgument(func::FuncOp funcOp) { + return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) { + return DialectEqual()( + &getElementTypeOrSelf(argType).getDialect()); + }); +} + +bool funcNeedsCryptoContextAndKeys(func::FuncOp funcOp) { + return containsDialects( + funcOp) || + containsCryptoArgument(funcOp); +} + +void insertCryptoContextAndKeys(func::FuncOp funcOp) { + if (!funcNeedsCryptoContextAndKeys(funcOp)) return; + if (funcOp.getFunctionType().getNumInputs() >= 2 && + mlir::isa( + funcOp.getFunctionType().getInput(0)) && + mlir::isa( + funcOp.getFunctionType().getInput(1))) { + return; + } + auto cryptoContextType = + jaxiteword::CryptoContextType::get(funcOp.getContext()); + auto evalKeyType = jaxiteword::EvalKeyType::get(funcOp.getContext()); + (void)funcOp.insertArgument(0, evalKeyType, nullptr, funcOp.getLoc()); + (void)funcOp.insertArgument(0, cryptoContextType, nullptr, funcOp.getLoc()); +} + +void updateCryptoFuncCalls(Operation* op) { + op->walk([&](func::CallOp callOp) { + auto callee = getCalledFunction(callOp); + if (failed(callee) || !funcNeedsCryptoContextAndKeys(callee.value())) { + return; + } + if (callOp.getNumOperands() == callee.value().getNumArguments()) { + return; + } + auto caller = callOp->getParentOfType(); + if (!caller || caller.getNumArguments() < 2 || + !mlir::isa( + caller.getArgument(0).getType()) || + !mlir::isa(caller.getArgument(1).getType())) { + return; + } + SmallVector newOperands; + newOperands.push_back(caller.getArgument(0)); + newOperands.push_back(caller.getArgument(1)); + newOperands.append(callOp.getOperands().begin(), + callOp.getOperands().end()); + callOp->setOperands(newOperands); + }); +} + FailureOr getContextualCryptoContextForJaxiteWord(Operation* op) { auto funcOp = op->getParentOfType(); if (!funcOp) return failure(); @@ -75,6 +130,25 @@ FailureOr getContextualEvalKeyForJaxiteWord(Operation* op) { return funcOp.getArgument(1); } +static FailureOr getStaticRotationIndex(ckks::RotateOp op, + Value dynamicShift) { + auto i64Type = IntegerType::get(op.getContext(), 64); + if (IntegerAttr staticShift = op.getStaticShiftAttr()) { + return IntegerAttr::get(i64Type, staticShift.getValue().getSExtValue()); + } + if (!dynamicShift) { + return failure(); + } + auto constOp = dynamicShift.getDefiningOp(); + if (!constOp) { + return failure(); + } + if (auto intAttr = dyn_cast(constOp.getValue())) { + return IntegerAttr::get(i64Type, intAttr.getValue().getSExtValue()); + } + return failure(); +} + struct AddCryptoContextAndKeys : public OpConversionPattern { AddCryptoContextAndKeys(mlir::MLIRContext* context) : OpConversionPattern(context, /* benefit= */ 2) {} @@ -84,10 +158,7 @@ struct AddCryptoContextAndKeys : public OpConversionPattern { LogicalResult matchAndRewrite( func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto containsCryptoOps = - ::mlir::heir::containsDialects(op); - if (!containsCryptoOps) return failure(); + if (!funcNeedsCryptoContextAndKeys(op)) return failure(); auto cryptoContextType = jaxiteword::CryptoContextType::get(getContext()); auto evalKeyType = jaxiteword::EvalKeyType::get(getContext()); @@ -108,6 +179,43 @@ struct AddCryptoContextAndKeys : public OpConversionPattern { } }; +struct ConvertFuncCallOp : public OpConversionPattern { + ConvertFuncCallOp(mlir::MLIRContext* context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::CallOp op, typename func::CallOp::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto callee = getCalledFunction(op); + if (failed(callee) || !funcNeedsCryptoContextAndKeys(callee.value())) { + return failure(); + } + if (op.getNumOperands() >= callee.value().getNumArguments()) { + return failure(); + } + + FailureOr ctx = getContextualCryptoContextForJaxiteWord(op); + if (failed(ctx)) return failure(); + FailureOr evalKey = getContextualEvalKeyForJaxiteWord(op); + if (failed(evalKey)) return failure(); + + SmallVector newOperands; + newOperands.push_back(ctx.value()); + newOperands.push_back(evalKey.value()); + newOperands.append(adaptor.getOperands().begin(), + adaptor.getOperands().end()); + + SmallVector dialectAttrs(op->getDialectAttrs()); + rewriter + .replaceOpWithNewOp(op, op.getCallee(), + op.getResultTypes(), newOperands) + ->setDialectAttrs(dialectAttrs); + return success(); + } +}; + template struct ConvertBinOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -173,20 +281,16 @@ struct ConvertRotateOp : public OpConversionPattern { FailureOr evalKey = getContextualEvalKeyForJaxiteWord(op); if (failed(evalKey)) return failure(); - Value dynamicShift = adaptor.getDynamicShift(); - IntegerAttr staticShift = op.getStaticShiftAttr(); - if (!staticShift && !dynamicShift) { + FailureOr indexAttr = + getStaticRotationIndex(op, adaptor.getDynamicShift()); + if (failed(indexAttr)) { return rewriter.notifyMatchFailure( - op, "rotate op must have either static or dynamic shift"); - } - if (dynamicShift) { - return rewriter.notifyMatchFailure( - op, "jaxiteword rotation requires static shift"); + op, "jaxiteword rotation requires statically known shift"); } rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getOutput().getType()), - ctx.value(), adaptor.getInput(), evalKey.value(), staticShift); + ctx.value(), adaptor.getInput(), evalKey.value(), indexAttr.value()); return success(); } }; @@ -298,6 +402,9 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { MLIRContext* context = &getContext(); Operation* op = getOperation(); + op->walk([&](func::FuncOp funcOp) { insertCryptoContextAndKeys(funcOp); }); + updateCryptoFuncCalls(op); + RewritePatternSet patterns(context); ConversionTarget target(*context); @@ -309,17 +416,25 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { JaxiteWordTypeConverter typeConverter(context); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - auto containsCryptoOps = - ::mlir::heir::containsDialects(op); - if (!containsCryptoOps) return true; - bool hasArgs = op.getFunctionType().getNumInputs() >= 2; - return typeConverter.isSignatureLegal(op.getFunctionType()) && hasArgs && + target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { + if (!funcNeedsCryptoContextAndKeys(funcOp)) return true; + bool hasArgs = funcOp.getFunctionType().getNumInputs() >= 2; + return typeConverter.isSignatureLegal(funcOp.getFunctionType()) && + hasArgs && mlir::isa( - op.getFunctionType().getInput(0)) && + funcOp.getFunctionType().getInput(0)) && mlir::isa( - op.getFunctionType().getInput(1)); + funcOp.getFunctionType().getInput(1)); + }); + + target.addDynamicallyLegalOp([&](func::CallOp callOp) { + if (auto callee = getCalledFunction(callOp); succeeded(callee)) { + if (funcNeedsCryptoContextAndKeys(callee.value())) { + return callOp.getNumOperands() == + callOp.getCalleeType().getNumInputs(); + } + } + return true; }); populateFunctionOpInterfaceTypeConversionPattern( @@ -327,6 +442,7 @@ struct LWEToJaxiteWord : public impl::LWEToJaxiteWordBase { addTensorConversionPatterns(typeConverter, patterns, target); patterns.add(typeConverter, context); + patterns.add(context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, diff --git a/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp b/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp index 0122cb4517..381e6f7b4c 100644 --- a/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp +++ b/lib/Target/JaxiteWord/JaxiteWordEmitter.cpp @@ -30,6 +30,7 @@ #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project @@ -115,7 +116,7 @@ LogicalResult JaxiteWordEmitter::translate(Operation& op) { LogicalResult status = llvm::TypeSwitch(op) .Case([&](auto op) { return printOperation(op); }) - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Case(arg.getType())) { - CiphertextArg_ = argName; - } } os.unindent(); os << ")"; @@ -213,6 +212,26 @@ LogicalResult JaxiteWordEmitter::printOperation(func::FuncOp funcOp) { return success(); } +LogicalResult JaxiteWordEmitter::printOperation(func::CallOp op) { + if (op.getNumResults() == 1) { + emitAssignPrefix(op.getResult(0)); + } else if (op.getNumResults() > 1) { + os << "("; + for (auto [idx, result] : llvm::enumerate(op.getResults())) { + if (idx > 0) os << ", "; + os << variableNames->getNameForValue(result); + } + os << ") = "; + } + + os << op.getCallee().str() << "("; + os << commaSeparatedValues(op.getOperands(), [&](Value value) { + return variableNames->getNameForValue(value); + }); + os << ")\n"; + return success(); +} + LogicalResult JaxiteWordEmitter::printOperation(func::ReturnOp op) { std::function resultValue = [&](Value value) { if (isa(value)) { @@ -236,38 +255,56 @@ LogicalResult JaxiteWordEmitter::printOperation(func::ReturnOp op) { } LogicalResult JaxiteWordEmitter::printOperation(AddOp op) { - return printBinaryOpHelper( - op.getResult(), op.getLhs(), op.getRhs(), - [&](StringRef lhs, StringRef rhs, StringRef result) { - os << lhs << "\n"; - os << llvm::formatv(kAddCoreTemplate.data(), result, rhs); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + auto result = variableNames->getNameForValue(op.getResult()); + + emitModularAdd(result, ctx, lhs, rhs); + return success(); } LogicalResult JaxiteWordEmitter::printOperation(SubOp op) { - return printBinaryOpHelper( - op.getResult(), op.getLhs(), op.getRhs(), - [&](StringRef lhs, StringRef rhs, StringRef result) { - os << lhs << "\n"; - std::string rhsCiphertext = (rhs + ".ciphertext").str(); - os << llvm::formatv(kSubTemplate.data(), result, rhsCiphertext); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string rhsWork = (result + "_rhs"); + + emitNormalizeCiphertext(result, ctx, lhs); + emitNormalizeCiphertext(rhsWork, ctx, rhs); + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << result << ".polynomial = jnp.where(" << result << ".polynomial < " + << rhsWork << ".polynomial, " << result << ".polynomial + _moduli - " + << rhsWork << ".polynomial, " << result << ".polynomial - " << rhsWork + << ".polynomial)\n"; + return success(); } LogicalResult JaxiteWordEmitter::printOperation(NegateOp op) { - emitAssignPrefix(op.getResult()); - os << variableNames->getNameForValue(op.getCiphertext()) << ".mul(-1)\n"; + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto ct = variableNames->getNameForValue(op.getCiphertext()); + auto result = variableNames->getNameForValue(op.getResult()); + + emitNormalizeCiphertext(result, ctx, ct); + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << result << ".polynomial = jnp.where(" << result << ".polynomial == 0, " + << result << ".polynomial, _moduli - " << result << ".polynomial)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(SquareOp op) { auto ct = variableNames->getNameForValue(op.getCiphertext()); auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto result = variableNames->getNameForValue(op.getResult()); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/1); + std::string inputLevel = ("(" + level + ") + 1"); + std::string ctWork = (result + "_arg"); - emitAssignPrefix(op.getResult()); - os << ctx << ".he_mul[" << level << "].mul(" << ct << ", " << ct << ")\n"; + emitNormalizeCiphertext(ctWork, ctx, ct, inputLevel); + os << result << " = " << ctx << ".he_mul[" << level << "].mul(" << ctWork + << ", " << ctWork << ")\n"; return success(); } @@ -275,26 +312,72 @@ void JaxiteWordEmitter::emitAssignPrefix(Value result) { os << variableNames->getNameForValue(result) << " = "; } -LogicalResult JaxiteWordEmitter::printBinaryOpHelper( - Value result, Value lhs, Value rhs, - llvm::function_ref callback) { - auto lhsName = variableNames->getNameForValue(lhs); - auto rhsName = variableNames->getNameForValue(rhs); - auto resultName = variableNames->getNameForValue(result); +void JaxiteWordEmitter::emitAssignCiphertext(StringRef targetName, + StringRef sourceName) { + os << "_assign_poly(" << targetName << ", " << sourceName << ")\n"; +} - emitAssignPrefix(result); - callback(lhsName, rhsName, resultName); - return success(); +void JaxiteWordEmitter::emitNormalizeCiphertext(StringRef resultName, + StringRef ctxName, + StringRef sourceName, + StringRef levelExpr) { + os << resultName << " = _ensure_poly(" << ctxName << ", " << sourceName; + if (!levelExpr.empty()) os << ", " << levelExpr; + os << ")\n"; } -LogicalResult JaxiteWordEmitter::printInPlaceBinaryOpHelper( - Value lhs, Value rhs, - llvm::function_ref callback) { - auto lhsName = variableNames->getNameForValue(lhs); - auto rhsName = variableNames->getNameForValue(rhs); +void JaxiteWordEmitter::emitModularAdd(StringRef resultName, StringRef ctxName, + StringRef lhsName, StringRef rhsName) { + std::string lhsData = (resultName + "_lhs").str(); + std::string rhsData = (resultName + "_rhs").str(); + std::string numModuli = (resultName + "_num_moduli").str(); + std::string moduliSrc = (resultName + "_moduli_src").str(); + std::string moduli = (resultName + "_moduli").str(); + std::string sum = (resultName + "_sum").str(); + + os << lhsData << " = " << lhsName << ".polynomial if hasattr(" << lhsName + << ", \"polynomial\") else " << lhsName << "\n"; + os << rhsData << " = " << rhsName << ".polynomial if hasattr(" << rhsName + << ", \"polynomial\") else " << rhsName << "\n"; + os << lhsData << " = " << lhsData << ".reshape(" << lhsData << ".shape[0], " + << lhsData << ".shape[1], " << ctxName << "._param_cache.r, " << ctxName + << "._param_cache.c, " << lhsData << ".shape[-1])\n"; + os << rhsData << " = " << rhsData << ".reshape(" << rhsData << ".shape[0], " + << rhsData << ".shape[1], " << ctxName << "._param_cache.r, " << ctxName + << "._param_cache.c, " << rhsData << ".shape[-1])\n"; + os << "if " << lhsData << ".shape != " << rhsData << ".shape:\n"; + os.indent(); + os << "raise ValueError(\"ciphertext add shape mismatch\")\n"; + os.unindent(); + os << numModuli << " = " << lhsData << ".shape[-1]\n"; + os << "if hasattr(" << lhsName << ", \"moduli\") and hasattr(" << rhsName + << ", \"moduli\"):\n"; + os.indent(); + os << "if list(" << lhsName << ".moduli)[:" << numModuli << "] != list(" + << rhsName << ".moduli)[:" << numModuli << "]:\n"; + os.indent(); + os << "raise ValueError(\"ciphertext add modulus mismatch\")\n"; + os.unindent(); + os.unindent(); + os << moduliSrc << " = getattr(" << lhsName << ", \"moduli\", getattr(" + << rhsName << ", \"moduli\", " << ctxName << ".q_towers))\n"; + os << "if isinstance(" << moduliSrc << ", (int, np.integer)):\n"; + os.indent(); + os << moduliSrc << " = [" << moduliSrc << "]\n"; + os.unindent(); + os << moduli << " = jnp.array(list(" << moduliSrc << ")[:" << numModuli + << "], dtype=jnp.uint64)\n"; + os << sum << " = " << lhsData << ".astype(jnp.uint64) + " << rhsData + << ".astype(jnp.uint64)\n"; + os << resultName << " = jnp.where(" << sum << " >= " << moduli << ", " << sum + << " - " << moduli << ", " << sum << ").astype(jnp.uint32)\n"; +} - callback(lhsName, rhsName); - return success(); +void JaxiteWordEmitter::emitModularReduce(StringRef targetName) { + os << "_moduli = jnp.array(" << targetName << ".moduli, dtype=jnp.uint32)\n"; + os << targetName << ".polynomial = jnp.where(" << targetName + << ".polynomial >= _moduli, " << targetName << ".polynomial - _moduli, " + << targetName << ".polynomial)\n"; } LogicalResult JaxiteWordEmitter::printMulOpHelper( @@ -308,7 +391,6 @@ LogicalResult JaxiteWordEmitter::printMulOpHelper( auto resultName = variableNames->getNameForValue(result); auto level = getCrossLevelExpr(lhs, ctxName, op, /*extraOffset=*/1); - emitAssignPrefix(result); callback(lhsName, rhsName, ctxName, resultName, level); return success(); } @@ -323,64 +405,96 @@ LogicalResult JaxiteWordEmitter::printOperation(EncodeOp op) { LogicalResult JaxiteWordEmitter::printOperation(EncryptOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto pk = variableNames->getNameForValue(op.getPublicKey()); - os << ctx << ".public_key = " << pk << "\n"; + auto pt = variableNames->getNameForValue(op.getPlaintext()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string raw = result + "_raw"; - emitAssignPrefix(op.getResult()); - os << ctx << ".encrypt(" << variableNames->getNameForValue(op.getPlaintext()) - << ")\n"; + os << ctx << ".public_key = " << pk << "\n"; + os << raw << " = " << ctx << ".encrypt(" << pt << ")\n"; + emitNormalizeCiphertext(result, ctx, raw); return success(); } LogicalResult JaxiteWordEmitter::printOperation(MulOp op) { - return printMulOpHelper(op.getResult(), op.getLhs(), op.getRhs(), - op.getCryptoContext(), op.getOperation(), - [&](StringRef lhs, StringRef rhs, StringRef ctx, - StringRef result, StringRef level) { - os << ctx << ".he_mul[" << level << "].hemul(" - << lhs << ", " << rhs << ")\n"; - }); + return printMulOpHelper( + op.getResult(), op.getLhs(), op.getRhs(), op.getCryptoContext(), + op.getOperation(), + [&](StringRef lhs, StringRef rhs, StringRef ctx, StringRef result, + StringRef level) { + std::string inputLevel = ("(" + level + ") + 1").str(); + std::string lhsWork = (result + "_lhs").str(); + std::string rhsWork = (result + "_rhs").str(); + emitNormalizeCiphertext(lhsWork, ctx, lhs, inputLevel); + emitNormalizeCiphertext(rhsWork, ctx, rhs, inputLevel); + os << result << " = " << ctx << ".he_mul[" << level << "].mul(" + << lhsWork << ", " << rhsWork << ")\n"; + }); } LogicalResult JaxiteWordEmitter::printOperation(MulNoRelinOp op) { - return printMulOpHelper(op.getResult(), op.getLhs(), op.getRhs(), - op.getCryptoContext(), op.getOperation(), - [&](StringRef lhs, StringRef rhs, StringRef ctx, - StringRef result, StringRef level) { - os << ctx << ".he_mul[" << level - << "].hemul_no_relin(" << lhs << ", " << rhs - << ")\n"; - }); + return printMulOpHelper( + op.getResult(), op.getLhs(), op.getRhs(), op.getCryptoContext(), + op.getOperation(), + [&](StringRef lhs, StringRef rhs, StringRef ctx, StringRef result, + StringRef level) { + std::string inputLevel = ("(" + level + ") + 1").str(); + std::string lhsWork = (result + "_lhs").str(); + std::string rhsWork = (result + "_rhs").str(); + std::string raw = (result + "_raw").str(); + emitNormalizeCiphertext(lhsWork, ctx, lhs, inputLevel); + emitNormalizeCiphertext(rhsWork, ctx, rhs, inputLevel); + os << raw << " = " << ctx << ".he_mul[" << level << "].hemul_no_relin(" + << lhsWork << ", " << rhsWork << ")\n"; + emitNormalizeCiphertext(result, ctx, raw, level); + }); } LogicalResult JaxiteWordEmitter::printOperation(RelinOp op) { auto ct = variableNames->getNameForValue(op.getCiphertext()); auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto result = variableNames->getNameForValue(op.getOutput()); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/1); + std::string ctData = result + "_ct_data"; - auto result = variableNames->getNameForValue(op.getOutput()); - os << llvm::formatv(kRelinTemplate.data(), result, ctx, level, ct); - + os << ctData << " = " << ct << ".polynomial if hasattr(" << ct + << ", \"polynomial\") else " << ct << "\n"; + os << result << " = " << ctx << ".he_mul[" << level << "].relinearize(" + << ctData << ")\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(ModReduceOp op) { + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto ct = variableNames->getNameForValue(op.getCiphertext()); - emitAssignPrefix(op.getResult()); - os << ct << "\n"; + auto result = variableNames->getNameForValue(op.getResult()); + auto srcLevel = + getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/0); + auto dstLevel = getCrossLevelExpr(op.getResult(), ctx, op, /*extraOffset=*/0); + std::string ctWork = result + "_arg"; + + emitNormalizeCiphertext(ctWork, ctx, ct, srcLevel); + if (srcLevel == dstLevel) { + emitNormalizeCiphertext(result, ctx, ctWork, dstLevel); + return success(); + } + os << result << " = " << ctx << ".he_rescale[" << srcLevel << ", " << dstLevel + << "](" << ctWork << ")\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(RotOp op) { auto ct = variableNames->getNameForValue(op.getCiphertext()); auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto result = variableNames->getNameForValue(op.getResult()); auto rotIndex = op.getIndex(); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/0); + std::string ctWork = result + "_arg"; - emitAssignPrefix(op.getResult()); - os << ctx << ".he_rot[" << level << ", " << rotIndex << "].rotate(" << ct - << ")\n"; + emitNormalizeCiphertext(ctWork, ctx, ct, level); + os << result << " = " << ctx << ".he_rot[" << level << ", " << rotIndex + << "].rotate(" << ctWork << ")\n"; return success(); } @@ -388,14 +502,23 @@ LogicalResult JaxiteWordEmitter::printOperation(DecryptOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto sk = variableNames->getNameForValue(op.getSecretKey()); auto ct = variableNames->getNameForValue(op.getCiphertext()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string ctWork = result + "_ct"; - auto ctType = cast(op.getCiphertext().getType()); - int current = ctType.getModulusChain().getCurrent(); - int maxCurrent = getMaxCurrentInModule(op); - int rescales = maxCurrent - current; - - os << llvm::formatv(kDecryptTemplate.data(), ctx, sk, rescales, ct); - + os << ctx << ".secret_key = " << sk << "\n"; + emitNormalizeCiphertext(ctWork, ctx, ct); + os << "_num_moduli = " << ctWork << ".polynomial.shape[-1]\n"; + os << "_q_sub = list(getattr(" << ctWork << ", \"moduli\", " << ctx + << ".q_towers))[:_num_moduli]\n"; + os << "_ct_for_dec = Polynomial({\"batch\": " << ctWork + << ".polynomial.shape[0], \"num_elements\": " << ctWork + << ".polynomial.shape[1], \"degree\": " << ctx + << ".degree, \"precision\": 32, \"num_moduli\": _num_moduli, " + "\"degree_layout\": (" + << ctx << ".degree,)}, {\"moduli\": _q_sub})\n"; + os << "_ct_for_dec.set_batch_polynomial(" << ctWork << ".polynomial.reshape(" + << ctWork << ".polynomial.shape[0], " << ctWork << ".polynomial.shape[1], " + << ctx << ".degree, _num_moduli))\n"; emitAssignPrefix(op.getResult()); os << ctx << ".decrypt(_ct_for_dec)\n"; return success(); @@ -424,48 +547,114 @@ LogicalResult JaxiteWordEmitter::printOperation(MulPlainOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto ct = variableNames->getNameForValue(op.getCiphertext()); auto pt = variableNames->getNameForValue(op.getPlaintext()); + auto result = variableNames->getNameForValue(op.getResult()); auto level = getCrossLevelExpr(op.getCiphertext(), ctx, op, /*extraOffset=*/0); - - os << ctx << ".ptct_mul[" << level << "].set_plaintext(" << pt << ")\n"; - emitAssignPrefix(op.getResult()); - os << ctx << ".ptct_mul[" << level << "].mul(" << ct << ")\n"; + std::string ctWork = result + "_arg"; + std::string ptNtt = result + "_pt_ntt"; + std::string opName = result + "_ptct"; + + emitNormalizeCiphertext(ctWork, ctx, ct, level); + os << ptNtt << " = " << pt << ".polynomial[0, 0, :, :" << ctWork + << ".polynomial.shape[-1]].reshape(" << ctWork << ".r, " << ctWork + << ".c, " << ctWork << ".polynomial.shape[-1]).astype(jnp.uint32)\n"; + os << opName << " = " << ctx << ".ptct_mul[" << level << "]\n"; + os << opName << ".set_plaintext(" << ptNtt << ")\n"; + os << result << " = " << opName << ".mul(" << ctWork << ", use_bat=False)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(AddPlainOp op) { - auto lhs = variableNames->getNameForValue(op.getLhs()); - auto rhs = variableNames->getNameForValue(op.getRhs()); - os << lhs << ".ciphertext = " << lhs << ".ciphertext + " << rhs << "\n"; - os << llvm::formatv(kAddModReduceTemplate.data(), lhs); - - emitAssignPrefix(op.getResult()); - os << lhs << "\n"; + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + bool lhsPlain = + isa(getElementTypeOrSelf(op.getLhs().getType())); + auto ct = + variableNames->getNameForValue(lhsPlain ? op.getRhs() : op.getLhs()); + auto pt = + variableNames->getNameForValue(lhsPlain ? op.getLhs() : op.getRhs()); + auto result = variableNames->getNameForValue(op.getResult()); + std::string numModuli = result + "_m"; + std::string ptData = result + "_pt_data"; + + emitNormalizeCiphertext(result, ctx, ct); + os << numModuli << " = " << result << ".polynomial.shape[-1]\n"; + os << ptData << " = " << pt << ".polynomial[0:1, 0:1, :, :" << numModuli + << "].reshape(1, 1, " << result << ".r, " << result << ".c, " << numModuli + << ")\n"; + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << "_c0 = " << result << ".polynomial[:, 0:1, ...] + " << ptData << "\n"; + os << "_c0 = jnp.where(_c0 >= _moduli, _c0 - _moduli, _c0)\n"; + os << result << ".polynomial = " << result + << ".polynomial.at[:, 0:1, ...].set(_c0)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(SubPlainOp op) { + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + bool lhsPlain = + isa(getElementTypeOrSelf(op.getLhs().getType())); auto lhs = variableNames->getNameForValue(op.getLhs()); auto rhs = variableNames->getNameForValue(op.getRhs()); - os << llvm::formatv(kSubTemplate.data(), lhs, rhs); - emitAssignPrefix(op.getResult()); - os << lhs << "\n"; + auto ct = lhsPlain ? rhs : lhs; + auto pt = lhsPlain ? lhs : rhs; + auto result = variableNames->getNameForValue(op.getResult()); + std::string numModuli = result + "_m"; + std::string ptData = result + "_pt_data"; + + emitNormalizeCiphertext(result, ctx, ct); + os << numModuli << " = " << result << ".polynomial.shape[-1]\n"; + os << ptData << " = " << pt << ".polynomial[0:1, 0:1, :, :" << numModuli + << "].reshape(1, 1, " << result << ".r, " << result << ".c, " << numModuli + << ")\n"; + os << "_moduli = jnp.array(" << result << ".moduli, dtype=jnp.uint32)\n"; + os << "_c0 = " << result << ".polynomial[:, 0:1, ...]\n"; + if (lhsPlain) { + os << "_c0 = jnp.where(" << ptData << " < _c0, " << ptData + << " + _moduli - _c0, " << ptData << " - _c0)\n"; + os << "_c1 = " << result << ".polynomial[:, 1:2, ...]\n"; + os << "_c1 = jnp.where(_c1 == 0, _c1, _moduli - _c1)\n"; + os << result << ".polynomial = " << result + << ".polynomial.at[:, 1:2, ...].set(_c1)\n"; + } else { + os << "_c0 = jnp.where(_c0 < " << ptData << ", _c0 + _moduli - " << ptData + << ", _c0 - " << ptData << ")\n"; + } + os << result << ".polynomial = " << result + << ".polynomial.at[:, 0:1, ...].set(_c0)\n"; return success(); } LogicalResult JaxiteWordEmitter::printOperation(AddInPlaceOp op) { - return printInPlaceBinaryOpHelper( - op.getLhs(), op.getRhs(), [&](StringRef lhs, StringRef rhs) { - os << llvm::formatv(kAddCoreTemplate.data(), lhs, rhs); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + std::string tmp = lhs + "_inplace"; + std::string rhsWork = lhs + "_rhs"; + + emitNormalizeCiphertext(tmp, ctx, lhs); + emitNormalizeCiphertext(rhsWork, ctx, rhs); + os << tmp << ".add(" << rhsWork << ")\n"; + emitModularReduce(tmp); + emitAssignCiphertext(lhs, tmp); + return success(); } LogicalResult JaxiteWordEmitter::printOperation(SubInPlaceOp op) { - return printInPlaceBinaryOpHelper( - op.getLhs(), op.getRhs(), [&](StringRef lhs, StringRef rhs) { - std::string rhsCiphertext = (rhs + ".ciphertext").str(); - os << llvm::formatv(kSubTemplate.data(), lhs, rhsCiphertext); - }); + auto ctx = variableNames->getNameForValue(op.getCryptoContext()); + auto lhs = variableNames->getNameForValue(op.getLhs()); + auto rhs = variableNames->getNameForValue(op.getRhs()); + std::string tmp = lhs + "_inplace"; + std::string rhsWork = lhs + "_rhs"; + + emitNormalizeCiphertext(tmp, ctx, lhs); + emitNormalizeCiphertext(rhsWork, ctx, rhs); + os << "_moduli = jnp.array(" << tmp << ".moduli, dtype=jnp.uint32)\n"; + os << tmp << ".polynomial = jnp.where(" << tmp << ".polynomial < " << rhsWork + << ".polynomial, " << tmp << ".polynomial + _moduli - " << rhsWork + << ".polynomial, " << tmp << ".polynomial - " << rhsWork + << ".polynomial)\n"; + emitAssignCiphertext(lhs, tmp); + return success(); } LogicalResult JaxiteWordEmitter::printOperation(GenKeyPairOp op) { @@ -489,9 +678,6 @@ LogicalResult JaxiteWordEmitter::printOperation(GenMulKeyOp op) { << ".q_towers, " << "P=" << ctx << ".p_towers, " << "dnum=" << ctx << ".parameters.get('dnum', 3)" << ")\n"; - heMulVarName_ = "he_mul"; - os << llvm::formatv(kGenMulKeyTemplate.data(), ctx, heMulVarName_, ek); - return success(); } @@ -500,20 +686,24 @@ LogicalResult JaxiteWordEmitter::printOperation(GenRotKeyOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); auto sk = variableNames->getNameForValue(op.getSecretKey()); - rotKeysDictVarName_ = rk + "_dict"; - std::string indicesStr; - llvm::raw_string_ostream indicesOs(indicesStr); - llvm::interleaveComma(op.getIndices(), indicesOs); - - os << llvm::formatv(kGenRotKeyTemplate.data(), rotKeysDictVarName_, - indicesStr, sk, ctx, heRotVarName_, rk); + os << rk << " = {}\n"; + os << "for _rot_idx in ["; + llvm::interleaveComma(op.getIndices(), os); + os << "]:\n"; + os.indent(); + os << rk << "[_rot_idx] = key_gen.gen_rotation_key(" << sk << ", " << ctx + << ".q_towers, " << ctx << ".p_towers, rot_index=_rot_idx, dnum=" << ctx + << ".parameters.get('dnum', 3))[_rot_idx]\n"; + os.unindent(); return success(); } LogicalResult JaxiteWordEmitter::printOperation(GenParamsOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); - cryptoContextVarName_ = ctx; + auto publicKey = variableNames->getNameForValue(op.getPublicKey()); + auto secretKey = variableNames->getNameForValue(op.getSecretKey()); + auto evaluationKey = variableNames->getNameForValue(op.getEvaluationKey()); os << "params = {\n"; os.indent(); @@ -543,7 +733,10 @@ LogicalResult JaxiteWordEmitter::printOperation(GenParamsOp op) { os << "\"max_bits_in_word\": 61,\n"; os << "\"max_bits_value\": " << ((1ULL << 63) - (1ULL << 9) - 1) << ",\n"; os << "\"noise_scale_degree\": 1,\n"; - os << "\"CKKS_M_FACTOR\": 1\n"; + os << "\"CKKS_M_FACTOR\": 1,\n"; + os << "\"public_key\": " << publicKey << ",\n"; + os << "\"secret_key\": " << secretKey << ",\n"; + os << "\"evaluation_key\": " << evaluationKey << "\n"; os.unindent(); os << "}\n"; @@ -553,9 +746,7 @@ LogicalResult JaxiteWordEmitter::printOperation(GenParamsOp op) { LogicalResult JaxiteWordEmitter::printOperation(ProgramInitializationOp op) { auto ctx = variableNames->getNameForValue(op.getCryptoContext()); - auto sk = variableNames->getNameForValue(op.getSecretKey()); - os << ctx << ".secret_key = " << sk << "\n"; os << ctx << ".program_initialization("; os << "total_hemul_levels=" << op.getTotalHemulLevels() << ", "; @@ -682,7 +873,8 @@ LogicalResult JaxiteWordEmitter::printOperation(tensor::EmptyOp op) { return success(); } - os << " = np.zeros(("; + emitAssignPrefix(op.getResult()); + os << "np.zeros(("; for (size_t i = 0; i < shape.size(); ++i) { if (i > 0) os << ", "; os << shape[i]; diff --git a/lib/Target/JaxiteWord/JaxiteWordEmitter.h b/lib/Target/JaxiteWord/JaxiteWordEmitter.h index c744b6205f..5a21d06c77 100644 --- a/lib/Target/JaxiteWord/JaxiteWordEmitter.h +++ b/lib/Target/JaxiteWord/JaxiteWordEmitter.h @@ -43,24 +43,9 @@ class JaxiteWordEmitter { // values. SelectVariableNames* variableNames; - // ciphertext arg. - std::string CiphertextArg_; - - // A list of modulus to be used for the add operation. - std::string ModulusListArg_; - - // Crypto context variable name (set by GenParamsOp, used by accessor calls) - std::string cryptoContextVarName_; - - // Legacy member variables kept for backward compatibility with old pipeline - // (GenMulKeyOp / GenRotKeyOp path). Not used by the new - // ProgramInitializationOp path. - std::string heMulVarName_; - std::string heRotVarName_; - std::string rotKeysDictVarName_; - LogicalResult printOperation(ModuleOp moduleOp); LogicalResult printOperation(func::FuncOp funcOp); + LogicalResult printOperation(func::CallOp op); LogicalResult printOperation(func::ReturnOp returnOp); LogicalResult printOperation(AddOp op); LogicalResult printOperation(SubOp op); @@ -125,14 +110,12 @@ class JaxiteWordEmitter { FailureOr convertType(Type type); void emitAssignPrefix(Value result); - - LogicalResult printBinaryOpHelper( - Value result, Value lhs, Value rhs, - llvm::function_ref callback); - - LogicalResult printInPlaceBinaryOpHelper( - Value lhs, Value rhs, - llvm::function_ref callback); + void emitAssignCiphertext(StringRef targetName, StringRef sourceName); + void emitNormalizeCiphertext(StringRef resultName, StringRef ctxName, + StringRef sourceName, StringRef levelExpr = ""); + void emitModularAdd(StringRef resultName, StringRef ctxName, + StringRef lhsName, StringRef rhsName); + void emitModularReduce(StringRef targetName); LogicalResult printMulOpHelper( Value result, Value lhs, Value rhs, Value ctx, Operation* op, diff --git a/lib/Target/JaxiteWord/JaxiteWordTemplates.h b/lib/Target/JaxiteWord/JaxiteWordTemplates.h index d979ef6c05..14c4da1958 100644 --- a/lib/Target/JaxiteWord/JaxiteWordTemplates.h +++ b/lib/Target/JaxiteWord/JaxiteWordTemplates.h @@ -10,6 +10,7 @@ namespace jaxiteword { constexpr std::string_view kModulePrelude = R"python( import jax import jax.numpy as jnp +import key_gen import numpy as np from ciphertext import Ciphertext from polynomial import Polynomial @@ -17,111 +18,69 @@ import ckks_ctx as ckks )python"; -// Template for GenMulKeyOp -// This template initializes HEMul for homomorphic multiplication and sets up -// relinearization. It computes r and c from the degree if they are not provided -// in the parameters. -constexpr std::string_view kGenMulKeyTemplate = R"python( -_degree = {0}.parameters.get('degree') -if _degree is not None: - _log_degree = int(math.log2(_degree)) - _half_k = _log_degree // 2 - _default_r = 2 ** _half_k - _default_c = _degree // _default_r -else: - _default_r = 4 - _default_c = 4 -{1} = HEMul( - batch={0}.parameters.get('batch', 1), - r={0}.parameters.get('r', _default_r), - c={0}.parameters.get('c', _default_c), - dnum={0}.parameters.get('dnum', 3), - num_eval_mult={0}.parameters.get('numEvalMult', 1), - original_moduli={0}.q_towers, - extend_moduli={0}.p_towers -) -{1}.control_gen(degree_layout=({0}.parameters.get('r', _default_r), {0}.parameters.get('c', _default_c))) -{1}.setup_relinearization(jnp.array({2}["a"], dtype=jnp.uint32).transpose(0,2,1), jnp.array({2}["b"], dtype=jnp.uint32).transpose(0,2,1)) -)python"; +constexpr std::string_view kEnsurePolyHelper = R"python( +def _ensure_poly(ctx, x, level=None): + _cache = ctx._param_cache + _r = _cache.r + _c = _cache.c + _m = _cache.num_q_at_level(level) if level is not None else None -// Template for GenRotKeyOp -// This template generates rotation keys for power-of-2 indices and initializes -// HERot. It computes r and c from the degree if they are not provided in the -// parameters. -constexpr std::string_view kGenRotKeyTemplate = R"python( -{0} = {{}} -_all_indices = [{1}] -_max_abs_rot_idx = max([abs(idx) for idx in _all_indices]) if _all_indices else 1 -_power_of_2_indices = [] -_pow2 = 1 -while _pow2 <= _max_abs_rot_idx: - _power_of_2_indices.append(_pow2) - _pow2 <<= 1 -_neg_power_of_2_indices = [-idx for idx in _power_of_2_indices] -_all_pow2_indices = _power_of_2_indices + _neg_power_of_2_indices + _data = x.polynomial if isinstance(x, Polynomial) else x + _m_in = _data.shape[-1] + if _m is None: + _m = _m_in + if _m > _m_in: + raise ValueError( + f"_ensure_poly: requested {_m} moduli but data only has {_m_in}" + ) -for _rot_idx in _all_pow2_indices: - {0}[_rot_idx] = key_gen.gen_rotation_key({2}, {3}.q_towers, {3}.p_towers, rot_index=_rot_idx, dnum={3}.parameters.get('dnum', 3)) + if level is not None: + _moduli = _cache.q_moduli_at_level(level) + else: + _moduli_src = getattr(x, "moduli", ctx.q_towers) + if isinstance(_moduli_src, (int, np.integer)): + _moduli_src = [int(_moduli_src)] + _moduli = list(_moduli_src)[:_m] -_degree_rot = {3}.parameters.get('degree') -if _degree_rot is not None: - _log_degree_rot = int(math.log2(_degree_rot)) - _default_r_rot = 1 << (_log_degree_rot // 2) - _default_c_rot = _degree_rot // _default_r_rot -else: - _default_r_rot = 4 - _default_c_rot = 4 -{4} = HERot( - r={3}.parameters.get('r', _default_r_rot), - c={3}.parameters.get('c', _default_c_rot), - dnum={3}.parameters.get('dnum', 3), - rotate_in_ciphertext_moduli={3}.q_towers, - extend_moduli={3}.p_towers -) -{4}.control_gen(batch=1, degree_layout=({3}.parameters.get('r', _default_r_rot), {3}.parameters.get('c', _default_c_rot))) -{5} = {0} -)python"; + # Return a fresh wrapper even when x is already tiled: emitted add/sub paths + # mutate the result object, so aliasing the source would violate SSA semantics. + _out = Polynomial( + { + "batch": _data.shape[0], + "num_elements": _data.shape[1], + "degree": ctx.degree, + "num_moduli": _m, + "precision": 32, + "degree_layout": (_r, _c), + }, + {"moduli": _moduli}, + ) + _out.polynomial = _data.reshape( + _data.shape[0], _data.shape[1], _r, _c, _m_in + )[..., :_m] + return _out -// Template for DecryptOp -// This template prepares the ciphertext for decryption by extracting the -// required moduli. -constexpr std::string_view kDecryptTemplate = R"python( -{0}.secret_key = {1} -_rescales = {2} -_num_moduli = len({0}.q_towers) - _rescales * {0}.composite_degree -_q_sub = {0}.q_towers[:_num_moduli] -_ct_for_dec = Polynomial( - {{'batch': 1, 'num_elements': 2, 'degree': {0}.degree, - 'precision': 32, 'num_moduli': _num_moduli, - 'degree_layout': ({0}.degree,)}}, - {{'moduli': _q_sub}}) -_ct_for_dec.set_batch_polynomial({3}.polynomial.reshape(1, 2, {0}.degree, _num_moduli)) -)python"; - -// Template for AddOp and AddInPlaceOp -// This template performs addition and modular reduction. -constexpr std::string_view kAddCoreTemplate = R"python( -{0}.add({1}) -{0}.ciphertext = jnp.where({0}.ciphertext >= {0}.moduli_array, {0}.ciphertext - {0}.moduli_array, {0}.ciphertext) -)python"; - -// Template for SubOp, SubInPlaceOp, and SubPlainOp -// This template performs subtraction and modular reduction. -// {1} should be rhs.ciphertext for Sub/SubInPlace and just rhs for SubPlain. -constexpr std::string_view kSubTemplate = R"python( -{0}.ciphertext = jnp.where({0}.ciphertext < {1}, {0}.ciphertext + {0}.moduli_array - {1}, {0}.ciphertext - {1}) -)python"; - -// Template for AddPlainOp modular reduction -constexpr std::string_view kAddModReduceTemplate = R"python( -{0}.ciphertext = jnp.where({0}.ciphertext >= {0}.moduli_array, {0}.ciphertext - {0}.moduli_array, {0}.ciphertext) -)python"; +def _assign_poly(dst, src): + for _attr in ( + "batch", + "num_elements", + "num_moduli", + "degree", + "precision", + "degree_layout", + "r", + "c", + "moduli", + "moduli_array", + "ntt_ctx", + "shape_in_ntt_all_limbs", + ): + if hasattr(src, _attr): + setattr(dst, _attr, getattr(src, _attr)) + dst.polynomial = src.polynomial + if hasattr(src, "extend_polynomial"): + dst.extend_polynomial = src.extend_polynomial -// Template for RelinOp -constexpr std::string_view kRelinTemplate = R"python( -{0} = {1}.he_mul[{2}].relinearize({3}) -_s = {0}.polynomial.shape -{0}.polynomial = {0}.polynomial.reshape(_s[0], _s[1], {0}.degree, _s[-1]) )python"; } // namespace jaxiteword diff --git a/matvec_8x8.mlir b/matvec_8x8.mlir new file mode 100644 index 0000000000..fb41004848 --- /dev/null +++ b/matvec_8x8.mlir @@ -0,0 +1,105 @@ +// MLIR mimicking CROSS_dev/jaxite_word/matvec_test.py (degree=16, 8 slots). +// Test vector (for later validation): [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0] +// +// CROSS crypto params (matvec_test.py setUp): +// degree=16, num_slots=8, dnum=3, r=4, c=4 +// scaling_factor=563019763943521 (pass flag at lowering time, see below) +// q_towers / p_towers preset via ckks.schemeParam on this module +// +// Lowering flags to match remaining CROSS params: +// --torch-linalg-to-ckks=ciphertext-degree=8 +// --jaxiteword-configure-crypto-context=entry-function=matvec_identity,dnum=3,r=4,c=4,scaling-factor=563019763943521 + +module attributes { + scheme.ckks, + ckks.schemeParam = #ckks.scheme_param< + logN = 4, + Q = [1073742881, 1073742721, 1073741441, 1073741857, 524353], + P = [1073740609, 1073739937, 1073739649], + logDefaultScale = 45 + > + +} { + func.func @matvec_identity(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { + %matrix = arith.constant dense<[ + [1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00] + ]> : tensor<8x8xf32> + %out = arith.constant dense<0.000000e+00> : tensor<8xf32> + %0 = linalg.matvec ins(%matrix, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> + } + + // Shift matrix from CROSS matvec_test.py: result[i] = vector[(i + 1) % n]. + func.func @matvec_shift(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { + %matrix = arith.constant dense<[ + [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], + [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00], + [1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00] + ]> : tensor<8x8xf32> + %out = arith.constant dense<0.000000e+00> : tensor<8xf32> + %0 = linalg.matvec ins(%matrix, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> + } + + + // Matrix from np.random.seed(42); np.random.uniform(0.1, 2.0, (8, 8)) + func.func @matvec_random(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { + %matrix = arith.constant dense<[ + [8.11626226e-01, 1.90635718e+00, 1.49078849e+00, 1.23745112e+00, 3.96435417e-01, 3.96389589e-01, 2.10358863e-01, 1.74573468e+00], + [1.24211852e+00, 1.44533790e+00, 1.39110539e-01, 1.94282872e+00, 1.68164102e+00, 5.03444310e-01, 4.45467438e-01, 4.48468569e-01], + [6.78060262e-01, 1.09703722e+00, 9.20695535e-01, 6.53335366e-01, 1.26252050e+00, 3.65038335e-01, 6.55074832e-01, 7.96087502e-01], + [9.66532970e-01, 1.59183433e+00, 4.79380186e-01, 1.07704543e+00, 1.22558768e+00, 1.88255784e-01, 1.25433522e+00, 4.23995835e-01], + [2.23598027e-01, 1.90288252e+00, 1.93470086e+00, 1.63595496e+00, 6.78766161e-01, 2.85577017e-01, 1.40004275e+00, 9.36289738e-01], + [3.31872646e-01, 1.04083613e+00, 1.65338190e-01, 1.82770876e+00, 5.91681965e-01, 1.35879234e+00, 6.92251045e-01, 1.08812924e+00], + [1.13874953e+00, 4.51223465e-01, 1.94221079e+00, 1.57275236e+00, 1.88504799e+00, 1.80017197e+00, 1.23600996e+00, 1.85156105e+00], + [2.68135754e-01, 4.72367439e-01, 1.85931849e-01, 7.18127628e-01, 8.38486850e-01, 6.15563160e-01, 1.67460127e+00, 7.77831321e-01] + ]> : tensor<8x8xf32> + %out = arith.constant dense<0.000000e+00> : tensor<8xf32> + %0 = linalg.matvec ins(%matrix, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out : tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> + } + + // Matmat-vector chain from CROSS matvec_test.py: + // np.random.seed(123); A = round(uniform(0.5, 1.5), 2); + // B = round(uniform(0.5, 1.5), 2); computes A @ (B @ v). + func.func @matvec_chain(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { + %matrix_b = arith.constant dense<[ + [1.340000e+00, 5.800000e-01, 1.260000e+00, 7.400000e-01, 6.900000e-01, 1.070000e+00, 6.000000e-01, 1.390000e+00], + [1.130000e+00, 1.220000e+00, 5.200000e-01, 1.090000e+00, 1.060000e+00, 6.600000e-01, 6.500000e-01, 1.200000e+00], + [8.200000e-01, 1.190000e+00, 1.050000e+00, 8.900000e-01, 1.430000e+00, 1.340000e+00, 8.600000e-01, 5.400000e-01], + [8.000000e-01, 9.000000e-01, 1.200000e+00, 1.500000e+00, 8.600000e-01, 1.260000e+00, 1.090000e+00, 1.190000e+00], + [6.500000e-01, 9.000000e-01, 7.400000e-01, 8.400000e-01, 1.010000e+00, 1.170000e+00, 6.100000e-01, 6.300000e-01], + [8.200000e-01, 1.160000e+00, 1.350000e+00, 1.050000e+00, 1.350000e+00, 8.800000e-01, 8.200000e-01, 8.500000e-01], + [6.700000e-01, 1.330000e+00, 8.400000e-01, 1.050000e+00, 1.080000e+00, 1.020000e+00, 5.000000e-01, 1.490000e+00], + [1.410000e+00, 7.100000e-01, 7.900000e-01, 1.020000e+00, 1.400000e+00, 1.480000e+00, 7.600000e-01, 1.060000e+00] + ]> : tensor<8x8xf32> + %matrix_a = arith.constant dense<[ + [1.200000e+00, 7.900000e-01, 7.300000e-01, 1.050000e+00, 1.220000e+00, 9.200000e-01, 1.480000e+00, 1.180000e+00], + [9.800000e-01, 8.900000e-01, 8.400000e-01, 1.230000e+00, 9.400000e-01, 5.600000e-01, 9.000000e-01, 1.240000e+00], + [6.800000e-01, 6.800000e-01, 1.030000e+00, 1.030000e+00, 1.130000e+00, 1.350000e+00, 1.220000e+00, 1.110000e+00], + [1.220000e+00, 8.200000e-01, 8.600000e-01, 7.300000e-01, 7.900000e-01, 1.130000e+00, 5.900000e-01, 9.300000e-01], + [9.300000e-01, 9.900000e-01, 9.300000e-01, 8.100000e-01, 9.300000e-01, 1.390000e+00, 1.440000e+00, 1.000000e+00], + [1.120000e+00, 6.200000e-01, 8.200000e-01, 9.100000e-01, 1.370000e+00, 7.500000e-01, 9.800000e-01, 1.490000e+00], + [1.020000e+00, 1.110000e+00, 6.200000e-01, 1.330000e+00, 1.100000e+00, 1.050000e+00, 8.400000e-01, 8.000000e-01], + [9.200000e-01, 1.180000e+00, 1.380000e+00, 1.010000e+00, 1.170000e+00, 1.090000e+00, 1.120000e+00, 1.170000e+00] + ]> : tensor<8x8xf32> + %out_b = arith.constant dense<0.000000e+00> : tensor<8xf32> + %out_a = arith.constant dense<0.000000e+00> : tensor<8xf32> + %0 = linalg.matvec ins(%matrix_b, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out_b : tensor<8xf32>) -> tensor<8xf32> + %1 = linalg.matvec ins(%matrix_a, %0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out_a : tensor<8xf32>) -> tensor<8xf32> + return %1 : tensor<8xf32> + } + +} diff --git a/matvec_8x8_cross.py b/matvec_8x8_cross.py new file mode 100644 index 0000000000..d3d34ae7d5 --- /dev/null +++ b/matvec_8x8_cross.py @@ -0,0 +1,2503 @@ +import jax +import jax.numpy as jnp +import key_gen +import numpy as np +from ciphertext import Ciphertext +from polynomial import Polynomial +import ckks_ctx as ckks + + +def _ensure_poly(ctx, x, level=None): + _cache = ctx._param_cache + _r = _cache.r + _c = _cache.c + _m = _cache.num_q_at_level(level) if level is not None else None + + _data = x.polynomial if isinstance(x, Polynomial) else x + _m_in = _data.shape[-1] + if _m is None: + _m = _m_in + if _m > _m_in: + raise ValueError( + f"_ensure_poly: requested {_m} moduli but data only has {_m_in}" + ) + + if level is not None: + _moduli = _cache.q_moduli_at_level(level) + else: + _moduli_src = getattr(x, "moduli", ctx.q_towers) + if isinstance(_moduli_src, (int, np.integer)): + _moduli_src = [int(_moduli_src)] + _moduli = list(_moduli_src)[:_m] + + # Return a fresh wrapper even when x is already tiled: emitted add/sub paths + # mutate the result object, so aliasing the source would violate SSA semantics. + _out = Polynomial( + { + "batch": _data.shape[0], + "num_elements": _data.shape[1], + "degree": ctx.degree, + "num_moduli": _m, + "precision": 32, + "degree_layout": (_r, _c), + }, + {"moduli": _moduli}, + ) + _out.polynomial = _data.reshape( + _data.shape[0], _data.shape[1], _r, _c, _m_in + )[..., :_m] + return _out + + +def _assign_poly(dst, src): + for _attr in ( + "batch", + "num_elements", + "num_moduli", + "degree", + "precision", + "degree_layout", + "r", + "c", + "moduli", + "moduli_array", + "ntt_ctx", + "shape_in_ntt_all_limbs", + ): + if hasattr(src, _attr): + setattr(dst, _attr, getattr(src, _attr)) + dst.polynomial = src.polynomial + if hasattr(src, "extend_polynomial"): + dst.extend_polynomial = src.extend_polynomial + + +def matvec_identity__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> (np.ndarray, np.ndarray): + v2 = np.full((8,), 0.000000e00, dtype=np.float32) + v3 = np.full((8,), 1.000000e00, dtype=np.float32) + pt = v0.encode(v2) + pt1 = v0.encode(v3) + v4 = [pt] + v5 = [pt1] + return (v4, v5) + + +def matvec_identity__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, +) -> np.ndarray: + v5 = 1 + v6 = 2 + v7 = 3 + v8 = 6 + v9 = 0 + pt = v3[0] + pt1 = v4[0] + ct = v2[0] + ct1_arg = _ensure_poly(v0, ct, v0.max_level) + ct1 = v0.he_rot[v0.max_level, 1].rotate(ct1_arg) + ct2_arg = _ensure_poly(v0, ct1, v0.max_level) + ct2_pt_ntt = ( + pt.polynomial[0, 0, :, : ct2_arg.polynomial.shape[-1]] + .reshape(ct2_arg.r, ct2_arg.c, ct2_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct2_ptct = v0.ptct_mul[v0.max_level] + ct2_ptct.set_plaintext(ct2_pt_ntt) + ct2 = ct2_ptct.mul(ct2_arg, use_bat=False) + ct3_arg = _ensure_poly(v0, ct, v0.max_level) + ct3 = v0.he_rot[v0.max_level, 2].rotate(ct3_arg) + ct4_arg = _ensure_poly(v0, ct3, v0.max_level) + ct4_pt_ntt = ( + pt.polynomial[0, 0, :, : ct4_arg.polynomial.shape[-1]] + .reshape(ct4_arg.r, ct4_arg.c, ct4_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct4_ptct = v0.ptct_mul[v0.max_level] + ct4_ptct.set_plaintext(ct4_pt_ntt) + ct4 = ct4_ptct.mul(ct4_arg, use_bat=False) + ct5_arg = _ensure_poly(v0, ct, v0.max_level) + ct5_pt_ntt = ( + pt.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) + ct6_lhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + ct6_rhs = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + ct6_lhs = ct6_lhs.reshape( + ct6_lhs.shape[0], + ct6_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct6_lhs.shape[-1], + ) + ct6_rhs = ct6_rhs.reshape( + ct6_rhs.shape[0], + ct6_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct6_rhs.shape[-1], + ) + if ct6_lhs.shape != ct6_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct6_num_moduli = ct6_lhs.shape[-1] + if hasattr(ct5, "moduli") and hasattr(ct2, "moduli"): + if list(ct5.moduli)[:ct6_num_moduli] != list(ct2.moduli)[:ct6_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct6_moduli_src = getattr(ct5, "moduli", getattr(ct2, "moduli", v0.q_towers)) + if isinstance(ct6_moduli_src, (int, np.integer)): + ct6_moduli_src = [ct6_moduli_src] + ct6_moduli = jnp.array( + list(ct6_moduli_src)[:ct6_num_moduli], dtype=jnp.uint64 + ) + ct6_sum = ct6_lhs.astype(jnp.uint64) + ct6_rhs.astype(jnp.uint64) + ct6 = jnp.where(ct6_sum >= ct6_moduli, ct6_sum - ct6_moduli, ct6_sum).astype( + jnp.uint32 + ) + ct7_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + ct7_rhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + ct7_lhs = ct7_lhs.reshape( + ct7_lhs.shape[0], + ct7_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct7_lhs.shape[-1], + ) + ct7_rhs = ct7_rhs.reshape( + ct7_rhs.shape[0], + ct7_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct7_rhs.shape[-1], + ) + if ct7_lhs.shape != ct7_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct7_num_moduli = ct7_lhs.shape[-1] + if hasattr(ct6, "moduli") and hasattr(ct4, "moduli"): + if list(ct6.moduli)[:ct7_num_moduli] != list(ct4.moduli)[:ct7_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct7_moduli_src = getattr(ct6, "moduli", getattr(ct4, "moduli", v0.q_towers)) + if isinstance(ct7_moduli_src, (int, np.integer)): + ct7_moduli_src = [ct7_moduli_src] + ct7_moduli = jnp.array( + list(ct7_moduli_src)[:ct7_num_moduli], dtype=jnp.uint64 + ) + ct7_sum = ct7_lhs.astype(jnp.uint64) + ct7_rhs.astype(jnp.uint64) + ct7 = jnp.where(ct7_sum >= ct7_moduli, ct7_sum - ct7_moduli, ct7_sum).astype( + jnp.uint32 + ) + ct8_arg = _ensure_poly(v0, ct7, v0.max_level) + ct8 = v0.he_rot[v0.max_level, 3].rotate(ct8_arg) + ct9_arg = _ensure_poly(v0, ct6, v0.max_level) + ct9 = v0.he_rot[v0.max_level, 6].rotate(ct9_arg) + ct10_arg = _ensure_poly(v0, ct, v0.max_level) + ct10_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct10_arg.polynomial.shape[-1]] + .reshape(ct10_arg.r, ct10_arg.c, ct10_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct10_ptct = v0.ptct_mul[v0.max_level] + ct10_ptct.set_plaintext(ct10_pt_ntt) + ct10 = ct10_ptct.mul(ct10_arg, use_bat=False) + ct11_lhs = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 + ct11_rhs = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + ct11_lhs = ct11_lhs.reshape( + ct11_lhs.shape[0], + ct11_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct11_lhs.shape[-1], + ) + ct11_rhs = ct11_rhs.reshape( + ct11_rhs.shape[0], + ct11_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct11_rhs.shape[-1], + ) + if ct11_lhs.shape != ct11_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct11_num_moduli = ct11_lhs.shape[-1] + if hasattr(ct10, "moduli") and hasattr(ct2, "moduli"): + if ( + list(ct10.moduli)[:ct11_num_moduli] + != list(ct2.moduli)[:ct11_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct11_moduli_src = getattr(ct10, "moduli", getattr(ct2, "moduli", v0.q_towers)) + if isinstance(ct11_moduli_src, (int, np.integer)): + ct11_moduli_src = [ct11_moduli_src] + ct11_moduli = jnp.array( + list(ct11_moduli_src)[:ct11_num_moduli], dtype=jnp.uint64 + ) + ct11_sum = ct11_lhs.astype(jnp.uint64) + ct11_rhs.astype(jnp.uint64) + ct11 = jnp.where( + ct11_sum >= ct11_moduli, ct11_sum - ct11_moduli, ct11_sum + ).astype(jnp.uint32) + ct12_lhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + ct12_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + ct12_lhs = ct12_lhs.reshape( + ct12_lhs.shape[0], + ct12_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct12_lhs.shape[-1], + ) + ct12_rhs = ct12_rhs.reshape( + ct12_rhs.shape[0], + ct12_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct12_rhs.shape[-1], + ) + if ct12_lhs.shape != ct12_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct12_num_moduli = ct12_lhs.shape[-1] + if hasattr(ct4, "moduli") and hasattr(ct8, "moduli"): + if list(ct4.moduli)[:ct12_num_moduli] != list(ct8.moduli)[:ct12_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct12_moduli_src = getattr(ct4, "moduli", getattr(ct8, "moduli", v0.q_towers)) + if isinstance(ct12_moduli_src, (int, np.integer)): + ct12_moduli_src = [ct12_moduli_src] + ct12_moduli = jnp.array( + list(ct12_moduli_src)[:ct12_num_moduli], dtype=jnp.uint64 + ) + ct12_sum = ct12_lhs.astype(jnp.uint64) + ct12_rhs.astype(jnp.uint64) + ct12 = jnp.where( + ct12_sum >= ct12_moduli, ct12_sum - ct12_moduli, ct12_sum + ).astype(jnp.uint32) + ct13_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + ct13_rhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + ct13_lhs = ct13_lhs.reshape( + ct13_lhs.shape[0], + ct13_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct13_lhs.shape[-1], + ) + ct13_rhs = ct13_rhs.reshape( + ct13_rhs.shape[0], + ct13_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct13_rhs.shape[-1], + ) + if ct13_lhs.shape != ct13_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct13_num_moduli = ct13_lhs.shape[-1] + if hasattr(ct12, "moduli") and hasattr(ct9, "moduli"): + if ( + list(ct12.moduli)[:ct13_num_moduli] + != list(ct9.moduli)[:ct13_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct13_moduli_src = getattr(ct12, "moduli", getattr(ct9, "moduli", v0.q_towers)) + if isinstance(ct13_moduli_src, (int, np.integer)): + ct13_moduli_src = [ct13_moduli_src] + ct13_moduli = jnp.array( + list(ct13_moduli_src)[:ct13_num_moduli], dtype=jnp.uint64 + ) + ct13_sum = ct13_lhs.astype(jnp.uint64) + ct13_rhs.astype(jnp.uint64) + ct13 = jnp.where( + ct13_sum >= ct13_moduli, ct13_sum - ct13_moduli, ct13_sum + ).astype(jnp.uint32) + ct14_lhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + ct14_lhs = ct14_lhs.reshape( + ct14_lhs.shape[0], + ct14_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_lhs.shape[-1], + ) + ct14_rhs = ct14_rhs.reshape( + ct14_rhs.shape[0], + ct14_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_rhs.shape[-1], + ) + if ct14_lhs.shape != ct14_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct14_num_moduli = ct14_lhs.shape[-1] + if hasattr(ct11, "moduli") and hasattr(ct13, "moduli"): + if ( + list(ct11.moduli)[:ct14_num_moduli] + != list(ct13.moduli)[:ct14_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct14_moduli_src = getattr( + ct11, "moduli", getattr(ct13, "moduli", v0.q_towers) + ) + if isinstance(ct14_moduli_src, (int, np.integer)): + ct14_moduli_src = [ct14_moduli_src] + ct14_moduli = jnp.array( + list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 + ) + ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) + ct14 = jnp.where( + ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum + ).astype(jnp.uint32) + v10 = [None] * 1 + ct15_arg = _ensure_poly(v0, ct14, v0.max_level) + ct15 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct15_arg) + v10[0] = ct15 + v11 = v10 + return v11 + + +def matvec_identity( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4) = matvec_identity__preprocessing(v0, v1) + v5 = matvec_identity__preprocessed(v0, v1, v2, v3, v4) + return v5 + + +def matvec_identity__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + ct = _ensure_poly(v0, ct_raw) + v16 = [ct] + return v16 + + +def matvec_identity__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 7 + v8 = 0 + v9 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + pt_ct = _ensure_poly(v0, ct) + _num_moduli = pt_ct.polynomial.shape[-1] + _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": pt_ct.polynomial.shape[0], + "num_elements": pt_ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + pt_ct.polynomial.reshape( + pt_ct.polynomial.shape[0], + pt_ct.polynomial.shape[1], + v0.degree, + _num_moduli, + ) + ) + pt = v0.decrypt(_ct_for_dec) + v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v11 = v9.copy() + for v12 in range(0, 8): + v14 = v7 - v12 + v15 = int(v14) + v16 = v10[0, v15] + v11[v15] = v16 + return v11 + + +def matvec_shift__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> (np.ndarray, np.ndarray): + v2 = np.full((8,), 0.000000e00, dtype=np.float32) + v3 = np.full((8,), 1.000000e00, dtype=np.float32) + pt = v0.encode(v2) + pt1 = v0.encode(v3) + v4 = [pt] + v5 = [pt1] + return (v4, v5) + + +def matvec_shift__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, +) -> np.ndarray: + v5 = 1 + v6 = 2 + v7 = 3 + v8 = 6 + v9 = 0 + pt = v3[0] + pt1 = v4[0] + ct = v2[0] + ct1_arg = _ensure_poly(v0, ct, v0.max_level) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1 = ct1_ptct.mul(ct1_arg, use_bat=False) + ct2_arg = _ensure_poly(v0, ct, v0.max_level) + ct2 = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + ct3_arg = _ensure_poly(v0, ct, v0.max_level) + ct3 = v0.he_rot[v0.max_level, 2].rotate(ct3_arg) + ct4_arg = _ensure_poly(v0, ct3, v0.max_level) + ct4_pt_ntt = ( + pt.polynomial[0, 0, :, : ct4_arg.polynomial.shape[-1]] + .reshape(ct4_arg.r, ct4_arg.c, ct4_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct4_ptct = v0.ptct_mul[v0.max_level] + ct4_ptct.set_plaintext(ct4_pt_ntt) + ct4 = ct4_ptct.mul(ct4_arg, use_bat=False) + ct5_arg = _ensure_poly(v0, ct2, v0.max_level) + ct5_pt_ntt = ( + pt.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) + ct6_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + ct6_rhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + ct6_lhs = ct6_lhs.reshape( + ct6_lhs.shape[0], + ct6_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct6_lhs.shape[-1], + ) + ct6_rhs = ct6_rhs.reshape( + ct6_rhs.shape[0], + ct6_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct6_rhs.shape[-1], + ) + if ct6_lhs.shape != ct6_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct6_num_moduli = ct6_lhs.shape[-1] + if hasattr(ct1, "moduli") and hasattr(ct5, "moduli"): + if list(ct1.moduli)[:ct6_num_moduli] != list(ct5.moduli)[:ct6_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct6_moduli_src = getattr(ct1, "moduli", getattr(ct5, "moduli", v0.q_towers)) + if isinstance(ct6_moduli_src, (int, np.integer)): + ct6_moduli_src = [ct6_moduli_src] + ct6_moduli = jnp.array( + list(ct6_moduli_src)[:ct6_num_moduli], dtype=jnp.uint64 + ) + ct6_sum = ct6_lhs.astype(jnp.uint64) + ct6_rhs.astype(jnp.uint64) + ct6 = jnp.where(ct6_sum >= ct6_moduli, ct6_sum - ct6_moduli, ct6_sum).astype( + jnp.uint32 + ) + ct7_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + ct7_rhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + ct7_lhs = ct7_lhs.reshape( + ct7_lhs.shape[0], + ct7_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct7_lhs.shape[-1], + ) + ct7_rhs = ct7_rhs.reshape( + ct7_rhs.shape[0], + ct7_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct7_rhs.shape[-1], + ) + if ct7_lhs.shape != ct7_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct7_num_moduli = ct7_lhs.shape[-1] + if hasattr(ct6, "moduli") and hasattr(ct4, "moduli"): + if list(ct6.moduli)[:ct7_num_moduli] != list(ct4.moduli)[:ct7_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct7_moduli_src = getattr(ct6, "moduli", getattr(ct4, "moduli", v0.q_towers)) + if isinstance(ct7_moduli_src, (int, np.integer)): + ct7_moduli_src = [ct7_moduli_src] + ct7_moduli = jnp.array( + list(ct7_moduli_src)[:ct7_num_moduli], dtype=jnp.uint64 + ) + ct7_sum = ct7_lhs.astype(jnp.uint64) + ct7_rhs.astype(jnp.uint64) + ct7 = jnp.where(ct7_sum >= ct7_moduli, ct7_sum - ct7_moduli, ct7_sum).astype( + jnp.uint32 + ) + ct8_arg = _ensure_poly(v0, ct7, v0.max_level) + ct8 = v0.he_rot[v0.max_level, 3].rotate(ct8_arg) + ct9_arg = _ensure_poly(v0, ct6, v0.max_level) + ct9 = v0.he_rot[v0.max_level, 6].rotate(ct9_arg) + ct10_arg = _ensure_poly(v0, ct2, v0.max_level) + ct10_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct10_arg.polynomial.shape[-1]] + .reshape(ct10_arg.r, ct10_arg.c, ct10_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct10_ptct = v0.ptct_mul[v0.max_level] + ct10_ptct.set_plaintext(ct10_pt_ntt) + ct10 = ct10_ptct.mul(ct10_arg, use_bat=False) + ct11_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + ct11_rhs = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 + ct11_lhs = ct11_lhs.reshape( + ct11_lhs.shape[0], + ct11_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct11_lhs.shape[-1], + ) + ct11_rhs = ct11_rhs.reshape( + ct11_rhs.shape[0], + ct11_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct11_rhs.shape[-1], + ) + if ct11_lhs.shape != ct11_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct11_num_moduli = ct11_lhs.shape[-1] + if hasattr(ct1, "moduli") and hasattr(ct10, "moduli"): + if ( + list(ct1.moduli)[:ct11_num_moduli] + != list(ct10.moduli)[:ct11_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct11_moduli_src = getattr(ct1, "moduli", getattr(ct10, "moduli", v0.q_towers)) + if isinstance(ct11_moduli_src, (int, np.integer)): + ct11_moduli_src = [ct11_moduli_src] + ct11_moduli = jnp.array( + list(ct11_moduli_src)[:ct11_num_moduli], dtype=jnp.uint64 + ) + ct11_sum = ct11_lhs.astype(jnp.uint64) + ct11_rhs.astype(jnp.uint64) + ct11 = jnp.where( + ct11_sum >= ct11_moduli, ct11_sum - ct11_moduli, ct11_sum + ).astype(jnp.uint32) + ct12_lhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + ct12_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + ct12_lhs = ct12_lhs.reshape( + ct12_lhs.shape[0], + ct12_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct12_lhs.shape[-1], + ) + ct12_rhs = ct12_rhs.reshape( + ct12_rhs.shape[0], + ct12_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct12_rhs.shape[-1], + ) + if ct12_lhs.shape != ct12_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct12_num_moduli = ct12_lhs.shape[-1] + if hasattr(ct4, "moduli") and hasattr(ct8, "moduli"): + if list(ct4.moduli)[:ct12_num_moduli] != list(ct8.moduli)[:ct12_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct12_moduli_src = getattr(ct4, "moduli", getattr(ct8, "moduli", v0.q_towers)) + if isinstance(ct12_moduli_src, (int, np.integer)): + ct12_moduli_src = [ct12_moduli_src] + ct12_moduli = jnp.array( + list(ct12_moduli_src)[:ct12_num_moduli], dtype=jnp.uint64 + ) + ct12_sum = ct12_lhs.astype(jnp.uint64) + ct12_rhs.astype(jnp.uint64) + ct12 = jnp.where( + ct12_sum >= ct12_moduli, ct12_sum - ct12_moduli, ct12_sum + ).astype(jnp.uint32) + ct13_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + ct13_rhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + ct13_lhs = ct13_lhs.reshape( + ct13_lhs.shape[0], + ct13_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct13_lhs.shape[-1], + ) + ct13_rhs = ct13_rhs.reshape( + ct13_rhs.shape[0], + ct13_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct13_rhs.shape[-1], + ) + if ct13_lhs.shape != ct13_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct13_num_moduli = ct13_lhs.shape[-1] + if hasattr(ct12, "moduli") and hasattr(ct9, "moduli"): + if ( + list(ct12.moduli)[:ct13_num_moduli] + != list(ct9.moduli)[:ct13_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct13_moduli_src = getattr(ct12, "moduli", getattr(ct9, "moduli", v0.q_towers)) + if isinstance(ct13_moduli_src, (int, np.integer)): + ct13_moduli_src = [ct13_moduli_src] + ct13_moduli = jnp.array( + list(ct13_moduli_src)[:ct13_num_moduli], dtype=jnp.uint64 + ) + ct13_sum = ct13_lhs.astype(jnp.uint64) + ct13_rhs.astype(jnp.uint64) + ct13 = jnp.where( + ct13_sum >= ct13_moduli, ct13_sum - ct13_moduli, ct13_sum + ).astype(jnp.uint32) + ct14_lhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + ct14_lhs = ct14_lhs.reshape( + ct14_lhs.shape[0], + ct14_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_lhs.shape[-1], + ) + ct14_rhs = ct14_rhs.reshape( + ct14_rhs.shape[0], + ct14_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_rhs.shape[-1], + ) + if ct14_lhs.shape != ct14_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct14_num_moduli = ct14_lhs.shape[-1] + if hasattr(ct11, "moduli") and hasattr(ct13, "moduli"): + if ( + list(ct11.moduli)[:ct14_num_moduli] + != list(ct13.moduli)[:ct14_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct14_moduli_src = getattr( + ct11, "moduli", getattr(ct13, "moduli", v0.q_towers) + ) + if isinstance(ct14_moduli_src, (int, np.integer)): + ct14_moduli_src = [ct14_moduli_src] + ct14_moduli = jnp.array( + list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 + ) + ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) + ct14 = jnp.where( + ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum + ).astype(jnp.uint32) + v10 = [None] * 1 + ct15_arg = _ensure_poly(v0, ct14, v0.max_level) + ct15 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct15_arg) + v10[0] = ct15 + v11 = v10 + return v11 + + +def matvec_shift( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4) = matvec_shift__preprocessing(v0, v1) + v5 = matvec_shift__preprocessed(v0, v1, v2, v3, v4) + return v5 + + +def matvec_shift__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + ct = _ensure_poly(v0, ct_raw) + v16 = [ct] + return v16 + + +def matvec_shift__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 7 + v8 = 0 + v9 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + pt_ct = _ensure_poly(v0, ct) + _num_moduli = pt_ct.polynomial.shape[-1] + _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": pt_ct.polynomial.shape[0], + "num_elements": pt_ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + pt_ct.polynomial.reshape( + pt_ct.polynomial.shape[0], + pt_ct.polynomial.shape[1], + v0.degree, + _num_moduli, + ) + ) + pt = v0.decrypt(_ct_for_dec) + v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v11 = v9.copy() + for v12 in range(0, 8): + v14 = v7 - v12 + v15 = int(v14) + v16 = v10[0, v15] + v11[v15] = v16 + return v11 + + +def matvec_random__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> ( + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +): + v2 = np.array( + [ + 8.116263e-01, + 1.445338e00, + 9.206955e-01, + 1.077045e00, + 6.787661e-01, + 1.358792e00, + 1.236010e00, + 7.778313e-01, + ], + dtype=np.float32, + ) + v3 = np.array( + [ + 1.906357e00, + 1.391105e-01, + 6.533354e-01, + 1.225588e00, + 2.855770e-01, + 6.922510e-01, + 1.851561e00, + 2.681358e-01, + ], + dtype=np.float32, + ) + v4 = np.array( + [ + 1.490788e00, + 1.942829e00, + 1.262521e00, + 1.882558e-01, + 1.400043e00, + 1.088129e00, + 1.138749e00, + 4.723674e-01, + ], + dtype=np.float32, + ) + v5 = np.array( + [ + 3.318726e-01, + 4.512235e-01, + 1.859318e-01, + 1.237451e00, + 1.681641e00, + 3.650383e-01, + 1.254335e00, + 9.362897e-01, + ], + dtype=np.float32, + ) + v6 = np.array( + [ + 1.040836e00, + 1.942211e00, + 7.181276e-01, + 3.964354e-01, + 5.034443e-01, + 6.550748e-01, + 4.239958e-01, + 2.235980e-01, + ], + dtype=np.float32, + ) + v7 = np.array( + [ + 1.653382e-01, + 1.572752e00, + 8.384869e-01, + 3.963896e-01, + 4.454674e-01, + 7.960875e-01, + 9.665329e-01, + 1.902883e00, + ], + dtype=np.float32, + ) + v8 = np.array( + [ + 6.780602e-01, + 1.591834e00, + 1.934701e00, + 1.827709e00, + 1.885048e00, + 6.155632e-01, + 2.103589e-01, + 4.484686e-01, + ], + dtype=np.float32, + ) + v9 = np.array( + [ + 1.097037e00, + 4.793802e-01, + 1.635955e00, + 5.916820e-01, + 1.800172e00, + 1.674601e00, + 1.745735e00, + 1.242118e00, + ], + dtype=np.float32, + ) + pt = v0.encode(v2) + pt1 = v0.encode(v3) + pt2 = v0.encode(v4) + pt3 = v0.encode(v5) + pt4 = v0.encode(v6) + pt5 = v0.encode(v7) + pt6 = v0.encode(v8) + pt7 = v0.encode(v9) + v10 = [pt] + v11 = [pt1] + v12 = [pt2] + v13 = [pt3] + v14 = [pt4] + v15 = [pt5] + v16 = [pt6] + v17 = [pt7] + return (v10, v11, v12, v13, v14, v15, v16, v17) + + +def matvec_random__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, + v5: np.ndarray, + v6: np.ndarray, + v7: np.ndarray, + v8: np.ndarray, + v9: np.ndarray, + v10: np.ndarray, +) -> np.ndarray: + v11 = 1 + v12 = 2 + v13 = 3 + v14 = 6 + v15 = 0 + pt = v3[0] + pt1 = v4[0] + pt2 = v5[0] + pt3 = v6[0] + pt4 = v7[0] + pt5 = v8[0] + pt6 = v9[0] + pt7 = v10[0] + ct = v2[0] + ct1_arg = _ensure_poly(v0, ct, v0.max_level) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1 = ct1_ptct.mul(ct1_arg, use_bat=False) + ct2_arg = _ensure_poly(v0, ct, v0.max_level) + ct2 = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + ct3_arg = _ensure_poly(v0, ct2, v0.max_level) + ct3_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] + .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct3_ptct = v0.ptct_mul[v0.max_level] + ct3_ptct.set_plaintext(ct3_pt_ntt) + ct3 = ct3_ptct.mul(ct3_arg, use_bat=False) + ct4_arg = _ensure_poly(v0, ct, v0.max_level) + ct4 = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) + ct5_arg = _ensure_poly(v0, ct4, v0.max_level) + ct5_pt_ntt = ( + pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) + ct6_arg = _ensure_poly(v0, ct, v0.max_level) + ct6_pt_ntt = ( + pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] + .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct6_ptct = v0.ptct_mul[v0.max_level] + ct6_ptct.set_plaintext(ct6_pt_ntt) + ct6 = ct6_ptct.mul(ct6_arg, use_bat=False) + ct7_arg = _ensure_poly(v0, ct2, v0.max_level) + ct7_pt_ntt = ( + pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] + .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct7_ptct = v0.ptct_mul[v0.max_level] + ct7_ptct.set_plaintext(ct7_pt_ntt) + ct7 = ct7_ptct.mul(ct7_arg, use_bat=False) + ct8_arg = _ensure_poly(v0, ct4, v0.max_level) + ct8_pt_ntt = ( + pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] + .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct8_ptct = v0.ptct_mul[v0.max_level] + ct8_ptct.set_plaintext(ct8_pt_ntt) + ct8 = ct8_ptct.mul(ct8_arg, use_bat=False) + ct9_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + ct9_rhs = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 + ct9_lhs = ct9_lhs.reshape( + ct9_lhs.shape[0], + ct9_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct9_lhs.shape[-1], + ) + ct9_rhs = ct9_rhs.reshape( + ct9_rhs.shape[0], + ct9_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct9_rhs.shape[-1], + ) + if ct9_lhs.shape != ct9_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct9_num_moduli = ct9_lhs.shape[-1] + if hasattr(ct6, "moduli") and hasattr(ct7, "moduli"): + if list(ct6.moduli)[:ct9_num_moduli] != list(ct7.moduli)[:ct9_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct9_moduli_src = getattr(ct6, "moduli", getattr(ct7, "moduli", v0.q_towers)) + if isinstance(ct9_moduli_src, (int, np.integer)): + ct9_moduli_src = [ct9_moduli_src] + ct9_moduli = jnp.array( + list(ct9_moduli_src)[:ct9_num_moduli], dtype=jnp.uint64 + ) + ct9_sum = ct9_lhs.astype(jnp.uint64) + ct9_rhs.astype(jnp.uint64) + ct9 = jnp.where(ct9_sum >= ct9_moduli, ct9_sum - ct9_moduli, ct9_sum).astype( + jnp.uint32 + ) + ct10_lhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + ct10_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + ct10_lhs = ct10_lhs.reshape( + ct10_lhs.shape[0], + ct10_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct10_lhs.shape[-1], + ) + ct10_rhs = ct10_rhs.reshape( + ct10_rhs.shape[0], + ct10_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct10_rhs.shape[-1], + ) + if ct10_lhs.shape != ct10_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct10_num_moduli = ct10_lhs.shape[-1] + if hasattr(ct9, "moduli") and hasattr(ct8, "moduli"): + if list(ct9.moduli)[:ct10_num_moduli] != list(ct8.moduli)[:ct10_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct10_moduli_src = getattr(ct9, "moduli", getattr(ct8, "moduli", v0.q_towers)) + if isinstance(ct10_moduli_src, (int, np.integer)): + ct10_moduli_src = [ct10_moduli_src] + ct10_moduli = jnp.array( + list(ct10_moduli_src)[:ct10_num_moduli], dtype=jnp.uint64 + ) + ct10_sum = ct10_lhs.astype(jnp.uint64) + ct10_rhs.astype(jnp.uint64) + ct10 = jnp.where( + ct10_sum >= ct10_moduli, ct10_sum - ct10_moduli, ct10_sum + ).astype(jnp.uint32) + ct11_arg = _ensure_poly(v0, ct10, v0.max_level) + ct11 = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) + ct12_arg = _ensure_poly(v0, ct, v0.max_level) + ct12_pt_ntt = ( + pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] + .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct12_ptct = v0.ptct_mul[v0.max_level] + ct12_ptct.set_plaintext(ct12_pt_ntt) + ct12 = ct12_ptct.mul(ct12_arg, use_bat=False) + ct13_arg = _ensure_poly(v0, ct2, v0.max_level) + ct13_pt_ntt = ( + pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] + .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct13_ptct = v0.ptct_mul[v0.max_level] + ct13_ptct.set_plaintext(ct13_pt_ntt) + ct13 = ct13_ptct.mul(ct13_arg, use_bat=False) + ct14_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + ct14_lhs = ct14_lhs.reshape( + ct14_lhs.shape[0], + ct14_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_lhs.shape[-1], + ) + ct14_rhs = ct14_rhs.reshape( + ct14_rhs.shape[0], + ct14_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_rhs.shape[-1], + ) + if ct14_lhs.shape != ct14_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct14_num_moduli = ct14_lhs.shape[-1] + if hasattr(ct12, "moduli") and hasattr(ct13, "moduli"): + if ( + list(ct12.moduli)[:ct14_num_moduli] + != list(ct13.moduli)[:ct14_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct14_moduli_src = getattr( + ct12, "moduli", getattr(ct13, "moduli", v0.q_towers) + ) + if isinstance(ct14_moduli_src, (int, np.integer)): + ct14_moduli_src = [ct14_moduli_src] + ct14_moduli = jnp.array( + list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 + ) + ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) + ct14 = jnp.where( + ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum + ).astype(jnp.uint32) + ct15_arg = _ensure_poly(v0, ct14, v0.max_level) + ct15 = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) + ct16_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + ct16_rhs = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 + ct16_lhs = ct16_lhs.reshape( + ct16_lhs.shape[0], + ct16_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct16_lhs.shape[-1], + ) + ct16_rhs = ct16_rhs.reshape( + ct16_rhs.shape[0], + ct16_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct16_rhs.shape[-1], + ) + if ct16_lhs.shape != ct16_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct16_num_moduli = ct16_lhs.shape[-1] + if hasattr(ct1, "moduli") and hasattr(ct3, "moduli"): + if list(ct1.moduli)[:ct16_num_moduli] != list(ct3.moduli)[:ct16_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct16_moduli_src = getattr(ct1, "moduli", getattr(ct3, "moduli", v0.q_towers)) + if isinstance(ct16_moduli_src, (int, np.integer)): + ct16_moduli_src = [ct16_moduli_src] + ct16_moduli = jnp.array( + list(ct16_moduli_src)[:ct16_num_moduli], dtype=jnp.uint64 + ) + ct16_sum = ct16_lhs.astype(jnp.uint64) + ct16_rhs.astype(jnp.uint64) + ct16 = jnp.where( + ct16_sum >= ct16_moduli, ct16_sum - ct16_moduli, ct16_sum + ).astype(jnp.uint32) + ct17_lhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + ct17_rhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + ct17_lhs = ct17_lhs.reshape( + ct17_lhs.shape[0], + ct17_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct17_lhs.shape[-1], + ) + ct17_rhs = ct17_rhs.reshape( + ct17_rhs.shape[0], + ct17_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct17_rhs.shape[-1], + ) + if ct17_lhs.shape != ct17_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct17_num_moduli = ct17_lhs.shape[-1] + if hasattr(ct5, "moduli") and hasattr(ct11, "moduli"): + if ( + list(ct5.moduli)[:ct17_num_moduli] + != list(ct11.moduli)[:ct17_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct17_moduli_src = getattr(ct5, "moduli", getattr(ct11, "moduli", v0.q_towers)) + if isinstance(ct17_moduli_src, (int, np.integer)): + ct17_moduli_src = [ct17_moduli_src] + ct17_moduli = jnp.array( + list(ct17_moduli_src)[:ct17_num_moduli], dtype=jnp.uint64 + ) + ct17_sum = ct17_lhs.astype(jnp.uint64) + ct17_rhs.astype(jnp.uint64) + ct17 = jnp.where( + ct17_sum >= ct17_moduli, ct17_sum - ct17_moduli, ct17_sum + ).astype(jnp.uint32) + ct18_lhs = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 + ct18_rhs = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 + ct18_lhs = ct18_lhs.reshape( + ct18_lhs.shape[0], + ct18_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct18_lhs.shape[-1], + ) + ct18_rhs = ct18_rhs.reshape( + ct18_rhs.shape[0], + ct18_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct18_rhs.shape[-1], + ) + if ct18_lhs.shape != ct18_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct18_num_moduli = ct18_lhs.shape[-1] + if hasattr(ct17, "moduli") and hasattr(ct15, "moduli"): + if ( + list(ct17.moduli)[:ct18_num_moduli] + != list(ct15.moduli)[:ct18_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct18_moduli_src = getattr( + ct17, "moduli", getattr(ct15, "moduli", v0.q_towers) + ) + if isinstance(ct18_moduli_src, (int, np.integer)): + ct18_moduli_src = [ct18_moduli_src] + ct18_moduli = jnp.array( + list(ct18_moduli_src)[:ct18_num_moduli], dtype=jnp.uint64 + ) + ct18_sum = ct18_lhs.astype(jnp.uint64) + ct18_rhs.astype(jnp.uint64) + ct18 = jnp.where( + ct18_sum >= ct18_moduli, ct18_sum - ct18_moduli, ct18_sum + ).astype(jnp.uint32) + ct19_lhs = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 + ct19_rhs = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 + ct19_lhs = ct19_lhs.reshape( + ct19_lhs.shape[0], + ct19_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct19_lhs.shape[-1], + ) + ct19_rhs = ct19_rhs.reshape( + ct19_rhs.shape[0], + ct19_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct19_rhs.shape[-1], + ) + if ct19_lhs.shape != ct19_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct19_num_moduli = ct19_lhs.shape[-1] + if hasattr(ct16, "moduli") and hasattr(ct18, "moduli"): + if ( + list(ct16.moduli)[:ct19_num_moduli] + != list(ct18.moduli)[:ct19_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct19_moduli_src = getattr( + ct16, "moduli", getattr(ct18, "moduli", v0.q_towers) + ) + if isinstance(ct19_moduli_src, (int, np.integer)): + ct19_moduli_src = [ct19_moduli_src] + ct19_moduli = jnp.array( + list(ct19_moduli_src)[:ct19_num_moduli], dtype=jnp.uint64 + ) + ct19_sum = ct19_lhs.astype(jnp.uint64) + ct19_rhs.astype(jnp.uint64) + ct19 = jnp.where( + ct19_sum >= ct19_moduli, ct19_sum - ct19_moduli, ct19_sum + ).astype(jnp.uint32) + v16 = [None] * 1 + ct20_arg = _ensure_poly(v0, ct19, v0.max_level) + ct20 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) + v16[0] = ct20 + v17 = v16 + return v17 + + +def matvec_random( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_random__preprocessing(v0, v1) + v11 = matvec_random__preprocessed(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) + return v11 + + +def matvec_random__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + ct = _ensure_poly(v0, ct_raw) + v16 = [ct] + return v16 + + +def matvec_random__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 7 + v8 = 0 + v9 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + pt_ct = _ensure_poly(v0, ct) + _num_moduli = pt_ct.polynomial.shape[-1] + _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": pt_ct.polynomial.shape[0], + "num_elements": pt_ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + pt_ct.polynomial.reshape( + pt_ct.polynomial.shape[0], + pt_ct.polynomial.shape[1], + v0.degree, + _num_moduli, + ) + ) + pt = v0.decrypt(_ct_for_dec) + v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v11 = v9.copy() + for v12 in range(0, 8): + v14 = v7 - v12 + v15 = int(v14) + v16 = v10[0, v15] + v11[v15] = v16 + return v11 + + +def matvec_chain__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> ( + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +): + v2 = np.array( + [ + 1.340000e00, + 1.220000e00, + 1.050000e00, + 1.500000e00, + 1.010000e00, + 8.800000e-01, + 5.000000e-01, + 1.060000e00, + ], + dtype=np.float32, + ) + v3 = np.array( + [ + 5.800000e-01, + 5.200000e-01, + 8.900000e-01, + 8.600000e-01, + 1.170000e00, + 8.200000e-01, + 1.490000e00, + 1.410000e00, + ], + dtype=np.float32, + ) + v4 = np.array( + [ + 1.260000e00, + 1.090000e00, + 1.430000e00, + 1.260000e00, + 6.100000e-01, + 8.500000e-01, + 6.700000e-01, + 7.100000e-01, + ], + dtype=np.float32, + ) + v5 = np.array( + [ + 8.200000e-01, + 1.330000e00, + 7.900000e-01, + 7.400000e-01, + 1.060000e00, + 1.340000e00, + 1.090000e00, + 6.300000e-01, + ], + dtype=np.float32, + ) + v6 = np.array( + [ + 1.160000e00, + 8.400000e-01, + 1.020000e00, + 6.900000e-01, + 6.600000e-01, + 8.600000e-01, + 1.190000e00, + 6.500000e-01, + ], + dtype=np.float32, + ) + v7 = np.array( + [ + 1.350000e00, + 1.050000e00, + 1.400000e00, + 1.070000e00, + 6.500000e-01, + 5.400000e-01, + 8.000000e-01, + 9.000000e-01, + ], + dtype=np.float32, + ) + v8 = np.array( + [ + 8.200000e-01, + 9.000000e-01, + 7.400000e-01, + 1.050000e00, + 1.080000e00, + 1.480000e00, + 6.000000e-01, + 1.200000e00, + ], + dtype=np.float32, + ) + v9 = np.array( + [ + 1.190000e00, + 1.200000e00, + 8.400000e-01, + 1.350000e00, + 1.020000e00, + 7.600000e-01, + 1.390000e00, + 1.130000e00, + ], + dtype=np.float32, + ) + v10 = np.array( + [ + 1.200000e00, + 8.900000e-01, + 1.030000e00, + 7.300000e-01, + 9.300000e-01, + 7.500000e-01, + 8.400000e-01, + 1.170000e00, + ], + dtype=np.float32, + ) + v11 = np.array( + [ + 7.900000e-01, + 8.400000e-01, + 1.030000e00, + 7.900000e-01, + 1.390000e00, + 9.800000e-01, + 8.000000e-01, + 9.200000e-01, + ], + dtype=np.float32, + ) + v12 = np.array( + [ + 7.300000e-01, + 1.230000e00, + 1.130000e00, + 1.130000e00, + 1.440000e00, + 1.490000e00, + 1.020000e00, + 1.180000e00, + ], + dtype=np.float32, + ) + v13 = np.array( + [ + 1.120000e00, + 1.110000e00, + 1.380000e00, + 1.050000e00, + 9.400000e-01, + 1.350000e00, + 5.900000e-01, + 1.000000e00, + ], + dtype=np.float32, + ) + v14 = np.array( + [ + 6.200000e-01, + 6.200000e-01, + 1.010000e00, + 1.220000e00, + 5.600000e-01, + 1.220000e00, + 9.300000e-01, + 9.300000e-01, + ], + dtype=np.float32, + ) + v15 = np.array( + [ + 8.200000e-01, + 1.330000e00, + 1.170000e00, + 9.200000e-01, + 9.000000e-01, + 1.110000e00, + 1.220000e00, + 9.900000e-01, + ], + dtype=np.float32, + ) + v16 = np.array( + [ + 6.800000e-01, + 8.200000e-01, + 9.300000e-01, + 9.100000e-01, + 1.100000e00, + 1.090000e00, + 1.480000e00, + 1.240000e00, + ], + dtype=np.float32, + ) + v17 = np.array( + [ + 6.800000e-01, + 8.600000e-01, + 8.100000e-01, + 1.370000e00, + 1.050000e00, + 1.120000e00, + 1.180000e00, + 9.800000e-01, + ], + dtype=np.float32, + ) + pt = v0.encode(v2) + pt1 = v0.encode(v3) + pt2 = v0.encode(v4) + pt3 = v0.encode(v5) + pt4 = v0.encode(v6) + pt5 = v0.encode(v7) + pt6 = v0.encode(v8) + pt7 = v0.encode(v9) + pt8 = v0.encode(v10) + pt9 = v0.encode(v11) + pt10 = v0.encode(v12) + pt11 = v0.encode(v13) + pt12 = v0.encode(v14) + pt13 = v0.encode(v15) + pt14 = v0.encode(v16) + pt15 = v0.encode(v17) + v18 = [pt] + v19 = [pt1] + v20 = [pt2] + v21 = [pt3] + v22 = [pt4] + v23 = [pt5] + v24 = [pt6] + v25 = [pt7] + v26 = [pt8, pt9] + v27 = [pt10, pt11] + v28 = [pt12, pt13] + v29 = [pt14, pt15] + return (v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) + + +def matvec_chain__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, + v5: np.ndarray, + v6: np.ndarray, + v7: np.ndarray, + v8: np.ndarray, + v9: np.ndarray, + v10: np.ndarray, + v11: np.ndarray, + v12: np.ndarray, + v13: np.ndarray, + v14: np.ndarray, +) -> np.ndarray: + v15 = 1 + v16 = 2 + v17 = 3 + v18 = 6 + v19 = 0 + pt = v3[0] + pt1 = v4[0] + pt2 = v5[0] + pt3 = v6[0] + pt4 = v7[0] + pt5 = v8[0] + pt6 = v9[0] + pt7 = v10[0] + pt8 = v11[0] + pt9 = v11[1] + pt10 = v12[0] + pt11 = v12[1] + pt12 = v13[0] + pt13 = v13[1] + pt14 = v14[0] + pt15 = v14[1] + ct = v2[0] + ct1_arg = _ensure_poly(v0, ct, v0.max_level) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1 = ct1_ptct.mul(ct1_arg, use_bat=False) + ct2_arg = _ensure_poly(v0, ct, v0.max_level) + ct2 = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + ct3_arg = _ensure_poly(v0, ct2, v0.max_level) + ct3_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] + .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct3_ptct = v0.ptct_mul[v0.max_level] + ct3_ptct.set_plaintext(ct3_pt_ntt) + ct3 = ct3_ptct.mul(ct3_arg, use_bat=False) + ct4_arg = _ensure_poly(v0, ct, v0.max_level) + ct4 = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) + ct5_arg = _ensure_poly(v0, ct4, v0.max_level) + ct5_pt_ntt = ( + pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) + ct6_arg = _ensure_poly(v0, ct, v0.max_level) + ct6_pt_ntt = ( + pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] + .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct6_ptct = v0.ptct_mul[v0.max_level] + ct6_ptct.set_plaintext(ct6_pt_ntt) + ct6 = ct6_ptct.mul(ct6_arg, use_bat=False) + ct7_arg = _ensure_poly(v0, ct2, v0.max_level) + ct7_pt_ntt = ( + pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] + .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct7_ptct = v0.ptct_mul[v0.max_level] + ct7_ptct.set_plaintext(ct7_pt_ntt) + ct7 = ct7_ptct.mul(ct7_arg, use_bat=False) + ct8_arg = _ensure_poly(v0, ct4, v0.max_level) + ct8_pt_ntt = ( + pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] + .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct8_ptct = v0.ptct_mul[v0.max_level] + ct8_ptct.set_plaintext(ct8_pt_ntt) + ct8 = ct8_ptct.mul(ct8_arg, use_bat=False) + ct9_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + ct9_rhs = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 + ct9_lhs = ct9_lhs.reshape( + ct9_lhs.shape[0], + ct9_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct9_lhs.shape[-1], + ) + ct9_rhs = ct9_rhs.reshape( + ct9_rhs.shape[0], + ct9_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct9_rhs.shape[-1], + ) + if ct9_lhs.shape != ct9_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct9_num_moduli = ct9_lhs.shape[-1] + if hasattr(ct6, "moduli") and hasattr(ct7, "moduli"): + if list(ct6.moduli)[:ct9_num_moduli] != list(ct7.moduli)[:ct9_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct9_moduli_src = getattr(ct6, "moduli", getattr(ct7, "moduli", v0.q_towers)) + if isinstance(ct9_moduli_src, (int, np.integer)): + ct9_moduli_src = [ct9_moduli_src] + ct9_moduli = jnp.array( + list(ct9_moduli_src)[:ct9_num_moduli], dtype=jnp.uint64 + ) + ct9_sum = ct9_lhs.astype(jnp.uint64) + ct9_rhs.astype(jnp.uint64) + ct9 = jnp.where(ct9_sum >= ct9_moduli, ct9_sum - ct9_moduli, ct9_sum).astype( + jnp.uint32 + ) + ct10_lhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + ct10_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + ct10_lhs = ct10_lhs.reshape( + ct10_lhs.shape[0], + ct10_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct10_lhs.shape[-1], + ) + ct10_rhs = ct10_rhs.reshape( + ct10_rhs.shape[0], + ct10_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct10_rhs.shape[-1], + ) + if ct10_lhs.shape != ct10_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct10_num_moduli = ct10_lhs.shape[-1] + if hasattr(ct9, "moduli") and hasattr(ct8, "moduli"): + if list(ct9.moduli)[:ct10_num_moduli] != list(ct8.moduli)[:ct10_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct10_moduli_src = getattr(ct9, "moduli", getattr(ct8, "moduli", v0.q_towers)) + if isinstance(ct10_moduli_src, (int, np.integer)): + ct10_moduli_src = [ct10_moduli_src] + ct10_moduli = jnp.array( + list(ct10_moduli_src)[:ct10_num_moduli], dtype=jnp.uint64 + ) + ct10_sum = ct10_lhs.astype(jnp.uint64) + ct10_rhs.astype(jnp.uint64) + ct10 = jnp.where( + ct10_sum >= ct10_moduli, ct10_sum - ct10_moduli, ct10_sum + ).astype(jnp.uint32) + ct11_arg = _ensure_poly(v0, ct10, v0.max_level) + ct11 = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) + ct12_arg = _ensure_poly(v0, ct, v0.max_level) + ct12_pt_ntt = ( + pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] + .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct12_ptct = v0.ptct_mul[v0.max_level] + ct12_ptct.set_plaintext(ct12_pt_ntt) + ct12 = ct12_ptct.mul(ct12_arg, use_bat=False) + ct13_arg = _ensure_poly(v0, ct2, v0.max_level) + ct13_pt_ntt = ( + pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] + .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct13_ptct = v0.ptct_mul[v0.max_level] + ct13_ptct.set_plaintext(ct13_pt_ntt) + ct13 = ct13_ptct.mul(ct13_arg, use_bat=False) + ct14_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + ct14_lhs = ct14_lhs.reshape( + ct14_lhs.shape[0], + ct14_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_lhs.shape[-1], + ) + ct14_rhs = ct14_rhs.reshape( + ct14_rhs.shape[0], + ct14_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct14_rhs.shape[-1], + ) + if ct14_lhs.shape != ct14_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct14_num_moduli = ct14_lhs.shape[-1] + if hasattr(ct12, "moduli") and hasattr(ct13, "moduli"): + if ( + list(ct12.moduli)[:ct14_num_moduli] + != list(ct13.moduli)[:ct14_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct14_moduli_src = getattr( + ct12, "moduli", getattr(ct13, "moduli", v0.q_towers) + ) + if isinstance(ct14_moduli_src, (int, np.integer)): + ct14_moduli_src = [ct14_moduli_src] + ct14_moduli = jnp.array( + list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 + ) + ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) + ct14 = jnp.where( + ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum + ).astype(jnp.uint32) + ct15_arg = _ensure_poly(v0, ct14, v0.max_level) + ct15 = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) + ct16_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + ct16_rhs = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 + ct16_lhs = ct16_lhs.reshape( + ct16_lhs.shape[0], + ct16_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct16_lhs.shape[-1], + ) + ct16_rhs = ct16_rhs.reshape( + ct16_rhs.shape[0], + ct16_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct16_rhs.shape[-1], + ) + if ct16_lhs.shape != ct16_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct16_num_moduli = ct16_lhs.shape[-1] + if hasattr(ct1, "moduli") and hasattr(ct3, "moduli"): + if list(ct1.moduli)[:ct16_num_moduli] != list(ct3.moduli)[:ct16_num_moduli]: + raise ValueError("ciphertext add modulus mismatch") + ct16_moduli_src = getattr(ct1, "moduli", getattr(ct3, "moduli", v0.q_towers)) + if isinstance(ct16_moduli_src, (int, np.integer)): + ct16_moduli_src = [ct16_moduli_src] + ct16_moduli = jnp.array( + list(ct16_moduli_src)[:ct16_num_moduli], dtype=jnp.uint64 + ) + ct16_sum = ct16_lhs.astype(jnp.uint64) + ct16_rhs.astype(jnp.uint64) + ct16 = jnp.where( + ct16_sum >= ct16_moduli, ct16_sum - ct16_moduli, ct16_sum + ).astype(jnp.uint32) + ct17_lhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + ct17_rhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + ct17_lhs = ct17_lhs.reshape( + ct17_lhs.shape[0], + ct17_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct17_lhs.shape[-1], + ) + ct17_rhs = ct17_rhs.reshape( + ct17_rhs.shape[0], + ct17_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct17_rhs.shape[-1], + ) + if ct17_lhs.shape != ct17_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct17_num_moduli = ct17_lhs.shape[-1] + if hasattr(ct5, "moduli") and hasattr(ct11, "moduli"): + if ( + list(ct5.moduli)[:ct17_num_moduli] + != list(ct11.moduli)[:ct17_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct17_moduli_src = getattr(ct5, "moduli", getattr(ct11, "moduli", v0.q_towers)) + if isinstance(ct17_moduli_src, (int, np.integer)): + ct17_moduli_src = [ct17_moduli_src] + ct17_moduli = jnp.array( + list(ct17_moduli_src)[:ct17_num_moduli], dtype=jnp.uint64 + ) + ct17_sum = ct17_lhs.astype(jnp.uint64) + ct17_rhs.astype(jnp.uint64) + ct17 = jnp.where( + ct17_sum >= ct17_moduli, ct17_sum - ct17_moduli, ct17_sum + ).astype(jnp.uint32) + ct18_lhs = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 + ct18_rhs = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 + ct18_lhs = ct18_lhs.reshape( + ct18_lhs.shape[0], + ct18_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct18_lhs.shape[-1], + ) + ct18_rhs = ct18_rhs.reshape( + ct18_rhs.shape[0], + ct18_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct18_rhs.shape[-1], + ) + if ct18_lhs.shape != ct18_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct18_num_moduli = ct18_lhs.shape[-1] + if hasattr(ct17, "moduli") and hasattr(ct15, "moduli"): + if ( + list(ct17.moduli)[:ct18_num_moduli] + != list(ct15.moduli)[:ct18_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct18_moduli_src = getattr( + ct17, "moduli", getattr(ct15, "moduli", v0.q_towers) + ) + if isinstance(ct18_moduli_src, (int, np.integer)): + ct18_moduli_src = [ct18_moduli_src] + ct18_moduli = jnp.array( + list(ct18_moduli_src)[:ct18_num_moduli], dtype=jnp.uint64 + ) + ct18_sum = ct18_lhs.astype(jnp.uint64) + ct18_rhs.astype(jnp.uint64) + ct18 = jnp.where( + ct18_sum >= ct18_moduli, ct18_sum - ct18_moduli, ct18_sum + ).astype(jnp.uint32) + ct19_lhs = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 + ct19_rhs = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 + ct19_lhs = ct19_lhs.reshape( + ct19_lhs.shape[0], + ct19_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct19_lhs.shape[-1], + ) + ct19_rhs = ct19_rhs.reshape( + ct19_rhs.shape[0], + ct19_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct19_rhs.shape[-1], + ) + if ct19_lhs.shape != ct19_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct19_num_moduli = ct19_lhs.shape[-1] + if hasattr(ct16, "moduli") and hasattr(ct18, "moduli"): + if ( + list(ct16.moduli)[:ct19_num_moduli] + != list(ct18.moduli)[:ct19_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct19_moduli_src = getattr( + ct16, "moduli", getattr(ct18, "moduli", v0.q_towers) + ) + if isinstance(ct19_moduli_src, (int, np.integer)): + ct19_moduli_src = [ct19_moduli_src] + ct19_moduli = jnp.array( + list(ct19_moduli_src)[:ct19_num_moduli], dtype=jnp.uint64 + ) + ct19_sum = ct19_lhs.astype(jnp.uint64) + ct19_rhs.astype(jnp.uint64) + ct19 = jnp.where( + ct19_sum >= ct19_moduli, ct19_sum - ct19_moduli, ct19_sum + ).astype(jnp.uint32) + ct20_arg = _ensure_poly(v0, ct19, v0.max_level) + ct20 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) + ct21_arg = _ensure_poly(v0, ct20, v0.max_level - 1) + ct21_pt_ntt = ( + pt8.polynomial[0, 0, :, : ct21_arg.polynomial.shape[-1]] + .reshape(ct21_arg.r, ct21_arg.c, ct21_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct21_ptct = v0.ptct_mul[v0.max_level - 1] + ct21_ptct.set_plaintext(ct21_pt_ntt) + ct21 = ct21_ptct.mul(ct21_arg, use_bat=False) + ct22_arg = _ensure_poly(v0, ct19, v0.max_level) + ct22 = v0.he_rot[v0.max_level, 1].rotate(ct22_arg) + ct23_arg = _ensure_poly(v0, ct22, v0.max_level) + ct23 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct23_arg) + ct24_arg = _ensure_poly(v0, ct23, v0.max_level - 1) + ct24_pt_ntt = ( + pt9.polynomial[0, 0, :, : ct24_arg.polynomial.shape[-1]] + .reshape(ct24_arg.r, ct24_arg.c, ct24_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct24_ptct = v0.ptct_mul[v0.max_level - 1] + ct24_ptct.set_plaintext(ct24_pt_ntt) + ct24 = ct24_ptct.mul(ct24_arg, use_bat=False) + ct25_arg = _ensure_poly(v0, ct19, v0.max_level) + ct25 = v0.he_rot[v0.max_level, 2].rotate(ct25_arg) + ct26_arg = _ensure_poly(v0, ct25, v0.max_level) + ct26 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct26_arg) + ct27_arg = _ensure_poly(v0, ct26, v0.max_level - 1) + ct27_pt_ntt = ( + pt10.polynomial[0, 0, :, : ct27_arg.polynomial.shape[-1]] + .reshape(ct27_arg.r, ct27_arg.c, ct27_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct27_ptct = v0.ptct_mul[v0.max_level - 1] + ct27_ptct.set_plaintext(ct27_pt_ntt) + ct27 = ct27_ptct.mul(ct27_arg, use_bat=False) + ct28_arg = _ensure_poly(v0, ct20, v0.max_level - 1) + ct28_pt_ntt = ( + pt11.polynomial[0, 0, :, : ct28_arg.polynomial.shape[-1]] + .reshape(ct28_arg.r, ct28_arg.c, ct28_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct28_ptct = v0.ptct_mul[v0.max_level - 1] + ct28_ptct.set_plaintext(ct28_pt_ntt) + ct28 = ct28_ptct.mul(ct28_arg, use_bat=False) + ct29_arg = _ensure_poly(v0, ct23, v0.max_level - 1) + ct29_pt_ntt = ( + pt12.polynomial[0, 0, :, : ct29_arg.polynomial.shape[-1]] + .reshape(ct29_arg.r, ct29_arg.c, ct29_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct29_ptct = v0.ptct_mul[v0.max_level - 1] + ct29_ptct.set_plaintext(ct29_pt_ntt) + ct29 = ct29_ptct.mul(ct29_arg, use_bat=False) + ct30_arg = _ensure_poly(v0, ct26, v0.max_level - 1) + ct30_pt_ntt = ( + pt13.polynomial[0, 0, :, : ct30_arg.polynomial.shape[-1]] + .reshape(ct30_arg.r, ct30_arg.c, ct30_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct30_ptct = v0.ptct_mul[v0.max_level - 1] + ct30_ptct.set_plaintext(ct30_pt_ntt) + ct30 = ct30_ptct.mul(ct30_arg, use_bat=False) + ct31_lhs = ct28.polynomial if hasattr(ct28, "polynomial") else ct28 + ct31_rhs = ct29.polynomial if hasattr(ct29, "polynomial") else ct29 + ct31_lhs = ct31_lhs.reshape( + ct31_lhs.shape[0], + ct31_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct31_lhs.shape[-1], + ) + ct31_rhs = ct31_rhs.reshape( + ct31_rhs.shape[0], + ct31_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct31_rhs.shape[-1], + ) + if ct31_lhs.shape != ct31_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct31_num_moduli = ct31_lhs.shape[-1] + if hasattr(ct28, "moduli") and hasattr(ct29, "moduli"): + if ( + list(ct28.moduli)[:ct31_num_moduli] + != list(ct29.moduli)[:ct31_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct31_moduli_src = getattr( + ct28, "moduli", getattr(ct29, "moduli", v0.q_towers) + ) + if isinstance(ct31_moduli_src, (int, np.integer)): + ct31_moduli_src = [ct31_moduli_src] + ct31_moduli = jnp.array( + list(ct31_moduli_src)[:ct31_num_moduli], dtype=jnp.uint64 + ) + ct31_sum = ct31_lhs.astype(jnp.uint64) + ct31_rhs.astype(jnp.uint64) + ct31 = jnp.where( + ct31_sum >= ct31_moduli, ct31_sum - ct31_moduli, ct31_sum + ).astype(jnp.uint32) + ct32_lhs = ct31.polynomial if hasattr(ct31, "polynomial") else ct31 + ct32_rhs = ct30.polynomial if hasattr(ct30, "polynomial") else ct30 + ct32_lhs = ct32_lhs.reshape( + ct32_lhs.shape[0], + ct32_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct32_lhs.shape[-1], + ) + ct32_rhs = ct32_rhs.reshape( + ct32_rhs.shape[0], + ct32_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct32_rhs.shape[-1], + ) + if ct32_lhs.shape != ct32_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct32_num_moduli = ct32_lhs.shape[-1] + if hasattr(ct31, "moduli") and hasattr(ct30, "moduli"): + if ( + list(ct31.moduli)[:ct32_num_moduli] + != list(ct30.moduli)[:ct32_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct32_moduli_src = getattr( + ct31, "moduli", getattr(ct30, "moduli", v0.q_towers) + ) + if isinstance(ct32_moduli_src, (int, np.integer)): + ct32_moduli_src = [ct32_moduli_src] + ct32_moduli = jnp.array( + list(ct32_moduli_src)[:ct32_num_moduli], dtype=jnp.uint64 + ) + ct32_sum = ct32_lhs.astype(jnp.uint64) + ct32_rhs.astype(jnp.uint64) + ct32 = jnp.where( + ct32_sum >= ct32_moduli, ct32_sum - ct32_moduli, ct32_sum + ).astype(jnp.uint32) + ct33_arg = _ensure_poly(v0, ct32, v0.max_level - 1) + ct33 = v0.he_rot[v0.max_level - 1, 3].rotate(ct33_arg) + ct34_arg = _ensure_poly(v0, ct20, v0.max_level - 1) + ct34_pt_ntt = ( + pt14.polynomial[0, 0, :, : ct34_arg.polynomial.shape[-1]] + .reshape(ct34_arg.r, ct34_arg.c, ct34_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct34_ptct = v0.ptct_mul[v0.max_level - 1] + ct34_ptct.set_plaintext(ct34_pt_ntt) + ct34 = ct34_ptct.mul(ct34_arg, use_bat=False) + ct35_arg = _ensure_poly(v0, ct23, v0.max_level - 1) + ct35_pt_ntt = ( + pt15.polynomial[0, 0, :, : ct35_arg.polynomial.shape[-1]] + .reshape(ct35_arg.r, ct35_arg.c, ct35_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct35_ptct = v0.ptct_mul[v0.max_level - 1] + ct35_ptct.set_plaintext(ct35_pt_ntt) + ct35 = ct35_ptct.mul(ct35_arg, use_bat=False) + ct36_lhs = ct34.polynomial if hasattr(ct34, "polynomial") else ct34 + ct36_rhs = ct35.polynomial if hasattr(ct35, "polynomial") else ct35 + ct36_lhs = ct36_lhs.reshape( + ct36_lhs.shape[0], + ct36_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct36_lhs.shape[-1], + ) + ct36_rhs = ct36_rhs.reshape( + ct36_rhs.shape[0], + ct36_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct36_rhs.shape[-1], + ) + if ct36_lhs.shape != ct36_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct36_num_moduli = ct36_lhs.shape[-1] + if hasattr(ct34, "moduli") and hasattr(ct35, "moduli"): + if ( + list(ct34.moduli)[:ct36_num_moduli] + != list(ct35.moduli)[:ct36_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct36_moduli_src = getattr( + ct34, "moduli", getattr(ct35, "moduli", v0.q_towers) + ) + if isinstance(ct36_moduli_src, (int, np.integer)): + ct36_moduli_src = [ct36_moduli_src] + ct36_moduli = jnp.array( + list(ct36_moduli_src)[:ct36_num_moduli], dtype=jnp.uint64 + ) + ct36_sum = ct36_lhs.astype(jnp.uint64) + ct36_rhs.astype(jnp.uint64) + ct36 = jnp.where( + ct36_sum >= ct36_moduli, ct36_sum - ct36_moduli, ct36_sum + ).astype(jnp.uint32) + ct37_arg = _ensure_poly(v0, ct36, v0.max_level - 1) + ct37 = v0.he_rot[v0.max_level - 1, 6].rotate(ct37_arg) + ct38_lhs = ct21.polynomial if hasattr(ct21, "polynomial") else ct21 + ct38_rhs = ct24.polynomial if hasattr(ct24, "polynomial") else ct24 + ct38_lhs = ct38_lhs.reshape( + ct38_lhs.shape[0], + ct38_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct38_lhs.shape[-1], + ) + ct38_rhs = ct38_rhs.reshape( + ct38_rhs.shape[0], + ct38_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct38_rhs.shape[-1], + ) + if ct38_lhs.shape != ct38_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct38_num_moduli = ct38_lhs.shape[-1] + if hasattr(ct21, "moduli") and hasattr(ct24, "moduli"): + if ( + list(ct21.moduli)[:ct38_num_moduli] + != list(ct24.moduli)[:ct38_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct38_moduli_src = getattr( + ct21, "moduli", getattr(ct24, "moduli", v0.q_towers) + ) + if isinstance(ct38_moduli_src, (int, np.integer)): + ct38_moduli_src = [ct38_moduli_src] + ct38_moduli = jnp.array( + list(ct38_moduli_src)[:ct38_num_moduli], dtype=jnp.uint64 + ) + ct38_sum = ct38_lhs.astype(jnp.uint64) + ct38_rhs.astype(jnp.uint64) + ct38 = jnp.where( + ct38_sum >= ct38_moduli, ct38_sum - ct38_moduli, ct38_sum + ).astype(jnp.uint32) + ct39_lhs = ct27.polynomial if hasattr(ct27, "polynomial") else ct27 + ct39_rhs = ct33.polynomial if hasattr(ct33, "polynomial") else ct33 + ct39_lhs = ct39_lhs.reshape( + ct39_lhs.shape[0], + ct39_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct39_lhs.shape[-1], + ) + ct39_rhs = ct39_rhs.reshape( + ct39_rhs.shape[0], + ct39_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct39_rhs.shape[-1], + ) + if ct39_lhs.shape != ct39_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct39_num_moduli = ct39_lhs.shape[-1] + if hasattr(ct27, "moduli") and hasattr(ct33, "moduli"): + if ( + list(ct27.moduli)[:ct39_num_moduli] + != list(ct33.moduli)[:ct39_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct39_moduli_src = getattr( + ct27, "moduli", getattr(ct33, "moduli", v0.q_towers) + ) + if isinstance(ct39_moduli_src, (int, np.integer)): + ct39_moduli_src = [ct39_moduli_src] + ct39_moduli = jnp.array( + list(ct39_moduli_src)[:ct39_num_moduli], dtype=jnp.uint64 + ) + ct39_sum = ct39_lhs.astype(jnp.uint64) + ct39_rhs.astype(jnp.uint64) + ct39 = jnp.where( + ct39_sum >= ct39_moduli, ct39_sum - ct39_moduli, ct39_sum + ).astype(jnp.uint32) + ct40_lhs = ct39.polynomial if hasattr(ct39, "polynomial") else ct39 + ct40_rhs = ct37.polynomial if hasattr(ct37, "polynomial") else ct37 + ct40_lhs = ct40_lhs.reshape( + ct40_lhs.shape[0], + ct40_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct40_lhs.shape[-1], + ) + ct40_rhs = ct40_rhs.reshape( + ct40_rhs.shape[0], + ct40_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct40_rhs.shape[-1], + ) + if ct40_lhs.shape != ct40_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct40_num_moduli = ct40_lhs.shape[-1] + if hasattr(ct39, "moduli") and hasattr(ct37, "moduli"): + if ( + list(ct39.moduli)[:ct40_num_moduli] + != list(ct37.moduli)[:ct40_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct40_moduli_src = getattr( + ct39, "moduli", getattr(ct37, "moduli", v0.q_towers) + ) + if isinstance(ct40_moduli_src, (int, np.integer)): + ct40_moduli_src = [ct40_moduli_src] + ct40_moduli = jnp.array( + list(ct40_moduli_src)[:ct40_num_moduli], dtype=jnp.uint64 + ) + ct40_sum = ct40_lhs.astype(jnp.uint64) + ct40_rhs.astype(jnp.uint64) + ct40 = jnp.where( + ct40_sum >= ct40_moduli, ct40_sum - ct40_moduli, ct40_sum + ).astype(jnp.uint32) + ct41_lhs = ct38.polynomial if hasattr(ct38, "polynomial") else ct38 + ct41_rhs = ct40.polynomial if hasattr(ct40, "polynomial") else ct40 + ct41_lhs = ct41_lhs.reshape( + ct41_lhs.shape[0], + ct41_lhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct41_lhs.shape[-1], + ) + ct41_rhs = ct41_rhs.reshape( + ct41_rhs.shape[0], + ct41_rhs.shape[1], + v0._param_cache.r, + v0._param_cache.c, + ct41_rhs.shape[-1], + ) + if ct41_lhs.shape != ct41_rhs.shape: + raise ValueError("ciphertext add shape mismatch") + ct41_num_moduli = ct41_lhs.shape[-1] + if hasattr(ct38, "moduli") and hasattr(ct40, "moduli"): + if ( + list(ct38.moduli)[:ct41_num_moduli] + != list(ct40.moduli)[:ct41_num_moduli] + ): + raise ValueError("ciphertext add modulus mismatch") + ct41_moduli_src = getattr( + ct38, "moduli", getattr(ct40, "moduli", v0.q_towers) + ) + if isinstance(ct41_moduli_src, (int, np.integer)): + ct41_moduli_src = [ct41_moduli_src] + ct41_moduli = jnp.array( + list(ct41_moduli_src)[:ct41_num_moduli], dtype=jnp.uint64 + ) + ct41_sum = ct41_lhs.astype(jnp.uint64) + ct41_rhs.astype(jnp.uint64) + ct41 = jnp.where( + ct41_sum >= ct41_moduli, ct41_sum - ct41_moduli, ct41_sum + ).astype(jnp.uint32) + v20 = [None] * 1 + ct42_arg = _ensure_poly(v0, ct41, v0.max_level - 1) + ct42 = v0.he_rescale[v0.max_level - 1, v0.max_level - 2](ct42_arg) + v20[0] = ct42 + v21 = v20 + return v21 + + +def matvec_chain( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) = ( + matvec_chain__preprocessing(v0, v1) + ) + v15 = matvec_chain__preprocessed( + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 + ) + return v15 + + +def matvec_chain__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + ct = _ensure_poly(v0, ct_raw) + v16 = [ct] + return v16 + + +def matvec_chain__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 7 + v8 = 0 + v9 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + pt_ct = _ensure_poly(v0, ct) + _num_moduli = pt_ct.polynomial.shape[-1] + _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": pt_ct.polynomial.shape[0], + "num_elements": pt_ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + pt_ct.polynomial.reshape( + pt_ct.polynomial.shape[0], + pt_ct.polynomial.shape[1], + v0.degree, + _num_moduli, + ) + ) + pt = v0.decrypt(_ct_for_dec) + v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v11 = v9.copy() + for v12 in range(0, 8): + v14 = v7 - v12 + v15 = int(v14) + v16 = v10[0, v15] + v11[v15] = v16 + return v11 + + +def matvec_identity__generate_crypto_context( + v0: np.ndarray, + v1: np.ndarray, + v2: dict, +) -> ckks.CKKSContext: + params = { + "degree": 16, + "num_slots": 8, + "batch": 1, + "r": 4, + "c": 4, + "dnum": 3, + "numEvalMult": 1, + "scaling_factor": 35184372088832, + "q_towers": [1073742881, 1073742721, 1073741441, 1073741857, 524353], + "p_towers": [1073740609, 1073739937, 1073739649], + "composite_degree": 1, + "p": 30, + "max_bits_in_word": 61, + "max_bits_value": 9223372036854775295, + "noise_scale_degree": 1, + "CKKS_M_FACTOR": 1, + "public_key": v0, + "secret_key": v1, + "evaluation_key": v2, + } + v3 = ckks.CKKSContext(params) + return v3 + + +def matvec_identity__configure_crypto_context( + v0: ckks.CKKSContext, +): + v0.program_initialization( + total_hemul_levels=1, + total_rotation_indices=[1, 2, 3, 6], + dnum=3, + r=4, + c=4, + batch=1, + ) diff --git a/matvec_8x8_jaxite.mlir b/matvec_8x8_jaxite.mlir new file mode 100644 index 0000000000..2dc3ac06d1 --- /dev/null +++ b/matvec_8x8_jaxite.mlir @@ -0,0 +1,465 @@ +!Z1073741441_i64 = !mod_arith.int<1073741441 : i64> +!Z1073742721_i64 = !mod_arith.int<1073742721 : i64> +!Z1073742881_i64 = !mod_arith.int<1073742881 : i64> +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding1 = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding2 = #lwe.inverse_canonical_encoding +#inverse_canonical_encoding3 = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +#layout = #tensor_ext.layout<"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 8 = 0 and 0 <= i0 <= 7 and 0 <= slot <= 7 }"> +#modulus_chain_L4_C0 = #lwe.modulus_chain, current = 0> +#modulus_chain_L4_C1 = #lwe.modulus_chain, current = 1> +#modulus_chain_L4_C2 = #lwe.modulus_chain, current = 2> +#ring_f64_1_x8 = #polynomial.ring> +!rns_L0 = !rns.rns +!rns_L1 = !rns.rns +!rns_L2 = !rns.rns +#original_type = #tensor_ext.original_type, layout = #layout> +!pt = !lwe.lwe_plaintext> +!pt1 = !lwe.lwe_plaintext> +!pt2 = !lwe.lwe_plaintext> +#ring_rns_L0_1_x8 = #polynomial.ring> +#ring_rns_L1_1_x8 = #polynomial.ring> +#ring_rns_L2_1_x8 = #polynomial.ring> +#ciphertext_space_L0 = #lwe.ciphertext_space +#ciphertext_space_L1 = #lwe.ciphertext_space +#ciphertext_space_L2 = #lwe.ciphertext_space +!ct_L0 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L0, key = #key, modulus_chain = #modulus_chain_L4_C0> +!ct_L1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L4_C1> +!ct_L1_1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L4_C1> +!ct_L2 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L4_C2> +!ct_L2_1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L4_C2> +module attributes {scheme.ckks} { + func.func @matvec_identity__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) attributes {client.pack_func = {func_name = "matvec_identity"}} { + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<8xf32> + %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_1 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %from_elements = tensor.from_elements %pt : tensor<1x!pt> + %from_elements_2 = tensor.from_elements %pt_1 : tensor<1x!pt> + return %from_elements, %from_elements_2 : tensor<1x!pt>, tensor<1x!pt> + } + func.func @matvec_identity__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>) -> tensor<1x!ct_L1> attributes {client.preprocessed_func = {func_name = "matvec_identity"}} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c0 = arith.constant 0 : index + %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> + %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> + %extracted_1 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> + %ct = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_2 = jaxiteword.mul_plain %arg0, %ct, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_3 = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_4 = jaxiteword.mul_plain %arg0, %ct_3, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_5 = jaxiteword.mul_plain %arg0, %extracted_1, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_6 = jaxiteword.add %arg0, %ct_5, %ct_2 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_7 = jaxiteword.add %arg0, %ct_6, %ct_4 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_8 = jaxiteword.rot %arg0, %ct_7, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_9 = jaxiteword.rot %arg0, %ct_6, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_10 = jaxiteword.mul_plain %arg0, %extracted_1, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_11 = jaxiteword.add %arg0, %ct_10, %ct_2 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_12 = jaxiteword.add %arg0, %ct_4, %ct_8 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_13 = jaxiteword.add %arg0, %ct_12, %ct_9 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_14 = jaxiteword.add %arg0, %ct_11, %ct_13 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %0 = tensor.empty() : tensor<1x!ct_L1> + %ct_15 = jaxiteword.mod_reduce %arg0, %ct_14 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 + %inserted = tensor.insert %ct_15 into %0[%c0] : tensor<1x!ct_L1> + return %inserted : tensor<1x!ct_L1> + } + func.func @matvec_identity(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) { + %0:2 = call @matvec_identity__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) + %1 = call @matvec_identity__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>) -> tensor<1x!ct_L1> + return %1 : tensor<1x!ct_L1> + } + func.func @matvec_identity__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_identity", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { + %1 = arith.index_cast %arg4 : i32 to index + %extracted = tensor.extract %arg2[%1] : tensor<8xf32> + %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> + scf.yield %inserted : tensor<1x8xf32> + } + %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> + %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 + %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> + return %from_elements : tensor<1x!ct_L2> + } + func.func @matvec_identity__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L1>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_identity", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %c7_i32 = arith.constant 7 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L1> + %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L1, !jaxiteword.private_key<>) -> !pt1 + %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt1) -> tensor<1x8xf32> + %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { + %2 = arith.subi %c7_i32, %arg4 : i32 + %3 = arith.index_cast %2 : i32 to index + %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> + %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> + scf.yield %inserted : tensor<8xf32> + } + return %1 : tensor<8xf32> + } + func.func @matvec_shift__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) attributes {client.pack_func = {func_name = "matvec_shift"}} { + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %cst_0 = arith.constant dense<1.000000e+00> : tensor<8xf32> + %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_1 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %from_elements = tensor.from_elements %pt : tensor<1x!pt> + %from_elements_2 = tensor.from_elements %pt_1 : tensor<1x!pt> + return %from_elements, %from_elements_2 : tensor<1x!pt>, tensor<1x!pt> + } + func.func @matvec_shift__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>) -> tensor<1x!ct_L1> attributes {client.preprocessed_func = {func_name = "matvec_shift"}} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c0 = arith.constant 0 : index + %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> + %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> + %extracted_1 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> + %ct = jaxiteword.mul_plain %arg0, %extracted_1, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_2 = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_3 = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_4 = jaxiteword.mul_plain %arg0, %ct_3, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_5 = jaxiteword.mul_plain %arg0, %ct_2, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_6 = jaxiteword.add %arg0, %ct, %ct_5 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_7 = jaxiteword.add %arg0, %ct_6, %ct_4 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_8 = jaxiteword.rot %arg0, %ct_7, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_9 = jaxiteword.rot %arg0, %ct_6, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_10 = jaxiteword.mul_plain %arg0, %ct_2, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_11 = jaxiteword.add %arg0, %ct, %ct_10 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_12 = jaxiteword.add %arg0, %ct_4, %ct_8 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_13 = jaxiteword.add %arg0, %ct_12, %ct_9 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_14 = jaxiteword.add %arg0, %ct_11, %ct_13 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %0 = tensor.empty() : tensor<1x!ct_L1> + %ct_15 = jaxiteword.mod_reduce %arg0, %ct_14 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 + %inserted = tensor.insert %ct_15 into %0[%c0] : tensor<1x!ct_L1> + return %inserted : tensor<1x!ct_L1> + } + func.func @matvec_shift(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) { + %0:2 = call @matvec_shift__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) + %1 = call @matvec_shift__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>) -> tensor<1x!ct_L1> + return %1 : tensor<1x!ct_L1> + } + func.func @matvec_shift__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_shift", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { + %1 = arith.index_cast %arg4 : i32 to index + %extracted = tensor.extract %arg2[%1] : tensor<8xf32> + %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> + scf.yield %inserted : tensor<1x8xf32> + } + %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> + %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 + %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> + return %from_elements : tensor<1x!ct_L2> + } + func.func @matvec_shift__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L1>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_shift", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %c7_i32 = arith.constant 7 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L1> + %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L1, !jaxiteword.private_key<>) -> !pt1 + %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt1) -> tensor<1x8xf32> + %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { + %2 = arith.subi %c7_i32, %arg4 : i32 + %3 = arith.index_cast %2 : i32 to index + %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> + %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> + scf.yield %inserted : tensor<8xf32> + } + return %1 : tensor<8xf32> + } + func.func @matvec_random__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>) attributes {client.pack_func = {func_name = "matvec_random"}} { + %cst = arith.constant dense<[0.811626255, 1.44533789, 0.920695543, 1.07704544, 0.678766131, 1.3587923, 1.236010e+00, 0.777831316]> : tensor<8xf32> + %cst_0 = arith.constant dense<[1.90635717, 0.139110535, 0.653335392, 1.22558773, 0.285577029, 6.922510e-01, 1.85156107, 0.268135756]> : tensor<8xf32> + %cst_1 = arith.constant dense<[1.49078846, 1.94282877, 1.26252055, 0.188255787, 1.40004277, 1.08812928, 1.13874948, 0.472367436]> : tensor<8xf32> + %cst_2 = arith.constant dense<[0.331872642, 0.451223463, 0.185931846, 1.23745108, 1.68164098, 0.365038335, 1.25433517, 0.936289727]> : tensor<8xf32> + %cst_3 = arith.constant dense<[1.0408361, 1.94221079, 0.718127608, 0.39643541, 0.503444314, 0.655074835, 0.423995823, 0.223598033]> : tensor<8xf32> + %cst_4 = arith.constant dense<[0.165338188, 1.57275236, 0.83848685, 0.396389604, 0.445467442, 0.796087503, 0.966532945, 1.90288258]> : tensor<8xf32> + %cst_5 = arith.constant dense<[0.678060233, 1.59183431, 1.93470085, 1.82770872, 1.88504803, 0.615563154, 0.210358858, 0.448468566]> : tensor<8xf32> + %cst_6 = arith.constant dense<[1.0970372, 0.47938019, 1.63595498, 0.591681957, 1.80017197, 1.67460132, 1.74573469, 1.24211848]> : tensor<8xf32> + %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_7 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_8 = jaxiteword.encode %arg0, %cst_1 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_9 = jaxiteword.encode %arg0, %cst_2 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_10 = jaxiteword.encode %arg0, %cst_3 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_11 = jaxiteword.encode %arg0, %cst_4 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_12 = jaxiteword.encode %arg0, %cst_5 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_13 = jaxiteword.encode %arg0, %cst_6 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %from_elements = tensor.from_elements %pt : tensor<1x!pt> + %from_elements_14 = tensor.from_elements %pt_7 : tensor<1x!pt> + %from_elements_15 = tensor.from_elements %pt_8 : tensor<1x!pt> + %from_elements_16 = tensor.from_elements %pt_9 : tensor<1x!pt> + %from_elements_17 = tensor.from_elements %pt_10 : tensor<1x!pt> + %from_elements_18 = tensor.from_elements %pt_11 : tensor<1x!pt> + %from_elements_19 = tensor.from_elements %pt_12 : tensor<1x!pt> + %from_elements_20 = tensor.from_elements %pt_13 : tensor<1x!pt> + return %from_elements, %from_elements_14, %from_elements_15, %from_elements_16, %from_elements_17, %from_elements_18, %from_elements_19, %from_elements_20 : tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt> + } + func.func @matvec_random__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>, %arg5: tensor<1x!pt>, %arg6: tensor<1x!pt>, %arg7: tensor<1x!pt>, %arg8: tensor<1x!pt>, %arg9: tensor<1x!pt>, %arg10: tensor<1x!pt>) -> tensor<1x!ct_L1> attributes {client.preprocessed_func = {func_name = "matvec_random"}} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c0 = arith.constant 0 : index + %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> + %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> + %extracted_1 = tensor.extract %arg5[%c0] : tensor<1x!pt> + %extracted_2 = tensor.extract %arg6[%c0] : tensor<1x!pt> + %extracted_3 = tensor.extract %arg7[%c0] : tensor<1x!pt> + %extracted_4 = tensor.extract %arg8[%c0] : tensor<1x!pt> + %extracted_5 = tensor.extract %arg9[%c0] : tensor<1x!pt> + %extracted_6 = tensor.extract %arg10[%c0] : tensor<1x!pt> + %extracted_7 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> + %ct = jaxiteword.mul_plain %arg0, %extracted_7, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_8 = jaxiteword.rot %arg0, %extracted_7, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_9 = jaxiteword.mul_plain %arg0, %ct_8, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_10 = jaxiteword.rot %arg0, %extracted_7, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_11 = jaxiteword.mul_plain %arg0, %ct_10, %extracted_1 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_12 = jaxiteword.mul_plain %arg0, %extracted_7, %extracted_2 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_13 = jaxiteword.mul_plain %arg0, %ct_8, %extracted_3 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_14 = jaxiteword.mul_plain %arg0, %ct_10, %extracted_4 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_15 = jaxiteword.add %arg0, %ct_12, %ct_13 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_16 = jaxiteword.add %arg0, %ct_15, %ct_14 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_17 = jaxiteword.rot %arg0, %ct_16, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_18 = jaxiteword.mul_plain %arg0, %extracted_7, %extracted_5 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_19 = jaxiteword.mul_plain %arg0, %ct_8, %extracted_6 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_20 = jaxiteword.add %arg0, %ct_18, %ct_19 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_21 = jaxiteword.rot %arg0, %ct_20, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_22 = jaxiteword.add %arg0, %ct, %ct_9 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_23 = jaxiteword.add %arg0, %ct_11, %ct_17 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_24 = jaxiteword.add %arg0, %ct_23, %ct_21 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_25 = jaxiteword.add %arg0, %ct_22, %ct_24 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %0 = tensor.empty() : tensor<1x!ct_L1> + %ct_26 = jaxiteword.mod_reduce %arg0, %ct_25 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 + %inserted = tensor.insert %ct_26 into %0[%c0] : tensor<1x!ct_L1> + return %inserted : tensor<1x!ct_L1> + } + func.func @matvec_random(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) { + %0:8 = call @matvec_random__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>) + %1 = call @matvec_random__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>) -> tensor<1x!ct_L1> + return %1 : tensor<1x!ct_L1> + } + func.func @matvec_random__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_random", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { + %1 = arith.index_cast %arg4 : i32 to index + %extracted = tensor.extract %arg2[%1] : tensor<8xf32> + %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> + scf.yield %inserted : tensor<1x8xf32> + } + %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> + %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 + %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> + return %from_elements : tensor<1x!ct_L2> + } + func.func @matvec_random__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L1>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_random", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %c7_i32 = arith.constant 7 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L1> + %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L1, !jaxiteword.private_key<>) -> !pt1 + %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt1) -> tensor<1x8xf32> + %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { + %2 = arith.subi %c7_i32, %arg4 : i32 + %3 = arith.index_cast %2 : i32 to index + %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> + %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> + scf.yield %inserted : tensor<8xf32> + } + return %1 : tensor<8xf32> + } + func.func @matvec_chain__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>) attributes {client.pack_func = {func_name = "matvec_chain"}} { + %cst = arith.constant dense<[1.340000e+00, 1.220000e+00, 1.050000e+00, 1.500000e+00, 1.010000e+00, 0.879999995, 5.000000e-01, 1.060000e+00]> : tensor<8xf32> + %cst_0 = arith.constant dense<[5.800000e-01, 5.200000e-01, 0.889999985, 8.600000e-01, 1.170000e+00, 0.819999992, 1.490000e+00, 1.410000e+00]> : tensor<8xf32> + %cst_1 = arith.constant dense<[1.260000e+00, 1.090000e+00, 1.430000e+00, 1.260000e+00, 6.100000e-01, 8.500000e-01, 6.700000e-01, 0.709999978]> : tensor<8xf32> + %cst_2 = arith.constant dense<[0.819999992, 1.330000e+00, 7.900000e-01, 7.400000e-01, 1.060000e+00, 1.340000e+00, 1.090000e+00, 6.300000e-01]> : tensor<8xf32> + %cst_3 = arith.constant dense<[1.160000e+00, 0.839999973, 1.020000e+00, 0.689999997, 6.600000e-01, 8.600000e-01, 1.190000e+00, 6.500000e-01]> : tensor<8xf32> + %cst_4 = arith.constant dense<[1.350000e+00, 1.050000e+00, 1.400000e+00, 1.070000e+00, 6.500000e-01, 5.400000e-01, 8.000000e-01, 0.899999976]> : tensor<8xf32> + %cst_5 = arith.constant dense<[0.819999992, 0.899999976, 7.400000e-01, 1.050000e+00, 1.080000e+00, 1.480000e+00, 6.000000e-01, 1.200000e+00]> : tensor<8xf32> + %cst_6 = arith.constant dense<[1.190000e+00, 1.200000e+00, 0.839999973, 1.350000e+00, 1.020000e+00, 7.600000e-01, 1.390000e+00, 1.130000e+00]> : tensor<8xf32> + %cst_7 = arith.constant dense<[1.200000e+00, 0.889999985, 1.030000e+00, 7.300000e-01, 9.300000e-01, 7.500000e-01, 0.839999973, 1.170000e+00]> : tensor<8xf32> + %cst_8 = arith.constant dense<[7.900000e-01, 0.839999973, 1.030000e+00, 7.900000e-01, 1.390000e+00, 9.800000e-01, 8.000000e-01, 9.200000e-01]> : tensor<8xf32> + %cst_9 = arith.constant dense<[7.300000e-01, 1.230000e+00, 1.130000e+00, 1.130000e+00, 1.440000e+00, 1.490000e+00, 1.020000e+00, 1.180000e+00]> : tensor<8xf32> + %cst_10 = arith.constant dense<[1.120000e+00, 1.110000e+00, 1.380000e+00, 1.050000e+00, 0.939999997, 1.350000e+00, 5.900000e-01, 1.000000e+00]> : tensor<8xf32> + %cst_11 = arith.constant dense<[6.200000e-01, 6.200000e-01, 1.010000e+00, 1.220000e+00, 5.600000e-01, 1.220000e+00, 9.300000e-01, 9.300000e-01]> : tensor<8xf32> + %cst_12 = arith.constant dense<[0.819999992, 1.330000e+00, 1.170000e+00, 9.200000e-01, 0.899999976, 1.110000e+00, 1.220000e+00, 9.900000e-01]> : tensor<8xf32> + %cst_13 = arith.constant dense<[6.800000e-01, 0.819999992, 9.300000e-01, 9.100000e-01, 1.100000e+00, 1.090000e+00, 1.480000e+00, 1.240000e+00]> : tensor<8xf32> + %cst_14 = arith.constant dense<[6.800000e-01, 8.600000e-01, 8.100000e-01, 1.370000e+00, 1.050000e+00, 1.120000e+00, 1.180000e+00, 9.800000e-01]> : tensor<8xf32> + %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_15 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_16 = jaxiteword.encode %arg0, %cst_1 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_17 = jaxiteword.encode %arg0, %cst_2 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_18 = jaxiteword.encode %arg0, %cst_3 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_19 = jaxiteword.encode %arg0, %cst_4 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_20 = jaxiteword.encode %arg0, %cst_5 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_21 = jaxiteword.encode %arg0, %cst_6 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %pt_22 = jaxiteword.encode %arg0, %cst_7 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_23 = jaxiteword.encode %arg0, %cst_8 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_24 = jaxiteword.encode %arg0, %cst_9 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_25 = jaxiteword.encode %arg0, %cst_10 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_26 = jaxiteword.encode %arg0, %cst_11 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_27 = jaxiteword.encode %arg0, %cst_12 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_28 = jaxiteword.encode %arg0, %cst_13 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %pt_29 = jaxiteword.encode %arg0, %cst_14 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 + %from_elements = tensor.from_elements %pt : tensor<1x!pt> + %from_elements_30 = tensor.from_elements %pt_15 : tensor<1x!pt> + %from_elements_31 = tensor.from_elements %pt_16 : tensor<1x!pt> + %from_elements_32 = tensor.from_elements %pt_17 : tensor<1x!pt> + %from_elements_33 = tensor.from_elements %pt_18 : tensor<1x!pt> + %from_elements_34 = tensor.from_elements %pt_19 : tensor<1x!pt> + %from_elements_35 = tensor.from_elements %pt_20 : tensor<1x!pt> + %from_elements_36 = tensor.from_elements %pt_21 : tensor<1x!pt> + %from_elements_37 = tensor.from_elements %pt_22, %pt_23 : tensor<2x!pt1> + %from_elements_38 = tensor.from_elements %pt_24, %pt_25 : tensor<2x!pt1> + %from_elements_39 = tensor.from_elements %pt_26, %pt_27 : tensor<2x!pt1> + %from_elements_40 = tensor.from_elements %pt_28, %pt_29 : tensor<2x!pt1> + return %from_elements, %from_elements_30, %from_elements_31, %from_elements_32, %from_elements_33, %from_elements_34, %from_elements_35, %from_elements_36, %from_elements_37, %from_elements_38, %from_elements_39, %from_elements_40 : tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1> + } + func.func @matvec_chain__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>, %arg5: tensor<1x!pt>, %arg6: tensor<1x!pt>, %arg7: tensor<1x!pt>, %arg8: tensor<1x!pt>, %arg9: tensor<1x!pt>, %arg10: tensor<1x!pt>, %arg11: tensor<2x!pt1>, %arg12: tensor<2x!pt1>, %arg13: tensor<2x!pt1>, %arg14: tensor<2x!pt1>) -> tensor<1x!ct_L0> attributes {client.preprocessed_func = {func_name = "matvec_chain"}} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c0 = arith.constant 0 : index + %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> + %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> + %extracted_1 = tensor.extract %arg5[%c0] : tensor<1x!pt> + %extracted_2 = tensor.extract %arg6[%c0] : tensor<1x!pt> + %extracted_3 = tensor.extract %arg7[%c0] : tensor<1x!pt> + %extracted_4 = tensor.extract %arg8[%c0] : tensor<1x!pt> + %extracted_5 = tensor.extract %arg9[%c0] : tensor<1x!pt> + %extracted_6 = tensor.extract %arg10[%c0] : tensor<1x!pt> + %extracted_7 = tensor.extract %arg11[%c0] : tensor<2x!pt1> + %extracted_8 = tensor.extract %arg11[%c1] : tensor<2x!pt1> + %extracted_9 = tensor.extract %arg12[%c0] : tensor<2x!pt1> + %extracted_10 = tensor.extract %arg12[%c1] : tensor<2x!pt1> + %extracted_11 = tensor.extract %arg13[%c0] : tensor<2x!pt1> + %extracted_12 = tensor.extract %arg13[%c1] : tensor<2x!pt1> + %extracted_13 = tensor.extract %arg14[%c0] : tensor<2x!pt1> + %extracted_14 = tensor.extract %arg14[%c1] : tensor<2x!pt1> + %extracted_15 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> + %ct = jaxiteword.mul_plain %arg0, %extracted_15, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_16 = jaxiteword.rot %arg0, %extracted_15, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_17 = jaxiteword.mul_plain %arg0, %ct_16, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_18 = jaxiteword.rot %arg0, %extracted_15, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 + %ct_19 = jaxiteword.mul_plain %arg0, %ct_18, %extracted_1 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_20 = jaxiteword.mul_plain %arg0, %extracted_15, %extracted_2 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_21 = jaxiteword.mul_plain %arg0, %ct_16, %extracted_3 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_22 = jaxiteword.mul_plain %arg0, %ct_18, %extracted_4 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_23 = jaxiteword.add %arg0, %ct_20, %ct_21 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_24 = jaxiteword.add %arg0, %ct_23, %ct_22 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_25 = jaxiteword.rot %arg0, %ct_24, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_26 = jaxiteword.mul_plain %arg0, %extracted_15, %extracted_5 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_27 = jaxiteword.mul_plain %arg0, %ct_16, %extracted_6 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 + %ct_28 = jaxiteword.add %arg0, %ct_26, %ct_27 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_29 = jaxiteword.rot %arg0, %ct_28, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_30 = jaxiteword.add %arg0, %ct, %ct_17 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_31 = jaxiteword.add %arg0, %ct_19, %ct_25 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_32 = jaxiteword.add %arg0, %ct_31, %ct_29 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_33 = jaxiteword.add %arg0, %ct_30, %ct_32 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 + %ct_34 = jaxiteword.mod_reduce %arg0, %ct_33 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 + %ct_35 = jaxiteword.mul_plain %arg0, %ct_34, %extracted_7 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_36 = jaxiteword.rot %arg0, %ct_33, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_37 = jaxiteword.mod_reduce %arg0, %ct_36 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 + %ct_38 = jaxiteword.mul_plain %arg0, %ct_37, %extracted_8 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_39 = jaxiteword.rot %arg0, %ct_33, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 + %ct_40 = jaxiteword.mod_reduce %arg0, %ct_39 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 + %ct_41 = jaxiteword.mul_plain %arg0, %ct_40, %extracted_9 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_42 = jaxiteword.mul_plain %arg0, %ct_34, %extracted_10 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_43 = jaxiteword.mul_plain %arg0, %ct_37, %extracted_11 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_44 = jaxiteword.mul_plain %arg0, %ct_40, %extracted_12 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_45 = jaxiteword.add %arg0, %ct_42, %ct_43 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %ct_46 = jaxiteword.add %arg0, %ct_45, %ct_44 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %ct_47 = jaxiteword.rot %arg0, %ct_46, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L1_1, !jaxiteword.eval_key<>) -> !ct_L1_1 + %ct_48 = jaxiteword.mul_plain %arg0, %ct_34, %extracted_13 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_49 = jaxiteword.mul_plain %arg0, %ct_37, %extracted_14 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 + %ct_50 = jaxiteword.add %arg0, %ct_48, %ct_49 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %ct_51 = jaxiteword.rot %arg0, %ct_50, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L1_1, !jaxiteword.eval_key<>) -> !ct_L1_1 + %ct_52 = jaxiteword.add %arg0, %ct_35, %ct_38 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %ct_53 = jaxiteword.add %arg0, %ct_41, %ct_47 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %ct_54 = jaxiteword.add %arg0, %ct_53, %ct_51 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %ct_55 = jaxiteword.add %arg0, %ct_52, %ct_54 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %0 = tensor.empty() : tensor<1x!ct_L0> + %ct_56 = jaxiteword.mod_reduce %arg0, %ct_55 : (!jaxiteword.crypto_context<>, !ct_L1_1) -> !ct_L0 + %inserted = tensor.insert %ct_56 into %0[%c0] : tensor<1x!ct_L0> + return %inserted : tensor<1x!ct_L0> + } + func.func @matvec_chain(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L0> {tensor_ext.original_type = #original_type}) { + %0:12 = call @matvec_chain__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>) + %1 = call @matvec_chain__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>) -> tensor<1x!ct_L0> + return %1 : tensor<1x!ct_L0> + } + func.func @matvec_chain__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_chain", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { + %1 = arith.index_cast %arg4 : i32 to index + %extracted = tensor.extract %arg2[%1] : tensor<8xf32> + %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> + scf.yield %inserted : tensor<1x8xf32> + } + %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> + %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt + %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 + %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> + return %from_elements : tensor<1x!ct_L2> + } + func.func @matvec_chain__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L0>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_chain", index = 0 : i64}} { + %c0 = arith.constant 0 : index + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %c7_i32 = arith.constant 7 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> + %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L0> + %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L0, !jaxiteword.private_key<>) -> !pt2 + %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt2) -> tensor<1x8xf32> + %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { + %2 = arith.subi %c7_i32, %arg4 : i32 + %3 = arith.index_cast %2 : i32 to index + %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> + %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> + scf.yield %inserted : tensor<8xf32> + } + return %1 : tensor<8xf32> + } + func.func @matvec_identity__generate_crypto_context(%arg0: !jaxiteword.public_key<>, %arg1: !jaxiteword.private_key<>, %arg2: !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> { + %0 = jaxiteword.gen_params %arg0, %arg1, %arg2 {batch = 1 : i32, c = 4 : i32, compositeDegree = 1 : i32, degree = 16 : i64, dnum = 3 : i32, numEvalMult = 1 : i32, numSlots = 8 : i64, pTowers = array, qTowers = array, r = 4 : i32, scalingFactor = 0x42C0000000000000 : f64} : (!jaxiteword.public_key<>, !jaxiteword.private_key<>, !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> + return %0 : !jaxiteword.crypto_context<> + } + func.func @matvec_identity__configure_crypto_context(%arg0: !jaxiteword.crypto_context<>) { + jaxiteword.program_initialization %arg0 {batch = 1 : i32, c = 4 : i32, dnum = 3 : i32, r = 4 : i32, totalHemulLevels = 1 : i64, totalRotationIndices = array} : (!jaxiteword.crypto_context<>) -> () + return + } +} diff --git a/matvec_8x8_jaxiteword.mlir b/matvec_8x8_jaxiteword.mlir new file mode 100644 index 0000000000..e69de29bb2 diff --git a/matvec_8x8_jaxiteword.py b/matvec_8x8_jaxiteword.py new file mode 100644 index 0000000000..03cc11bf62 --- /dev/null +++ b/matvec_8x8_jaxiteword.py @@ -0,0 +1,11122 @@ +import jax +import jax.numpy as jnp +import numpy as np +from ciphertext import Ciphertext +from polynomial import Polynomial +import ckks_ctx as ckks + + +def _assign_layout_15335824159471298539( + v0: np.ndarray, +) -> np.ndarray: + v1 = 8 + v2 = np.full( + ( + 8, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v3 = 0 + v4 = 1 + v5 = v2.copy() + for v6 in range(0, 8): + for v9 in range(0, 8): + v11 = v6 + v9 + v12 = v11 % v1 + v13 = int(v9) + v14 = int(v12) + v15 = v0[v13, v14] + v16 = int(v6) + v5[v16, v13] = v15 + return v5 + + +def matvec_identity__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> ( + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +): + v2 = np.array( + [ + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + ], + dtype=np.float32, + ).reshape(8, 8) + v3 = _assign_layout_15335824159471298539(v2) + v4 = v3[3 : 3 + 1, 0 : 0 + 5] + v5 = v3[3 : 3 + 1, 5 : 5 + 3] + v6 = np.zeros( + ( + 1, + 8, + ), + dtype=np.float32, + ) + v7 = v6.copy() + v7[0 : 0 + 1, 3 : 3 + 5] = v4 + v8 = v7.copy() + v8[0 : 0 + 1, 0 : 0 + 3] = v5 + v9 = v3[4 : 4 + 1, 0 : 0 + 5] + v10 = v3[4 : 4 + 1, 5 : 5 + 3] + v11 = v6.copy() + v11[0 : 0 + 1, 3 : 3 + 5] = v9 + v12 = v11.copy() + v12[0 : 0 + 1, 0 : 0 + 3] = v10 + v13 = v3[5 : 5 + 1, 0 : 0 + 5] + v14 = v3[5 : 5 + 1, 5 : 5 + 3] + v15 = v6.copy() + v15[0 : 0 + 1, 3 : 3 + 5] = v13 + v16 = v15.copy() + v16[0 : 0 + 1, 0 : 0 + 3] = v14 + v17 = v3[6 : 6 + 1, 0 : 0 + 2] + v18 = v3[6 : 6 + 1, 2 : 2 + 6] + v19 = v6.copy() + v19[0 : 0 + 1, 6 : 6 + 2] = v17 + v20 = v19.copy() + v20[0 : 0 + 1, 0 : 0 + 6] = v18 + v21 = v3[7 : 7 + 1, 0 : 0 + 2] + v22 = v3[7 : 7 + 1, 2 : 2 + 6] + v23 = v6.copy() + v23[0 : 0 + 1, 6 : 6 + 2] = v21 + v24 = v23.copy() + v24[0 : 0 + 1, 0 : 0 + 6] = v22 + v25 = v3[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v25) + v26 = v3[1 : 1 + 1, 0 : 0 + 8].reshape(8) + pt1 = v0.encode(v26) + v27 = v3[2 : 2 + 1, 0 : 0 + 8].reshape(8) + pt2 = v0.encode(v27) + v28 = v8[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt3 = v0.encode(v28) + v29 = v12[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt4 = v0.encode(v29) + v30 = v16[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt5 = v0.encode(v30) + v31 = v20[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt6 = v0.encode(v31) + v32 = v24[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt7 = v0.encode(v32) + v33 = [pt] + v34 = [pt1] + v35 = [pt2] + v36 = [pt3] + v37 = [pt4] + v38 = [pt5] + v39 = [pt6] + v40 = [pt7] + return (v33, v34, v35, v36, v37, v38, v39, v40) + + +def matvec_identity__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, + v5: np.ndarray, + v6: np.ndarray, + v7: np.ndarray, + v8: np.ndarray, + v9: np.ndarray, + v10: np.ndarray, +) -> np.ndarray: + v11 = 1 + v12 = 2 + v13 = 3 + v14 = 6 + v15 = 0 + pt = v3[0] + pt1 = v4[0] + pt2 = v5[0] + pt3 = v6[0] + pt4 = v7[0] + pt5 = v8[0] + pt6 = v9[0] + pt7 = v10[0] + ct = v2[0] + _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct1_arg_m_in = _ct1_arg_data.shape[-1] + _ct1_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_arg_m_in + ) + _ct1_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_arg_r) + ) + _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct1_arg_moduli, (int, np.integer)): + _ct1_arg_moduli = [int(_ct1_arg_moduli)] + ct1_arg = Polynomial( + { + "batch": _ct1_arg_data.shape[0], + "num_elements": _ct1_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_arg_m, + "precision": 32, + "degree_layout": (_ct1_arg_r, _ct1_arg_c), + }, + {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, + ) + ct1_arg.polynomial = _ct1_arg_data.reshape( + _ct1_arg_data.shape[0], + _ct1_arg_data.shape[1], + _ct1_arg_r, + _ct1_arg_c, + _ct1_arg_m_in, + )[..., :_ct1_arg_m].copy() + ct1_arg.batch = ct1_arg.polynomial.shape[0] + ct1_arg.num_elements = ct1_arg.polynomial.shape[1] + ct1_arg.num_moduli = _ct1_arg_m + ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) + ct1_arg.r = _ct1_arg_r + ct1_arg.c = _ct1_arg_c + ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] + ct1_arg.moduli_array = jnp.array( + ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) + ) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) + _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw + _ct1_m_in = _ct1_data.shape[-1] + _ct1_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_m_in + ) + _ct1_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_r) + ) + _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) + if isinstance(_ct1_moduli, (int, np.integer)): + _ct1_moduli = [int(_ct1_moduli)] + ct1 = Polynomial( + { + "batch": _ct1_data.shape[0], + "num_elements": _ct1_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_m, + "precision": 32, + "degree_layout": (_ct1_r, _ct1_c), + }, + {"moduli": list(_ct1_moduli)[:_ct1_m]}, + ) + ct1.polynomial = _ct1_data.reshape( + _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in + )[..., :_ct1_m].copy() + ct1.batch = ct1.polynomial.shape[0] + ct1.num_elements = ct1.polynomial.shape[1] + ct1.num_moduli = _ct1_m + ct1.degree_layout = (_ct1_r, _ct1_c) + ct1.r = _ct1_r + ct1.c = _ct1_c + ct1.moduli = list(_ct1_moduli)[:_ct1_m] + ct1.moduli_array = jnp.array( + ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) + ) + _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct2_arg_m_in = _ct2_arg_data.shape[-1] + _ct2_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_arg_m_in + ) + _ct2_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_arg_r) + ) + _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct2_arg_moduli, (int, np.integer)): + _ct2_arg_moduli = [int(_ct2_arg_moduli)] + ct2_arg = Polynomial( + { + "batch": _ct2_arg_data.shape[0], + "num_elements": _ct2_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_arg_m, + "precision": 32, + "degree_layout": (_ct2_arg_r, _ct2_arg_c), + }, + {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, + ) + ct2_arg.polynomial = _ct2_arg_data.reshape( + _ct2_arg_data.shape[0], + _ct2_arg_data.shape[1], + _ct2_arg_r, + _ct2_arg_c, + _ct2_arg_m_in, + )[..., :_ct2_arg_m].copy() + ct2_arg.batch = ct2_arg.polynomial.shape[0] + ct2_arg.num_elements = ct2_arg.polynomial.shape[1] + ct2_arg.num_moduli = _ct2_arg_m + ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) + ct2_arg.r = _ct2_arg_r + ct2_arg.c = _ct2_arg_c + ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] + ct2_arg.moduli_array = jnp.array( + ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) + ) + ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw + _ct2_m_in = _ct2_data.shape[-1] + _ct2_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_m_in + ) + _ct2_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_r) + ) + _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) + if isinstance(_ct2_moduli, (int, np.integer)): + _ct2_moduli = [int(_ct2_moduli)] + ct2 = Polynomial( + { + "batch": _ct2_data.shape[0], + "num_elements": _ct2_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_m, + "precision": 32, + "degree_layout": (_ct2_r, _ct2_c), + }, + {"moduli": list(_ct2_moduli)[:_ct2_m]}, + ) + ct2.polynomial = _ct2_data.reshape( + _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in + )[..., :_ct2_m].copy() + ct2.batch = ct2.polynomial.shape[0] + ct2.num_elements = ct2.polynomial.shape[1] + ct2.num_moduli = _ct2_m + ct2.degree_layout = (_ct2_r, _ct2_c) + ct2.r = _ct2_r + ct2.c = _ct2_c + ct2.moduli = list(_ct2_moduli)[:_ct2_m] + ct2.moduli_array = jnp.array( + ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) + ) + _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct3_arg_m_in = _ct3_arg_data.shape[-1] + _ct3_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_arg_m_in + ) + _ct3_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_arg_r) + ) + _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct3_arg_moduli, (int, np.integer)): + _ct3_arg_moduli = [int(_ct3_arg_moduli)] + ct3_arg = Polynomial( + { + "batch": _ct3_arg_data.shape[0], + "num_elements": _ct3_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_arg_m, + "precision": 32, + "degree_layout": (_ct3_arg_r, _ct3_arg_c), + }, + {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, + ) + ct3_arg.polynomial = _ct3_arg_data.reshape( + _ct3_arg_data.shape[0], + _ct3_arg_data.shape[1], + _ct3_arg_r, + _ct3_arg_c, + _ct3_arg_m_in, + )[..., :_ct3_arg_m].copy() + ct3_arg.batch = ct3_arg.polynomial.shape[0] + ct3_arg.num_elements = ct3_arg.polynomial.shape[1] + ct3_arg.num_moduli = _ct3_arg_m + ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) + ct3_arg.r = _ct3_arg_r + ct3_arg.c = _ct3_arg_c + ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] + ct3_arg.moduli_array = jnp.array( + ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) + ) + ct3_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] + .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct3_ptct = v0.ptct_mul[v0.max_level] + ct3_ptct.set_plaintext(ct3_pt_ntt) + ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) + _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw + _ct3_m_in = _ct3_data.shape[-1] + _ct3_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_m_in + ) + _ct3_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_r) + ) + _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) + if isinstance(_ct3_moduli, (int, np.integer)): + _ct3_moduli = [int(_ct3_moduli)] + ct3 = Polynomial( + { + "batch": _ct3_data.shape[0], + "num_elements": _ct3_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_m, + "precision": 32, + "degree_layout": (_ct3_r, _ct3_c), + }, + {"moduli": list(_ct3_moduli)[:_ct3_m]}, + ) + ct3.polynomial = _ct3_data.reshape( + _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in + )[..., :_ct3_m].copy() + ct3.batch = ct3.polynomial.shape[0] + ct3.num_elements = ct3.polynomial.shape[1] + ct3.num_moduli = _ct3_m + ct3.degree_layout = (_ct3_r, _ct3_c) + ct3.r = _ct3_r + ct3.c = _ct3_c + ct3.moduli = list(_ct3_moduli)[:_ct3_m] + ct3.moduli_array = jnp.array( + ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) + ) + _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct4_arg_m_in = _ct4_arg_data.shape[-1] + _ct4_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_arg_m_in + ) + _ct4_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_arg_r) + ) + _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct4_arg_moduli, (int, np.integer)): + _ct4_arg_moduli = [int(_ct4_arg_moduli)] + ct4_arg = Polynomial( + { + "batch": _ct4_arg_data.shape[0], + "num_elements": _ct4_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_arg_m, + "precision": 32, + "degree_layout": (_ct4_arg_r, _ct4_arg_c), + }, + {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, + ) + ct4_arg.polynomial = _ct4_arg_data.reshape( + _ct4_arg_data.shape[0], + _ct4_arg_data.shape[1], + _ct4_arg_r, + _ct4_arg_c, + _ct4_arg_m_in, + )[..., :_ct4_arg_m].copy() + ct4_arg.batch = ct4_arg.polynomial.shape[0] + ct4_arg.num_elements = ct4_arg.polynomial.shape[1] + ct4_arg.num_moduli = _ct4_arg_m + ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) + ct4_arg.r = _ct4_arg_r + ct4_arg.c = _ct4_arg_c + ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] + ct4_arg.moduli_array = jnp.array( + ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) + ) + ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) + _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw + _ct4_m_in = _ct4_data.shape[-1] + _ct4_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_m_in + ) + _ct4_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_r) + ) + _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) + if isinstance(_ct4_moduli, (int, np.integer)): + _ct4_moduli = [int(_ct4_moduli)] + ct4 = Polynomial( + { + "batch": _ct4_data.shape[0], + "num_elements": _ct4_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_m, + "precision": 32, + "degree_layout": (_ct4_r, _ct4_c), + }, + {"moduli": list(_ct4_moduli)[:_ct4_m]}, + ) + ct4.polynomial = _ct4_data.reshape( + _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in + )[..., :_ct4_m].copy() + ct4.batch = ct4.polynomial.shape[0] + ct4.num_elements = ct4.polynomial.shape[1] + ct4.num_moduli = _ct4_m + ct4.degree_layout = (_ct4_r, _ct4_c) + ct4.r = _ct4_r + ct4.c = _ct4_c + ct4.moduli = list(_ct4_moduli)[:_ct4_m] + ct4.moduli_array = jnp.array( + ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) + ) + _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct5_arg_m_in = _ct5_arg_data.shape[-1] + _ct5_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_arg_m_in + ) + _ct5_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_arg_r) + ) + _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct5_arg_moduli, (int, np.integer)): + _ct5_arg_moduli = [int(_ct5_arg_moduli)] + ct5_arg = Polynomial( + { + "batch": _ct5_arg_data.shape[0], + "num_elements": _ct5_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_arg_m, + "precision": 32, + "degree_layout": (_ct5_arg_r, _ct5_arg_c), + }, + {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, + ) + ct5_arg.polynomial = _ct5_arg_data.reshape( + _ct5_arg_data.shape[0], + _ct5_arg_data.shape[1], + _ct5_arg_r, + _ct5_arg_c, + _ct5_arg_m_in, + )[..., :_ct5_arg_m].copy() + ct5_arg.batch = ct5_arg.polynomial.shape[0] + ct5_arg.num_elements = ct5_arg.polynomial.shape[1] + ct5_arg.num_moduli = _ct5_arg_m + ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) + ct5_arg.r = _ct5_arg_r + ct5_arg.c = _ct5_arg_c + ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] + ct5_arg.moduli_array = jnp.array( + ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) + ) + ct5_pt_ntt = ( + pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) + _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw + _ct5_m_in = _ct5_data.shape[-1] + _ct5_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_m_in + ) + _ct5_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_r) + ) + _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) + if isinstance(_ct5_moduli, (int, np.integer)): + _ct5_moduli = [int(_ct5_moduli)] + ct5 = Polynomial( + { + "batch": _ct5_data.shape[0], + "num_elements": _ct5_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_m, + "precision": 32, + "degree_layout": (_ct5_r, _ct5_c), + }, + {"moduli": list(_ct5_moduli)[:_ct5_m]}, + ) + ct5.polynomial = _ct5_data.reshape( + _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in + )[..., :_ct5_m].copy() + ct5.batch = ct5.polynomial.shape[0] + ct5.num_elements = ct5.polynomial.shape[1] + ct5.num_moduli = _ct5_m + ct5.degree_layout = (_ct5_r, _ct5_c) + ct5.r = _ct5_r + ct5.c = _ct5_c + ct5.moduli = list(_ct5_moduli)[:_ct5_m] + ct5.moduli_array = jnp.array( + ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) + ) + _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct6_arg_m_in = _ct6_arg_data.shape[-1] + _ct6_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_arg_m_in + ) + _ct6_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_arg_r) + ) + _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct6_arg_moduli, (int, np.integer)): + _ct6_arg_moduli = [int(_ct6_arg_moduli)] + ct6_arg = Polynomial( + { + "batch": _ct6_arg_data.shape[0], + "num_elements": _ct6_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_arg_m, + "precision": 32, + "degree_layout": (_ct6_arg_r, _ct6_arg_c), + }, + {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, + ) + ct6_arg.polynomial = _ct6_arg_data.reshape( + _ct6_arg_data.shape[0], + _ct6_arg_data.shape[1], + _ct6_arg_r, + _ct6_arg_c, + _ct6_arg_m_in, + )[..., :_ct6_arg_m].copy() + ct6_arg.batch = ct6_arg.polynomial.shape[0] + ct6_arg.num_elements = ct6_arg.polynomial.shape[1] + ct6_arg.num_moduli = _ct6_arg_m + ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) + ct6_arg.r = _ct6_arg_r + ct6_arg.c = _ct6_arg_c + ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] + ct6_arg.moduli_array = jnp.array( + ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) + ) + ct6_pt_ntt = ( + pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] + .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct6_ptct = v0.ptct_mul[v0.max_level] + ct6_ptct.set_plaintext(ct6_pt_ntt) + ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) + _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw + _ct6_m_in = _ct6_data.shape[-1] + _ct6_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_m_in + ) + _ct6_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_r) + ) + _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) + if isinstance(_ct6_moduli, (int, np.integer)): + _ct6_moduli = [int(_ct6_moduli)] + ct6 = Polynomial( + { + "batch": _ct6_data.shape[0], + "num_elements": _ct6_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_m, + "precision": 32, + "degree_layout": (_ct6_r, _ct6_c), + }, + {"moduli": list(_ct6_moduli)[:_ct6_m]}, + ) + ct6.polynomial = _ct6_data.reshape( + _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in + )[..., :_ct6_m].copy() + ct6.batch = ct6.polynomial.shape[0] + ct6.num_elements = ct6.polynomial.shape[1] + ct6.num_moduli = _ct6_m + ct6.degree_layout = (_ct6_r, _ct6_c) + ct6.r = _ct6_r + ct6.c = _ct6_c + ct6.moduli = list(_ct6_moduli)[:_ct6_m] + ct6.moduli_array = jnp.array( + ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) + ) + _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct7_arg_m_in = _ct7_arg_data.shape[-1] + _ct7_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_arg_m_in + ) + _ct7_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_arg_r) + ) + _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct7_arg_moduli, (int, np.integer)): + _ct7_arg_moduli = [int(_ct7_arg_moduli)] + ct7_arg = Polynomial( + { + "batch": _ct7_arg_data.shape[0], + "num_elements": _ct7_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_arg_m, + "precision": 32, + "degree_layout": (_ct7_arg_r, _ct7_arg_c), + }, + {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, + ) + ct7_arg.polynomial = _ct7_arg_data.reshape( + _ct7_arg_data.shape[0], + _ct7_arg_data.shape[1], + _ct7_arg_r, + _ct7_arg_c, + _ct7_arg_m_in, + )[..., :_ct7_arg_m].copy() + ct7_arg.batch = ct7_arg.polynomial.shape[0] + ct7_arg.num_elements = ct7_arg.polynomial.shape[1] + ct7_arg.num_moduli = _ct7_arg_m + ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) + ct7_arg.r = _ct7_arg_r + ct7_arg.c = _ct7_arg_c + ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] + ct7_arg.moduli_array = jnp.array( + ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) + ) + ct7_pt_ntt = ( + pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] + .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct7_ptct = v0.ptct_mul[v0.max_level] + ct7_ptct.set_plaintext(ct7_pt_ntt) + ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) + _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw + _ct7_m_in = _ct7_data.shape[-1] + _ct7_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_m_in + ) + _ct7_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_r) + ) + _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) + if isinstance(_ct7_moduli, (int, np.integer)): + _ct7_moduli = [int(_ct7_moduli)] + ct7 = Polynomial( + { + "batch": _ct7_data.shape[0], + "num_elements": _ct7_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_m, + "precision": 32, + "degree_layout": (_ct7_r, _ct7_c), + }, + {"moduli": list(_ct7_moduli)[:_ct7_m]}, + ) + ct7.polynomial = _ct7_data.reshape( + _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in + )[..., :_ct7_m].copy() + ct7.batch = ct7.polynomial.shape[0] + ct7.num_elements = ct7.polynomial.shape[1] + ct7.num_moduli = _ct7_m + ct7.degree_layout = (_ct7_r, _ct7_c) + ct7.r = _ct7_r + ct7.c = _ct7_c + ct7.moduli = list(_ct7_moduli)[:_ct7_m] + ct7.moduli_array = jnp.array( + ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) + ) + _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct8_arg_m_in = _ct8_arg_data.shape[-1] + _ct8_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_arg_m_in + ) + _ct8_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_arg_r) + ) + _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct8_arg_moduli, (int, np.integer)): + _ct8_arg_moduli = [int(_ct8_arg_moduli)] + ct8_arg = Polynomial( + { + "batch": _ct8_arg_data.shape[0], + "num_elements": _ct8_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_arg_m, + "precision": 32, + "degree_layout": (_ct8_arg_r, _ct8_arg_c), + }, + {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, + ) + ct8_arg.polynomial = _ct8_arg_data.reshape( + _ct8_arg_data.shape[0], + _ct8_arg_data.shape[1], + _ct8_arg_r, + _ct8_arg_c, + _ct8_arg_m_in, + )[..., :_ct8_arg_m].copy() + ct8_arg.batch = ct8_arg.polynomial.shape[0] + ct8_arg.num_elements = ct8_arg.polynomial.shape[1] + ct8_arg.num_moduli = _ct8_arg_m + ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) + ct8_arg.r = _ct8_arg_r + ct8_arg.c = _ct8_arg_c + ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] + ct8_arg.moduli_array = jnp.array( + ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) + ) + ct8_pt_ntt = ( + pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] + .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct8_ptct = v0.ptct_mul[v0.max_level] + ct8_ptct.set_plaintext(ct8_pt_ntt) + ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) + _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw + _ct8_m_in = _ct8_data.shape[-1] + _ct8_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_m_in + ) + _ct8_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_r) + ) + _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) + if isinstance(_ct8_moduli, (int, np.integer)): + _ct8_moduli = [int(_ct8_moduli)] + ct8 = Polynomial( + { + "batch": _ct8_data.shape[0], + "num_elements": _ct8_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_m, + "precision": 32, + "degree_layout": (_ct8_r, _ct8_c), + }, + {"moduli": list(_ct8_moduli)[:_ct8_m]}, + ) + ct8.polynomial = _ct8_data.reshape( + _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in + )[..., :_ct8_m].copy() + ct8.batch = ct8.polynomial.shape[0] + ct8.num_elements = ct8.polynomial.shape[1] + ct8.num_moduli = _ct8_m + ct8.degree_layout = (_ct8_r, _ct8_c) + ct8.r = _ct8_r + ct8.c = _ct8_c + ct8.moduli = list(_ct8_moduli)[:_ct8_m] + ct8.moduli_array = jnp.array( + ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) + ) + _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + _ct9_m_in = _ct9_data.shape[-1] + _ct9_m = _ct9_m_in + _ct9_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_r) + ) + _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) + if isinstance(_ct9_moduli, (int, np.integer)): + _ct9_moduli = [int(_ct9_moduli)] + ct9 = Polynomial( + { + "batch": _ct9_data.shape[0], + "num_elements": _ct9_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_m, + "precision": 32, + "degree_layout": (_ct9_r, _ct9_c), + }, + {"moduli": list(_ct9_moduli)[:_ct9_m]}, + ) + ct9.polynomial = _ct9_data.reshape( + _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in + )[..., :_ct9_m].copy() + ct9.batch = ct9.polynomial.shape[0] + ct9.num_elements = ct9.polynomial.shape[1] + ct9.num_moduli = _ct9_m + ct9.degree_layout = (_ct9_r, _ct9_c) + ct9.r = _ct9_r + ct9.c = _ct9_c + ct9.moduli = list(_ct9_moduli)[:_ct9_m] + ct9.moduli_array = jnp.array( + ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) + ) + _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 + _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] + _ct9_rhs_m = _ct9_rhs_m_in + _ct9_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_rhs_r) + ) + _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) + if isinstance(_ct9_rhs_moduli, (int, np.integer)): + _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] + ct9_rhs = Polynomial( + { + "batch": _ct9_rhs_data.shape[0], + "num_elements": _ct9_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_rhs_m, + "precision": 32, + "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), + }, + {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, + ) + ct9_rhs.polynomial = _ct9_rhs_data.reshape( + _ct9_rhs_data.shape[0], + _ct9_rhs_data.shape[1], + _ct9_rhs_r, + _ct9_rhs_c, + _ct9_rhs_m_in, + )[..., :_ct9_rhs_m].copy() + ct9_rhs.batch = ct9_rhs.polynomial.shape[0] + ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] + ct9_rhs.num_moduli = _ct9_rhs_m + ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) + ct9_rhs.r = _ct9_rhs_r + ct9_rhs.c = _ct9_rhs_c + ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] + ct9_rhs.moduli_array = jnp.array( + ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) + ) + ct9.add(ct9_rhs) + _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) + ct9.polynomial = jnp.where( + ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial + ) + _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + _ct10_m_in = _ct10_data.shape[-1] + _ct10_m = _ct10_m_in + _ct10_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_r) + ) + _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) + if isinstance(_ct10_moduli, (int, np.integer)): + _ct10_moduli = [int(_ct10_moduli)] + ct10 = Polynomial( + { + "batch": _ct10_data.shape[0], + "num_elements": _ct10_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_m, + "precision": 32, + "degree_layout": (_ct10_r, _ct10_c), + }, + {"moduli": list(_ct10_moduli)[:_ct10_m]}, + ) + ct10.polynomial = _ct10_data.reshape( + _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in + )[..., :_ct10_m].copy() + ct10.batch = ct10.polynomial.shape[0] + ct10.num_elements = ct10.polynomial.shape[1] + ct10.num_moduli = _ct10_m + ct10.degree_layout = (_ct10_r, _ct10_c) + ct10.r = _ct10_r + ct10.c = _ct10_c + ct10.moduli = list(_ct10_moduli)[:_ct10_m] + ct10.moduli_array = jnp.array( + ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) + ) + _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] + _ct10_rhs_m = _ct10_rhs_m_in + _ct10_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_rhs_r) + ) + _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) + if isinstance(_ct10_rhs_moduli, (int, np.integer)): + _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] + ct10_rhs = Polynomial( + { + "batch": _ct10_rhs_data.shape[0], + "num_elements": _ct10_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_rhs_m, + "precision": 32, + "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), + }, + {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, + ) + ct10_rhs.polynomial = _ct10_rhs_data.reshape( + _ct10_rhs_data.shape[0], + _ct10_rhs_data.shape[1], + _ct10_rhs_r, + _ct10_rhs_c, + _ct10_rhs_m_in, + )[..., :_ct10_rhs_m].copy() + ct10_rhs.batch = ct10_rhs.polynomial.shape[0] + ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] + ct10_rhs.num_moduli = _ct10_rhs_m + ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) + ct10_rhs.r = _ct10_rhs_r + ct10_rhs.c = _ct10_rhs_c + ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] + ct10_rhs.moduli_array = jnp.array( + ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) + ) + ct10.add(ct10_rhs) + _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) + ct10.polynomial = jnp.where( + ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial + ) + _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 + _ct11_arg_m_in = _ct11_arg_data.shape[-1] + _ct11_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_arg_m_in + ) + _ct11_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_arg_r) + ) + _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) + if isinstance(_ct11_arg_moduli, (int, np.integer)): + _ct11_arg_moduli = [int(_ct11_arg_moduli)] + ct11_arg = Polynomial( + { + "batch": _ct11_arg_data.shape[0], + "num_elements": _ct11_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_arg_m, + "precision": 32, + "degree_layout": (_ct11_arg_r, _ct11_arg_c), + }, + {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, + ) + ct11_arg.polynomial = _ct11_arg_data.reshape( + _ct11_arg_data.shape[0], + _ct11_arg_data.shape[1], + _ct11_arg_r, + _ct11_arg_c, + _ct11_arg_m_in, + )[..., :_ct11_arg_m].copy() + ct11_arg.batch = ct11_arg.polynomial.shape[0] + ct11_arg.num_elements = ct11_arg.polynomial.shape[1] + ct11_arg.num_moduli = _ct11_arg_m + ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) + ct11_arg.r = _ct11_arg_r + ct11_arg.c = _ct11_arg_c + ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] + ct11_arg.moduli_array = jnp.array( + ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) + ) + ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) + _ct11_data = ( + ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw + ) + _ct11_m_in = _ct11_data.shape[-1] + _ct11_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_m_in + ) + _ct11_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_r) + ) + _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) + if isinstance(_ct11_moduli, (int, np.integer)): + _ct11_moduli = [int(_ct11_moduli)] + ct11 = Polynomial( + { + "batch": _ct11_data.shape[0], + "num_elements": _ct11_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_m, + "precision": 32, + "degree_layout": (_ct11_r, _ct11_c), + }, + {"moduli": list(_ct11_moduli)[:_ct11_m]}, + ) + ct11.polynomial = _ct11_data.reshape( + _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in + )[..., :_ct11_m].copy() + ct11.batch = ct11.polynomial.shape[0] + ct11.num_elements = ct11.polynomial.shape[1] + ct11.num_moduli = _ct11_m + ct11.degree_layout = (_ct11_r, _ct11_c) + ct11.r = _ct11_r + ct11.c = _ct11_c + ct11.moduli = list(_ct11_moduli)[:_ct11_m] + ct11.moduli_array = jnp.array( + ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) + ) + _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct12_arg_m_in = _ct12_arg_data.shape[-1] + _ct12_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_arg_m_in + ) + _ct12_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_arg_r) + ) + _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct12_arg_moduli, (int, np.integer)): + _ct12_arg_moduli = [int(_ct12_arg_moduli)] + ct12_arg = Polynomial( + { + "batch": _ct12_arg_data.shape[0], + "num_elements": _ct12_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_arg_m, + "precision": 32, + "degree_layout": (_ct12_arg_r, _ct12_arg_c), + }, + {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, + ) + ct12_arg.polynomial = _ct12_arg_data.reshape( + _ct12_arg_data.shape[0], + _ct12_arg_data.shape[1], + _ct12_arg_r, + _ct12_arg_c, + _ct12_arg_m_in, + )[..., :_ct12_arg_m].copy() + ct12_arg.batch = ct12_arg.polynomial.shape[0] + ct12_arg.num_elements = ct12_arg.polynomial.shape[1] + ct12_arg.num_moduli = _ct12_arg_m + ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) + ct12_arg.r = _ct12_arg_r + ct12_arg.c = _ct12_arg_c + ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] + ct12_arg.moduli_array = jnp.array( + ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) + ) + ct12_pt_ntt = ( + pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] + .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct12_ptct = v0.ptct_mul[v0.max_level] + ct12_ptct.set_plaintext(ct12_pt_ntt) + ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) + _ct12_data = ( + ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw + ) + _ct12_m_in = _ct12_data.shape[-1] + _ct12_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_m_in + ) + _ct12_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_r) + ) + _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) + if isinstance(_ct12_moduli, (int, np.integer)): + _ct12_moduli = [int(_ct12_moduli)] + ct12 = Polynomial( + { + "batch": _ct12_data.shape[0], + "num_elements": _ct12_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_m, + "precision": 32, + "degree_layout": (_ct12_r, _ct12_c), + }, + {"moduli": list(_ct12_moduli)[:_ct12_m]}, + ) + ct12.polynomial = _ct12_data.reshape( + _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in + )[..., :_ct12_m].copy() + ct12.batch = ct12.polynomial.shape[0] + ct12.num_elements = ct12.polynomial.shape[1] + ct12.num_moduli = _ct12_m + ct12.degree_layout = (_ct12_r, _ct12_c) + ct12.r = _ct12_r + ct12.c = _ct12_c + ct12.moduli = list(_ct12_moduli)[:_ct12_m] + ct12.moduli_array = jnp.array( + ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) + ) + _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct13_arg_m_in = _ct13_arg_data.shape[-1] + _ct13_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_arg_m_in + ) + _ct13_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_arg_r) + ) + _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct13_arg_moduli, (int, np.integer)): + _ct13_arg_moduli = [int(_ct13_arg_moduli)] + ct13_arg = Polynomial( + { + "batch": _ct13_arg_data.shape[0], + "num_elements": _ct13_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_arg_m, + "precision": 32, + "degree_layout": (_ct13_arg_r, _ct13_arg_c), + }, + {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, + ) + ct13_arg.polynomial = _ct13_arg_data.reshape( + _ct13_arg_data.shape[0], + _ct13_arg_data.shape[1], + _ct13_arg_r, + _ct13_arg_c, + _ct13_arg_m_in, + )[..., :_ct13_arg_m].copy() + ct13_arg.batch = ct13_arg.polynomial.shape[0] + ct13_arg.num_elements = ct13_arg.polynomial.shape[1] + ct13_arg.num_moduli = _ct13_arg_m + ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) + ct13_arg.r = _ct13_arg_r + ct13_arg.c = _ct13_arg_c + ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] + ct13_arg.moduli_array = jnp.array( + ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) + ) + ct13_pt_ntt = ( + pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] + .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct13_ptct = v0.ptct_mul[v0.max_level] + ct13_ptct.set_plaintext(ct13_pt_ntt) + ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) + _ct13_data = ( + ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw + ) + _ct13_m_in = _ct13_data.shape[-1] + _ct13_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_m_in + ) + _ct13_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_r) + ) + _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) + if isinstance(_ct13_moduli, (int, np.integer)): + _ct13_moduli = [int(_ct13_moduli)] + ct13 = Polynomial( + { + "batch": _ct13_data.shape[0], + "num_elements": _ct13_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_m, + "precision": 32, + "degree_layout": (_ct13_r, _ct13_c), + }, + {"moduli": list(_ct13_moduli)[:_ct13_m]}, + ) + ct13.polynomial = _ct13_data.reshape( + _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in + )[..., :_ct13_m].copy() + ct13.batch = ct13.polynomial.shape[0] + ct13.num_elements = ct13.polynomial.shape[1] + ct13.num_moduli = _ct13_m + ct13.degree_layout = (_ct13_r, _ct13_c) + ct13.r = _ct13_r + ct13.c = _ct13_c + ct13.moduli = list(_ct13_moduli)[:_ct13_m] + ct13.moduli_array = jnp.array( + ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) + ) + _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + _ct14_m_in = _ct14_data.shape[-1] + _ct14_m = _ct14_m_in + _ct14_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_r) + ) + _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) + if isinstance(_ct14_moduli, (int, np.integer)): + _ct14_moduli = [int(_ct14_moduli)] + ct14 = Polynomial( + { + "batch": _ct14_data.shape[0], + "num_elements": _ct14_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_m, + "precision": 32, + "degree_layout": (_ct14_r, _ct14_c), + }, + {"moduli": list(_ct14_moduli)[:_ct14_m]}, + ) + ct14.polynomial = _ct14_data.reshape( + _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in + )[..., :_ct14_m].copy() + ct14.batch = ct14.polynomial.shape[0] + ct14.num_elements = ct14.polynomial.shape[1] + ct14.num_moduli = _ct14_m + ct14.degree_layout = (_ct14_r, _ct14_c) + ct14.r = _ct14_r + ct14.c = _ct14_c + ct14.moduli = list(_ct14_moduli)[:_ct14_m] + ct14.moduli_array = jnp.array( + ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) + ) + _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] + _ct14_rhs_m = _ct14_rhs_m_in + _ct14_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_rhs_r) + ) + _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) + if isinstance(_ct14_rhs_moduli, (int, np.integer)): + _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] + ct14_rhs = Polynomial( + { + "batch": _ct14_rhs_data.shape[0], + "num_elements": _ct14_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_rhs_m, + "precision": 32, + "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), + }, + {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, + ) + ct14_rhs.polynomial = _ct14_rhs_data.reshape( + _ct14_rhs_data.shape[0], + _ct14_rhs_data.shape[1], + _ct14_rhs_r, + _ct14_rhs_c, + _ct14_rhs_m_in, + )[..., :_ct14_rhs_m].copy() + ct14_rhs.batch = ct14_rhs.polynomial.shape[0] + ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] + ct14_rhs.num_moduli = _ct14_rhs_m + ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) + ct14_rhs.r = _ct14_rhs_r + ct14_rhs.c = _ct14_rhs_c + ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] + ct14_rhs.moduli_array = jnp.array( + ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) + ) + ct14.add(ct14_rhs) + _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) + ct14.polynomial = jnp.where( + ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial + ) + _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 + _ct15_arg_m_in = _ct15_arg_data.shape[-1] + _ct15_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_arg_m_in + ) + _ct15_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_arg_r) + ) + _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) + if isinstance(_ct15_arg_moduli, (int, np.integer)): + _ct15_arg_moduli = [int(_ct15_arg_moduli)] + ct15_arg = Polynomial( + { + "batch": _ct15_arg_data.shape[0], + "num_elements": _ct15_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_arg_m, + "precision": 32, + "degree_layout": (_ct15_arg_r, _ct15_arg_c), + }, + {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, + ) + ct15_arg.polynomial = _ct15_arg_data.reshape( + _ct15_arg_data.shape[0], + _ct15_arg_data.shape[1], + _ct15_arg_r, + _ct15_arg_c, + _ct15_arg_m_in, + )[..., :_ct15_arg_m].copy() + ct15_arg.batch = ct15_arg.polynomial.shape[0] + ct15_arg.num_elements = ct15_arg.polynomial.shape[1] + ct15_arg.num_moduli = _ct15_arg_m + ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) + ct15_arg.r = _ct15_arg_r + ct15_arg.c = _ct15_arg_c + ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] + ct15_arg.moduli_array = jnp.array( + ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) + ) + ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) + _ct15_data = ( + ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw + ) + _ct15_m_in = _ct15_data.shape[-1] + _ct15_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_m_in + ) + _ct15_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_r) + ) + _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) + if isinstance(_ct15_moduli, (int, np.integer)): + _ct15_moduli = [int(_ct15_moduli)] + ct15 = Polynomial( + { + "batch": _ct15_data.shape[0], + "num_elements": _ct15_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_m, + "precision": 32, + "degree_layout": (_ct15_r, _ct15_c), + }, + {"moduli": list(_ct15_moduli)[:_ct15_m]}, + ) + ct15.polynomial = _ct15_data.reshape( + _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in + )[..., :_ct15_m].copy() + ct15.batch = ct15.polynomial.shape[0] + ct15.num_elements = ct15.polynomial.shape[1] + ct15.num_moduli = _ct15_m + ct15.degree_layout = (_ct15_r, _ct15_c) + ct15.r = _ct15_r + ct15.c = _ct15_c + ct15.moduli = list(_ct15_moduli)[:_ct15_m] + ct15.moduli_array = jnp.array( + ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) + ) + _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + _ct16_m_in = _ct16_data.shape[-1] + _ct16_m = _ct16_m_in + _ct16_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_r) + ) + _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) + if isinstance(_ct16_moduli, (int, np.integer)): + _ct16_moduli = [int(_ct16_moduli)] + ct16 = Polynomial( + { + "batch": _ct16_data.shape[0], + "num_elements": _ct16_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_m, + "precision": 32, + "degree_layout": (_ct16_r, _ct16_c), + }, + {"moduli": list(_ct16_moduli)[:_ct16_m]}, + ) + ct16.polynomial = _ct16_data.reshape( + _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in + )[..., :_ct16_m].copy() + ct16.batch = ct16.polynomial.shape[0] + ct16.num_elements = ct16.polynomial.shape[1] + ct16.num_moduli = _ct16_m + ct16.degree_layout = (_ct16_r, _ct16_c) + ct16.r = _ct16_r + ct16.c = _ct16_c + ct16.moduli = list(_ct16_moduli)[:_ct16_m] + ct16.moduli_array = jnp.array( + ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) + ) + _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 + _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] + _ct16_rhs_m = _ct16_rhs_m_in + _ct16_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_rhs_r) + ) + _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) + if isinstance(_ct16_rhs_moduli, (int, np.integer)): + _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] + ct16_rhs = Polynomial( + { + "batch": _ct16_rhs_data.shape[0], + "num_elements": _ct16_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_rhs_m, + "precision": 32, + "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), + }, + {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, + ) + ct16_rhs.polynomial = _ct16_rhs_data.reshape( + _ct16_rhs_data.shape[0], + _ct16_rhs_data.shape[1], + _ct16_rhs_r, + _ct16_rhs_c, + _ct16_rhs_m_in, + )[..., :_ct16_rhs_m].copy() + ct16_rhs.batch = ct16_rhs.polynomial.shape[0] + ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] + ct16_rhs.num_moduli = _ct16_rhs_m + ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) + ct16_rhs.r = _ct16_rhs_r + ct16_rhs.c = _ct16_rhs_c + ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] + ct16_rhs.moduli_array = jnp.array( + ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) + ) + ct16.add(ct16_rhs) + _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) + ct16.polynomial = jnp.where( + ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial + ) + _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + _ct17_m_in = _ct17_data.shape[-1] + _ct17_m = _ct17_m_in + _ct17_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_r) + ) + _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) + if isinstance(_ct17_moduli, (int, np.integer)): + _ct17_moduli = [int(_ct17_moduli)] + ct17 = Polynomial( + { + "batch": _ct17_data.shape[0], + "num_elements": _ct17_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_m, + "precision": 32, + "degree_layout": (_ct17_r, _ct17_c), + }, + {"moduli": list(_ct17_moduli)[:_ct17_m]}, + ) + ct17.polynomial = _ct17_data.reshape( + _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in + )[..., :_ct17_m].copy() + ct17.batch = ct17.polynomial.shape[0] + ct17.num_elements = ct17.polynomial.shape[1] + ct17.num_moduli = _ct17_m + ct17.degree_layout = (_ct17_r, _ct17_c) + ct17.r = _ct17_r + ct17.c = _ct17_c + ct17.moduli = list(_ct17_moduli)[:_ct17_m] + ct17.moduli_array = jnp.array( + ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) + ) + _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] + _ct17_rhs_m = _ct17_rhs_m_in + _ct17_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_rhs_r) + ) + _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) + if isinstance(_ct17_rhs_moduli, (int, np.integer)): + _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] + ct17_rhs = Polynomial( + { + "batch": _ct17_rhs_data.shape[0], + "num_elements": _ct17_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_rhs_m, + "precision": 32, + "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), + }, + {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, + ) + ct17_rhs.polynomial = _ct17_rhs_data.reshape( + _ct17_rhs_data.shape[0], + _ct17_rhs_data.shape[1], + _ct17_rhs_r, + _ct17_rhs_c, + _ct17_rhs_m_in, + )[..., :_ct17_rhs_m].copy() + ct17_rhs.batch = ct17_rhs.polynomial.shape[0] + ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] + ct17_rhs.num_moduli = _ct17_rhs_m + ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) + ct17_rhs.r = _ct17_rhs_r + ct17_rhs.c = _ct17_rhs_c + ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] + ct17_rhs.moduli_array = jnp.array( + ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) + ) + ct17.add(ct17_rhs) + _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) + ct17.polynomial = jnp.where( + ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial + ) + _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 + _ct18_m_in = _ct18_data.shape[-1] + _ct18_m = _ct18_m_in + _ct18_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_r) + ) + _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) + if isinstance(_ct18_moduli, (int, np.integer)): + _ct18_moduli = [int(_ct18_moduli)] + ct18 = Polynomial( + { + "batch": _ct18_data.shape[0], + "num_elements": _ct18_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_m, + "precision": 32, + "degree_layout": (_ct18_r, _ct18_c), + }, + {"moduli": list(_ct18_moduli)[:_ct18_m]}, + ) + ct18.polynomial = _ct18_data.reshape( + _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in + )[..., :_ct18_m].copy() + ct18.batch = ct18.polynomial.shape[0] + ct18.num_elements = ct18.polynomial.shape[1] + ct18.num_moduli = _ct18_m + ct18.degree_layout = (_ct18_r, _ct18_c) + ct18.r = _ct18_r + ct18.c = _ct18_c + ct18.moduli = list(_ct18_moduli)[:_ct18_m] + ct18.moduli_array = jnp.array( + ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) + ) + _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 + _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] + _ct18_rhs_m = _ct18_rhs_m_in + _ct18_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_rhs_r) + ) + _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) + if isinstance(_ct18_rhs_moduli, (int, np.integer)): + _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] + ct18_rhs = Polynomial( + { + "batch": _ct18_rhs_data.shape[0], + "num_elements": _ct18_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_rhs_m, + "precision": 32, + "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), + }, + {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, + ) + ct18_rhs.polynomial = _ct18_rhs_data.reshape( + _ct18_rhs_data.shape[0], + _ct18_rhs_data.shape[1], + _ct18_rhs_r, + _ct18_rhs_c, + _ct18_rhs_m_in, + )[..., :_ct18_rhs_m].copy() + ct18_rhs.batch = ct18_rhs.polynomial.shape[0] + ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] + ct18_rhs.num_moduli = _ct18_rhs_m + ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) + ct18_rhs.r = _ct18_rhs_r + ct18_rhs.c = _ct18_rhs_c + ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] + ct18_rhs.moduli_array = jnp.array( + ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) + ) + ct18.add(ct18_rhs) + _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) + ct18.polynomial = jnp.where( + ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial + ) + _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 + _ct19_m_in = _ct19_data.shape[-1] + _ct19_m = _ct19_m_in + _ct19_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_r) + ) + _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) + if isinstance(_ct19_moduli, (int, np.integer)): + _ct19_moduli = [int(_ct19_moduli)] + ct19 = Polynomial( + { + "batch": _ct19_data.shape[0], + "num_elements": _ct19_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_m, + "precision": 32, + "degree_layout": (_ct19_r, _ct19_c), + }, + {"moduli": list(_ct19_moduli)[:_ct19_m]}, + ) + ct19.polynomial = _ct19_data.reshape( + _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in + )[..., :_ct19_m].copy() + ct19.batch = ct19.polynomial.shape[0] + ct19.num_elements = ct19.polynomial.shape[1] + ct19.num_moduli = _ct19_m + ct19.degree_layout = (_ct19_r, _ct19_c) + ct19.r = _ct19_r + ct19.c = _ct19_c + ct19.moduli = list(_ct19_moduli)[:_ct19_m] + ct19.moduli_array = jnp.array( + ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) + ) + _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 + _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] + _ct19_rhs_m = _ct19_rhs_m_in + _ct19_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_rhs_r) + ) + _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) + if isinstance(_ct19_rhs_moduli, (int, np.integer)): + _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] + ct19_rhs = Polynomial( + { + "batch": _ct19_rhs_data.shape[0], + "num_elements": _ct19_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_rhs_m, + "precision": 32, + "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), + }, + {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, + ) + ct19_rhs.polynomial = _ct19_rhs_data.reshape( + _ct19_rhs_data.shape[0], + _ct19_rhs_data.shape[1], + _ct19_rhs_r, + _ct19_rhs_c, + _ct19_rhs_m_in, + )[..., :_ct19_rhs_m].copy() + ct19_rhs.batch = ct19_rhs.polynomial.shape[0] + ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] + ct19_rhs.num_moduli = _ct19_rhs_m + ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) + ct19_rhs.r = _ct19_rhs_r + ct19_rhs.c = _ct19_rhs_c + ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] + ct19_rhs.moduli_array = jnp.array( + ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) + ) + ct19.add(ct19_rhs) + _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) + ct19.polynomial = jnp.where( + ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial + ) + v16 = [None] * 1 + _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 + _ct20_arg_m_in = _ct20_arg_data.shape[-1] + _ct20_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct20_arg_m_in + ) + _ct20_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_arg_r) + ) + _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) + if isinstance(_ct20_arg_moduli, (int, np.integer)): + _ct20_arg_moduli = [int(_ct20_arg_moduli)] + ct20_arg = Polynomial( + { + "batch": _ct20_arg_data.shape[0], + "num_elements": _ct20_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_arg_m, + "precision": 32, + "degree_layout": (_ct20_arg_r, _ct20_arg_c), + }, + {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, + ) + ct20_arg.polynomial = _ct20_arg_data.reshape( + _ct20_arg_data.shape[0], + _ct20_arg_data.shape[1], + _ct20_arg_r, + _ct20_arg_c, + _ct20_arg_m_in, + )[..., :_ct20_arg_m].copy() + ct20_arg.batch = ct20_arg.polynomial.shape[0] + ct20_arg.num_elements = ct20_arg.polynomial.shape[1] + ct20_arg.num_moduli = _ct20_arg_m + ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) + ct20_arg.r = _ct20_arg_r + ct20_arg.c = _ct20_arg_c + ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] + ct20_arg.moduli_array = jnp.array( + ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) + ) + ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) + _ct20_data = ( + ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw + ) + _ct20_m_in = _ct20_data.shape[-1] + _ct20_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct20_m_in + ) + _ct20_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_r) + ) + _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) + if isinstance(_ct20_moduli, (int, np.integer)): + _ct20_moduli = [int(_ct20_moduli)] + ct20 = Polynomial( + { + "batch": _ct20_data.shape[0], + "num_elements": _ct20_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_m, + "precision": 32, + "degree_layout": (_ct20_r, _ct20_c), + }, + {"moduli": list(_ct20_moduli)[:_ct20_m]}, + ) + ct20.polynomial = _ct20_data.reshape( + _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in + )[..., :_ct20_m].copy() + ct20.batch = ct20.polynomial.shape[0] + ct20.num_elements = ct20.polynomial.shape[1] + ct20.num_moduli = _ct20_m + ct20.degree_layout = (_ct20_r, _ct20_c) + ct20.r = _ct20_r + ct20.c = _ct20_c + ct20.moduli = list(_ct20_moduli)[:_ct20_m] + ct20.moduli_array = jnp.array( + ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) + ) + v16[0] = ct20 + v17 = v16 + return v17 + + +def matvec_identity( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_identity__preprocessing(v0, v1) + v11 = matvec_identity__preprocessed( + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10 + ) + return v11 + + +def matvec_identity__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw + _ct_m_in = _ct_data.shape[-1] + _ct_m = _ct_m_in + _ct_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct_r) + ) + _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) + if isinstance(_ct_moduli, (int, np.integer)): + _ct_moduli = [int(_ct_moduli)] + ct = Polynomial( + { + "batch": _ct_data.shape[0], + "num_elements": _ct_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct_m, + "precision": 32, + "degree_layout": (_ct_r, _ct_c), + }, + {"moduli": list(_ct_moduli)[:_ct_m]}, + ) + ct.polynomial = _ct_data.reshape( + _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in + )[..., :_ct_m].copy() + ct.batch = ct.polynomial.shape[0] + ct.num_elements = ct.polynomial.shape[1] + ct.num_moduli = _ct_m + ct.degree_layout = (_ct_r, _ct_c) + ct.r = _ct_r + ct.c = _ct_c + ct.moduli = list(_ct_moduli)[:_ct_m] + ct.moduli_array = jnp.array( + ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) + ) + v16 = [ct] + return v16 + + +def matvec_identity__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 0 + v8 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + _num_moduli = ct.polynomial.shape[-1] + _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": ct.polynomial.shape[0], + "num_elements": ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + ct.polynomial.reshape( + ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli + ) + ) + pt = v0.decrypt(_ct_for_dec) + v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v10 = v8.copy() + for v11 in range(0, 8): + v13 = int(v11) + v14 = v9[0, v13] + v10[v13] = v14 + return v10 + + +def matvec_shift__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> ( + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +): + v2 = np.array( + [ + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 1.000000e00, + 1.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + 0.000000e00, + ], + dtype=np.float32, + ).reshape(8, 8) + v3 = _assign_layout_15335824159471298539(v2) + v4 = v3[3 : 3 + 1, 0 : 0 + 5] + v5 = v3[3 : 3 + 1, 5 : 5 + 3] + v6 = np.zeros( + ( + 1, + 8, + ), + dtype=np.float32, + ) + v7 = v6.copy() + v7[0 : 0 + 1, 3 : 3 + 5] = v4 + v8 = v7.copy() + v8[0 : 0 + 1, 0 : 0 + 3] = v5 + v9 = v3[4 : 4 + 1, 0 : 0 + 5] + v10 = v3[4 : 4 + 1, 5 : 5 + 3] + v11 = v6.copy() + v11[0 : 0 + 1, 3 : 3 + 5] = v9 + v12 = v11.copy() + v12[0 : 0 + 1, 0 : 0 + 3] = v10 + v13 = v3[5 : 5 + 1, 0 : 0 + 5] + v14 = v3[5 : 5 + 1, 5 : 5 + 3] + v15 = v6.copy() + v15[0 : 0 + 1, 3 : 3 + 5] = v13 + v16 = v15.copy() + v16[0 : 0 + 1, 0 : 0 + 3] = v14 + v17 = v3[6 : 6 + 1, 0 : 0 + 2] + v18 = v3[6 : 6 + 1, 2 : 2 + 6] + v19 = v6.copy() + v19[0 : 0 + 1, 6 : 6 + 2] = v17 + v20 = v19.copy() + v20[0 : 0 + 1, 0 : 0 + 6] = v18 + v21 = v3[7 : 7 + 1, 0 : 0 + 2] + v22 = v3[7 : 7 + 1, 2 : 2 + 6] + v23 = v6.copy() + v23[0 : 0 + 1, 6 : 6 + 2] = v21 + v24 = v23.copy() + v24[0 : 0 + 1, 0 : 0 + 6] = v22 + v25 = v3[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v25) + v26 = v3[1 : 1 + 1, 0 : 0 + 8].reshape(8) + pt1 = v0.encode(v26) + v27 = v3[2 : 2 + 1, 0 : 0 + 8].reshape(8) + pt2 = v0.encode(v27) + v28 = v8[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt3 = v0.encode(v28) + v29 = v12[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt4 = v0.encode(v29) + v30 = v16[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt5 = v0.encode(v30) + v31 = v20[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt6 = v0.encode(v31) + v32 = v24[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt7 = v0.encode(v32) + v33 = [pt] + v34 = [pt1] + v35 = [pt2] + v36 = [pt3] + v37 = [pt4] + v38 = [pt5] + v39 = [pt6] + v40 = [pt7] + return (v33, v34, v35, v36, v37, v38, v39, v40) + + +def matvec_shift__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, + v5: np.ndarray, + v6: np.ndarray, + v7: np.ndarray, + v8: np.ndarray, + v9: np.ndarray, + v10: np.ndarray, +) -> np.ndarray: + v11 = 1 + v12 = 2 + v13 = 3 + v14 = 6 + v15 = 0 + pt = v3[0] + pt1 = v4[0] + pt2 = v5[0] + pt3 = v6[0] + pt4 = v7[0] + pt5 = v8[0] + pt6 = v9[0] + pt7 = v10[0] + ct = v2[0] + _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct1_arg_m_in = _ct1_arg_data.shape[-1] + _ct1_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_arg_m_in + ) + _ct1_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_arg_r) + ) + _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct1_arg_moduli, (int, np.integer)): + _ct1_arg_moduli = [int(_ct1_arg_moduli)] + ct1_arg = Polynomial( + { + "batch": _ct1_arg_data.shape[0], + "num_elements": _ct1_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_arg_m, + "precision": 32, + "degree_layout": (_ct1_arg_r, _ct1_arg_c), + }, + {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, + ) + ct1_arg.polynomial = _ct1_arg_data.reshape( + _ct1_arg_data.shape[0], + _ct1_arg_data.shape[1], + _ct1_arg_r, + _ct1_arg_c, + _ct1_arg_m_in, + )[..., :_ct1_arg_m].copy() + ct1_arg.batch = ct1_arg.polynomial.shape[0] + ct1_arg.num_elements = ct1_arg.polynomial.shape[1] + ct1_arg.num_moduli = _ct1_arg_m + ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) + ct1_arg.r = _ct1_arg_r + ct1_arg.c = _ct1_arg_c + ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] + ct1_arg.moduli_array = jnp.array( + ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) + ) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) + _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw + _ct1_m_in = _ct1_data.shape[-1] + _ct1_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_m_in + ) + _ct1_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_r) + ) + _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) + if isinstance(_ct1_moduli, (int, np.integer)): + _ct1_moduli = [int(_ct1_moduli)] + ct1 = Polynomial( + { + "batch": _ct1_data.shape[0], + "num_elements": _ct1_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_m, + "precision": 32, + "degree_layout": (_ct1_r, _ct1_c), + }, + {"moduli": list(_ct1_moduli)[:_ct1_m]}, + ) + ct1.polynomial = _ct1_data.reshape( + _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in + )[..., :_ct1_m].copy() + ct1.batch = ct1.polynomial.shape[0] + ct1.num_elements = ct1.polynomial.shape[1] + ct1.num_moduli = _ct1_m + ct1.degree_layout = (_ct1_r, _ct1_c) + ct1.r = _ct1_r + ct1.c = _ct1_c + ct1.moduli = list(_ct1_moduli)[:_ct1_m] + ct1.moduli_array = jnp.array( + ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) + ) + _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct2_arg_m_in = _ct2_arg_data.shape[-1] + _ct2_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_arg_m_in + ) + _ct2_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_arg_r) + ) + _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct2_arg_moduli, (int, np.integer)): + _ct2_arg_moduli = [int(_ct2_arg_moduli)] + ct2_arg = Polynomial( + { + "batch": _ct2_arg_data.shape[0], + "num_elements": _ct2_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_arg_m, + "precision": 32, + "degree_layout": (_ct2_arg_r, _ct2_arg_c), + }, + {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, + ) + ct2_arg.polynomial = _ct2_arg_data.reshape( + _ct2_arg_data.shape[0], + _ct2_arg_data.shape[1], + _ct2_arg_r, + _ct2_arg_c, + _ct2_arg_m_in, + )[..., :_ct2_arg_m].copy() + ct2_arg.batch = ct2_arg.polynomial.shape[0] + ct2_arg.num_elements = ct2_arg.polynomial.shape[1] + ct2_arg.num_moduli = _ct2_arg_m + ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) + ct2_arg.r = _ct2_arg_r + ct2_arg.c = _ct2_arg_c + ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] + ct2_arg.moduli_array = jnp.array( + ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) + ) + ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw + _ct2_m_in = _ct2_data.shape[-1] + _ct2_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_m_in + ) + _ct2_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_r) + ) + _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) + if isinstance(_ct2_moduli, (int, np.integer)): + _ct2_moduli = [int(_ct2_moduli)] + ct2 = Polynomial( + { + "batch": _ct2_data.shape[0], + "num_elements": _ct2_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_m, + "precision": 32, + "degree_layout": (_ct2_r, _ct2_c), + }, + {"moduli": list(_ct2_moduli)[:_ct2_m]}, + ) + ct2.polynomial = _ct2_data.reshape( + _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in + )[..., :_ct2_m].copy() + ct2.batch = ct2.polynomial.shape[0] + ct2.num_elements = ct2.polynomial.shape[1] + ct2.num_moduli = _ct2_m + ct2.degree_layout = (_ct2_r, _ct2_c) + ct2.r = _ct2_r + ct2.c = _ct2_c + ct2.moduli = list(_ct2_moduli)[:_ct2_m] + ct2.moduli_array = jnp.array( + ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) + ) + _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct3_arg_m_in = _ct3_arg_data.shape[-1] + _ct3_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_arg_m_in + ) + _ct3_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_arg_r) + ) + _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct3_arg_moduli, (int, np.integer)): + _ct3_arg_moduli = [int(_ct3_arg_moduli)] + ct3_arg = Polynomial( + { + "batch": _ct3_arg_data.shape[0], + "num_elements": _ct3_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_arg_m, + "precision": 32, + "degree_layout": (_ct3_arg_r, _ct3_arg_c), + }, + {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, + ) + ct3_arg.polynomial = _ct3_arg_data.reshape( + _ct3_arg_data.shape[0], + _ct3_arg_data.shape[1], + _ct3_arg_r, + _ct3_arg_c, + _ct3_arg_m_in, + )[..., :_ct3_arg_m].copy() + ct3_arg.batch = ct3_arg.polynomial.shape[0] + ct3_arg.num_elements = ct3_arg.polynomial.shape[1] + ct3_arg.num_moduli = _ct3_arg_m + ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) + ct3_arg.r = _ct3_arg_r + ct3_arg.c = _ct3_arg_c + ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] + ct3_arg.moduli_array = jnp.array( + ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) + ) + ct3_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] + .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct3_ptct = v0.ptct_mul[v0.max_level] + ct3_ptct.set_plaintext(ct3_pt_ntt) + ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) + _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw + _ct3_m_in = _ct3_data.shape[-1] + _ct3_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_m_in + ) + _ct3_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_r) + ) + _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) + if isinstance(_ct3_moduli, (int, np.integer)): + _ct3_moduli = [int(_ct3_moduli)] + ct3 = Polynomial( + { + "batch": _ct3_data.shape[0], + "num_elements": _ct3_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_m, + "precision": 32, + "degree_layout": (_ct3_r, _ct3_c), + }, + {"moduli": list(_ct3_moduli)[:_ct3_m]}, + ) + ct3.polynomial = _ct3_data.reshape( + _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in + )[..., :_ct3_m].copy() + ct3.batch = ct3.polynomial.shape[0] + ct3.num_elements = ct3.polynomial.shape[1] + ct3.num_moduli = _ct3_m + ct3.degree_layout = (_ct3_r, _ct3_c) + ct3.r = _ct3_r + ct3.c = _ct3_c + ct3.moduli = list(_ct3_moduli)[:_ct3_m] + ct3.moduli_array = jnp.array( + ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) + ) + _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct4_arg_m_in = _ct4_arg_data.shape[-1] + _ct4_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_arg_m_in + ) + _ct4_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_arg_r) + ) + _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct4_arg_moduli, (int, np.integer)): + _ct4_arg_moduli = [int(_ct4_arg_moduli)] + ct4_arg = Polynomial( + { + "batch": _ct4_arg_data.shape[0], + "num_elements": _ct4_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_arg_m, + "precision": 32, + "degree_layout": (_ct4_arg_r, _ct4_arg_c), + }, + {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, + ) + ct4_arg.polynomial = _ct4_arg_data.reshape( + _ct4_arg_data.shape[0], + _ct4_arg_data.shape[1], + _ct4_arg_r, + _ct4_arg_c, + _ct4_arg_m_in, + )[..., :_ct4_arg_m].copy() + ct4_arg.batch = ct4_arg.polynomial.shape[0] + ct4_arg.num_elements = ct4_arg.polynomial.shape[1] + ct4_arg.num_moduli = _ct4_arg_m + ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) + ct4_arg.r = _ct4_arg_r + ct4_arg.c = _ct4_arg_c + ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] + ct4_arg.moduli_array = jnp.array( + ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) + ) + ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) + _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw + _ct4_m_in = _ct4_data.shape[-1] + _ct4_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_m_in + ) + _ct4_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_r) + ) + _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) + if isinstance(_ct4_moduli, (int, np.integer)): + _ct4_moduli = [int(_ct4_moduli)] + ct4 = Polynomial( + { + "batch": _ct4_data.shape[0], + "num_elements": _ct4_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_m, + "precision": 32, + "degree_layout": (_ct4_r, _ct4_c), + }, + {"moduli": list(_ct4_moduli)[:_ct4_m]}, + ) + ct4.polynomial = _ct4_data.reshape( + _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in + )[..., :_ct4_m].copy() + ct4.batch = ct4.polynomial.shape[0] + ct4.num_elements = ct4.polynomial.shape[1] + ct4.num_moduli = _ct4_m + ct4.degree_layout = (_ct4_r, _ct4_c) + ct4.r = _ct4_r + ct4.c = _ct4_c + ct4.moduli = list(_ct4_moduli)[:_ct4_m] + ct4.moduli_array = jnp.array( + ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) + ) + _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct5_arg_m_in = _ct5_arg_data.shape[-1] + _ct5_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_arg_m_in + ) + _ct5_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_arg_r) + ) + _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct5_arg_moduli, (int, np.integer)): + _ct5_arg_moduli = [int(_ct5_arg_moduli)] + ct5_arg = Polynomial( + { + "batch": _ct5_arg_data.shape[0], + "num_elements": _ct5_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_arg_m, + "precision": 32, + "degree_layout": (_ct5_arg_r, _ct5_arg_c), + }, + {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, + ) + ct5_arg.polynomial = _ct5_arg_data.reshape( + _ct5_arg_data.shape[0], + _ct5_arg_data.shape[1], + _ct5_arg_r, + _ct5_arg_c, + _ct5_arg_m_in, + )[..., :_ct5_arg_m].copy() + ct5_arg.batch = ct5_arg.polynomial.shape[0] + ct5_arg.num_elements = ct5_arg.polynomial.shape[1] + ct5_arg.num_moduli = _ct5_arg_m + ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) + ct5_arg.r = _ct5_arg_r + ct5_arg.c = _ct5_arg_c + ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] + ct5_arg.moduli_array = jnp.array( + ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) + ) + ct5_pt_ntt = ( + pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) + _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw + _ct5_m_in = _ct5_data.shape[-1] + _ct5_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_m_in + ) + _ct5_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_r) + ) + _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) + if isinstance(_ct5_moduli, (int, np.integer)): + _ct5_moduli = [int(_ct5_moduli)] + ct5 = Polynomial( + { + "batch": _ct5_data.shape[0], + "num_elements": _ct5_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_m, + "precision": 32, + "degree_layout": (_ct5_r, _ct5_c), + }, + {"moduli": list(_ct5_moduli)[:_ct5_m]}, + ) + ct5.polynomial = _ct5_data.reshape( + _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in + )[..., :_ct5_m].copy() + ct5.batch = ct5.polynomial.shape[0] + ct5.num_elements = ct5.polynomial.shape[1] + ct5.num_moduli = _ct5_m + ct5.degree_layout = (_ct5_r, _ct5_c) + ct5.r = _ct5_r + ct5.c = _ct5_c + ct5.moduli = list(_ct5_moduli)[:_ct5_m] + ct5.moduli_array = jnp.array( + ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) + ) + _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct6_arg_m_in = _ct6_arg_data.shape[-1] + _ct6_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_arg_m_in + ) + _ct6_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_arg_r) + ) + _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct6_arg_moduli, (int, np.integer)): + _ct6_arg_moduli = [int(_ct6_arg_moduli)] + ct6_arg = Polynomial( + { + "batch": _ct6_arg_data.shape[0], + "num_elements": _ct6_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_arg_m, + "precision": 32, + "degree_layout": (_ct6_arg_r, _ct6_arg_c), + }, + {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, + ) + ct6_arg.polynomial = _ct6_arg_data.reshape( + _ct6_arg_data.shape[0], + _ct6_arg_data.shape[1], + _ct6_arg_r, + _ct6_arg_c, + _ct6_arg_m_in, + )[..., :_ct6_arg_m].copy() + ct6_arg.batch = ct6_arg.polynomial.shape[0] + ct6_arg.num_elements = ct6_arg.polynomial.shape[1] + ct6_arg.num_moduli = _ct6_arg_m + ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) + ct6_arg.r = _ct6_arg_r + ct6_arg.c = _ct6_arg_c + ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] + ct6_arg.moduli_array = jnp.array( + ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) + ) + ct6_pt_ntt = ( + pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] + .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct6_ptct = v0.ptct_mul[v0.max_level] + ct6_ptct.set_plaintext(ct6_pt_ntt) + ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) + _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw + _ct6_m_in = _ct6_data.shape[-1] + _ct6_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_m_in + ) + _ct6_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_r) + ) + _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) + if isinstance(_ct6_moduli, (int, np.integer)): + _ct6_moduli = [int(_ct6_moduli)] + ct6 = Polynomial( + { + "batch": _ct6_data.shape[0], + "num_elements": _ct6_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_m, + "precision": 32, + "degree_layout": (_ct6_r, _ct6_c), + }, + {"moduli": list(_ct6_moduli)[:_ct6_m]}, + ) + ct6.polynomial = _ct6_data.reshape( + _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in + )[..., :_ct6_m].copy() + ct6.batch = ct6.polynomial.shape[0] + ct6.num_elements = ct6.polynomial.shape[1] + ct6.num_moduli = _ct6_m + ct6.degree_layout = (_ct6_r, _ct6_c) + ct6.r = _ct6_r + ct6.c = _ct6_c + ct6.moduli = list(_ct6_moduli)[:_ct6_m] + ct6.moduli_array = jnp.array( + ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) + ) + _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct7_arg_m_in = _ct7_arg_data.shape[-1] + _ct7_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_arg_m_in + ) + _ct7_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_arg_r) + ) + _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct7_arg_moduli, (int, np.integer)): + _ct7_arg_moduli = [int(_ct7_arg_moduli)] + ct7_arg = Polynomial( + { + "batch": _ct7_arg_data.shape[0], + "num_elements": _ct7_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_arg_m, + "precision": 32, + "degree_layout": (_ct7_arg_r, _ct7_arg_c), + }, + {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, + ) + ct7_arg.polynomial = _ct7_arg_data.reshape( + _ct7_arg_data.shape[0], + _ct7_arg_data.shape[1], + _ct7_arg_r, + _ct7_arg_c, + _ct7_arg_m_in, + )[..., :_ct7_arg_m].copy() + ct7_arg.batch = ct7_arg.polynomial.shape[0] + ct7_arg.num_elements = ct7_arg.polynomial.shape[1] + ct7_arg.num_moduli = _ct7_arg_m + ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) + ct7_arg.r = _ct7_arg_r + ct7_arg.c = _ct7_arg_c + ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] + ct7_arg.moduli_array = jnp.array( + ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) + ) + ct7_pt_ntt = ( + pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] + .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct7_ptct = v0.ptct_mul[v0.max_level] + ct7_ptct.set_plaintext(ct7_pt_ntt) + ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) + _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw + _ct7_m_in = _ct7_data.shape[-1] + _ct7_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_m_in + ) + _ct7_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_r) + ) + _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) + if isinstance(_ct7_moduli, (int, np.integer)): + _ct7_moduli = [int(_ct7_moduli)] + ct7 = Polynomial( + { + "batch": _ct7_data.shape[0], + "num_elements": _ct7_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_m, + "precision": 32, + "degree_layout": (_ct7_r, _ct7_c), + }, + {"moduli": list(_ct7_moduli)[:_ct7_m]}, + ) + ct7.polynomial = _ct7_data.reshape( + _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in + )[..., :_ct7_m].copy() + ct7.batch = ct7.polynomial.shape[0] + ct7.num_elements = ct7.polynomial.shape[1] + ct7.num_moduli = _ct7_m + ct7.degree_layout = (_ct7_r, _ct7_c) + ct7.r = _ct7_r + ct7.c = _ct7_c + ct7.moduli = list(_ct7_moduli)[:_ct7_m] + ct7.moduli_array = jnp.array( + ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) + ) + _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct8_arg_m_in = _ct8_arg_data.shape[-1] + _ct8_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_arg_m_in + ) + _ct8_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_arg_r) + ) + _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct8_arg_moduli, (int, np.integer)): + _ct8_arg_moduli = [int(_ct8_arg_moduli)] + ct8_arg = Polynomial( + { + "batch": _ct8_arg_data.shape[0], + "num_elements": _ct8_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_arg_m, + "precision": 32, + "degree_layout": (_ct8_arg_r, _ct8_arg_c), + }, + {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, + ) + ct8_arg.polynomial = _ct8_arg_data.reshape( + _ct8_arg_data.shape[0], + _ct8_arg_data.shape[1], + _ct8_arg_r, + _ct8_arg_c, + _ct8_arg_m_in, + )[..., :_ct8_arg_m].copy() + ct8_arg.batch = ct8_arg.polynomial.shape[0] + ct8_arg.num_elements = ct8_arg.polynomial.shape[1] + ct8_arg.num_moduli = _ct8_arg_m + ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) + ct8_arg.r = _ct8_arg_r + ct8_arg.c = _ct8_arg_c + ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] + ct8_arg.moduli_array = jnp.array( + ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) + ) + ct8_pt_ntt = ( + pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] + .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct8_ptct = v0.ptct_mul[v0.max_level] + ct8_ptct.set_plaintext(ct8_pt_ntt) + ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) + _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw + _ct8_m_in = _ct8_data.shape[-1] + _ct8_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_m_in + ) + _ct8_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_r) + ) + _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) + if isinstance(_ct8_moduli, (int, np.integer)): + _ct8_moduli = [int(_ct8_moduli)] + ct8 = Polynomial( + { + "batch": _ct8_data.shape[0], + "num_elements": _ct8_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_m, + "precision": 32, + "degree_layout": (_ct8_r, _ct8_c), + }, + {"moduli": list(_ct8_moduli)[:_ct8_m]}, + ) + ct8.polynomial = _ct8_data.reshape( + _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in + )[..., :_ct8_m].copy() + ct8.batch = ct8.polynomial.shape[0] + ct8.num_elements = ct8.polynomial.shape[1] + ct8.num_moduli = _ct8_m + ct8.degree_layout = (_ct8_r, _ct8_c) + ct8.r = _ct8_r + ct8.c = _ct8_c + ct8.moduli = list(_ct8_moduli)[:_ct8_m] + ct8.moduli_array = jnp.array( + ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) + ) + _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + _ct9_m_in = _ct9_data.shape[-1] + _ct9_m = _ct9_m_in + _ct9_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_r) + ) + _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) + if isinstance(_ct9_moduli, (int, np.integer)): + _ct9_moduli = [int(_ct9_moduli)] + ct9 = Polynomial( + { + "batch": _ct9_data.shape[0], + "num_elements": _ct9_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_m, + "precision": 32, + "degree_layout": (_ct9_r, _ct9_c), + }, + {"moduli": list(_ct9_moduli)[:_ct9_m]}, + ) + ct9.polynomial = _ct9_data.reshape( + _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in + )[..., :_ct9_m].copy() + ct9.batch = ct9.polynomial.shape[0] + ct9.num_elements = ct9.polynomial.shape[1] + ct9.num_moduli = _ct9_m + ct9.degree_layout = (_ct9_r, _ct9_c) + ct9.r = _ct9_r + ct9.c = _ct9_c + ct9.moduli = list(_ct9_moduli)[:_ct9_m] + ct9.moduli_array = jnp.array( + ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) + ) + _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 + _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] + _ct9_rhs_m = _ct9_rhs_m_in + _ct9_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_rhs_r) + ) + _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) + if isinstance(_ct9_rhs_moduli, (int, np.integer)): + _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] + ct9_rhs = Polynomial( + { + "batch": _ct9_rhs_data.shape[0], + "num_elements": _ct9_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_rhs_m, + "precision": 32, + "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), + }, + {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, + ) + ct9_rhs.polynomial = _ct9_rhs_data.reshape( + _ct9_rhs_data.shape[0], + _ct9_rhs_data.shape[1], + _ct9_rhs_r, + _ct9_rhs_c, + _ct9_rhs_m_in, + )[..., :_ct9_rhs_m].copy() + ct9_rhs.batch = ct9_rhs.polynomial.shape[0] + ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] + ct9_rhs.num_moduli = _ct9_rhs_m + ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) + ct9_rhs.r = _ct9_rhs_r + ct9_rhs.c = _ct9_rhs_c + ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] + ct9_rhs.moduli_array = jnp.array( + ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) + ) + ct9.add(ct9_rhs) + _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) + ct9.polynomial = jnp.where( + ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial + ) + _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + _ct10_m_in = _ct10_data.shape[-1] + _ct10_m = _ct10_m_in + _ct10_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_r) + ) + _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) + if isinstance(_ct10_moduli, (int, np.integer)): + _ct10_moduli = [int(_ct10_moduli)] + ct10 = Polynomial( + { + "batch": _ct10_data.shape[0], + "num_elements": _ct10_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_m, + "precision": 32, + "degree_layout": (_ct10_r, _ct10_c), + }, + {"moduli": list(_ct10_moduli)[:_ct10_m]}, + ) + ct10.polynomial = _ct10_data.reshape( + _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in + )[..., :_ct10_m].copy() + ct10.batch = ct10.polynomial.shape[0] + ct10.num_elements = ct10.polynomial.shape[1] + ct10.num_moduli = _ct10_m + ct10.degree_layout = (_ct10_r, _ct10_c) + ct10.r = _ct10_r + ct10.c = _ct10_c + ct10.moduli = list(_ct10_moduli)[:_ct10_m] + ct10.moduli_array = jnp.array( + ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) + ) + _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] + _ct10_rhs_m = _ct10_rhs_m_in + _ct10_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_rhs_r) + ) + _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) + if isinstance(_ct10_rhs_moduli, (int, np.integer)): + _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] + ct10_rhs = Polynomial( + { + "batch": _ct10_rhs_data.shape[0], + "num_elements": _ct10_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_rhs_m, + "precision": 32, + "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), + }, + {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, + ) + ct10_rhs.polynomial = _ct10_rhs_data.reshape( + _ct10_rhs_data.shape[0], + _ct10_rhs_data.shape[1], + _ct10_rhs_r, + _ct10_rhs_c, + _ct10_rhs_m_in, + )[..., :_ct10_rhs_m].copy() + ct10_rhs.batch = ct10_rhs.polynomial.shape[0] + ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] + ct10_rhs.num_moduli = _ct10_rhs_m + ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) + ct10_rhs.r = _ct10_rhs_r + ct10_rhs.c = _ct10_rhs_c + ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] + ct10_rhs.moduli_array = jnp.array( + ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) + ) + ct10.add(ct10_rhs) + _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) + ct10.polynomial = jnp.where( + ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial + ) + _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 + _ct11_arg_m_in = _ct11_arg_data.shape[-1] + _ct11_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_arg_m_in + ) + _ct11_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_arg_r) + ) + _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) + if isinstance(_ct11_arg_moduli, (int, np.integer)): + _ct11_arg_moduli = [int(_ct11_arg_moduli)] + ct11_arg = Polynomial( + { + "batch": _ct11_arg_data.shape[0], + "num_elements": _ct11_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_arg_m, + "precision": 32, + "degree_layout": (_ct11_arg_r, _ct11_arg_c), + }, + {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, + ) + ct11_arg.polynomial = _ct11_arg_data.reshape( + _ct11_arg_data.shape[0], + _ct11_arg_data.shape[1], + _ct11_arg_r, + _ct11_arg_c, + _ct11_arg_m_in, + )[..., :_ct11_arg_m].copy() + ct11_arg.batch = ct11_arg.polynomial.shape[0] + ct11_arg.num_elements = ct11_arg.polynomial.shape[1] + ct11_arg.num_moduli = _ct11_arg_m + ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) + ct11_arg.r = _ct11_arg_r + ct11_arg.c = _ct11_arg_c + ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] + ct11_arg.moduli_array = jnp.array( + ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) + ) + ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) + _ct11_data = ( + ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw + ) + _ct11_m_in = _ct11_data.shape[-1] + _ct11_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_m_in + ) + _ct11_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_r) + ) + _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) + if isinstance(_ct11_moduli, (int, np.integer)): + _ct11_moduli = [int(_ct11_moduli)] + ct11 = Polynomial( + { + "batch": _ct11_data.shape[0], + "num_elements": _ct11_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_m, + "precision": 32, + "degree_layout": (_ct11_r, _ct11_c), + }, + {"moduli": list(_ct11_moduli)[:_ct11_m]}, + ) + ct11.polynomial = _ct11_data.reshape( + _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in + )[..., :_ct11_m].copy() + ct11.batch = ct11.polynomial.shape[0] + ct11.num_elements = ct11.polynomial.shape[1] + ct11.num_moduli = _ct11_m + ct11.degree_layout = (_ct11_r, _ct11_c) + ct11.r = _ct11_r + ct11.c = _ct11_c + ct11.moduli = list(_ct11_moduli)[:_ct11_m] + ct11.moduli_array = jnp.array( + ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) + ) + _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct12_arg_m_in = _ct12_arg_data.shape[-1] + _ct12_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_arg_m_in + ) + _ct12_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_arg_r) + ) + _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct12_arg_moduli, (int, np.integer)): + _ct12_arg_moduli = [int(_ct12_arg_moduli)] + ct12_arg = Polynomial( + { + "batch": _ct12_arg_data.shape[0], + "num_elements": _ct12_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_arg_m, + "precision": 32, + "degree_layout": (_ct12_arg_r, _ct12_arg_c), + }, + {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, + ) + ct12_arg.polynomial = _ct12_arg_data.reshape( + _ct12_arg_data.shape[0], + _ct12_arg_data.shape[1], + _ct12_arg_r, + _ct12_arg_c, + _ct12_arg_m_in, + )[..., :_ct12_arg_m].copy() + ct12_arg.batch = ct12_arg.polynomial.shape[0] + ct12_arg.num_elements = ct12_arg.polynomial.shape[1] + ct12_arg.num_moduli = _ct12_arg_m + ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) + ct12_arg.r = _ct12_arg_r + ct12_arg.c = _ct12_arg_c + ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] + ct12_arg.moduli_array = jnp.array( + ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) + ) + ct12_pt_ntt = ( + pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] + .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct12_ptct = v0.ptct_mul[v0.max_level] + ct12_ptct.set_plaintext(ct12_pt_ntt) + ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) + _ct12_data = ( + ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw + ) + _ct12_m_in = _ct12_data.shape[-1] + _ct12_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_m_in + ) + _ct12_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_r) + ) + _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) + if isinstance(_ct12_moduli, (int, np.integer)): + _ct12_moduli = [int(_ct12_moduli)] + ct12 = Polynomial( + { + "batch": _ct12_data.shape[0], + "num_elements": _ct12_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_m, + "precision": 32, + "degree_layout": (_ct12_r, _ct12_c), + }, + {"moduli": list(_ct12_moduli)[:_ct12_m]}, + ) + ct12.polynomial = _ct12_data.reshape( + _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in + )[..., :_ct12_m].copy() + ct12.batch = ct12.polynomial.shape[0] + ct12.num_elements = ct12.polynomial.shape[1] + ct12.num_moduli = _ct12_m + ct12.degree_layout = (_ct12_r, _ct12_c) + ct12.r = _ct12_r + ct12.c = _ct12_c + ct12.moduli = list(_ct12_moduli)[:_ct12_m] + ct12.moduli_array = jnp.array( + ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) + ) + _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct13_arg_m_in = _ct13_arg_data.shape[-1] + _ct13_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_arg_m_in + ) + _ct13_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_arg_r) + ) + _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct13_arg_moduli, (int, np.integer)): + _ct13_arg_moduli = [int(_ct13_arg_moduli)] + ct13_arg = Polynomial( + { + "batch": _ct13_arg_data.shape[0], + "num_elements": _ct13_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_arg_m, + "precision": 32, + "degree_layout": (_ct13_arg_r, _ct13_arg_c), + }, + {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, + ) + ct13_arg.polynomial = _ct13_arg_data.reshape( + _ct13_arg_data.shape[0], + _ct13_arg_data.shape[1], + _ct13_arg_r, + _ct13_arg_c, + _ct13_arg_m_in, + )[..., :_ct13_arg_m].copy() + ct13_arg.batch = ct13_arg.polynomial.shape[0] + ct13_arg.num_elements = ct13_arg.polynomial.shape[1] + ct13_arg.num_moduli = _ct13_arg_m + ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) + ct13_arg.r = _ct13_arg_r + ct13_arg.c = _ct13_arg_c + ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] + ct13_arg.moduli_array = jnp.array( + ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) + ) + ct13_pt_ntt = ( + pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] + .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct13_ptct = v0.ptct_mul[v0.max_level] + ct13_ptct.set_plaintext(ct13_pt_ntt) + ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) + _ct13_data = ( + ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw + ) + _ct13_m_in = _ct13_data.shape[-1] + _ct13_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_m_in + ) + _ct13_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_r) + ) + _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) + if isinstance(_ct13_moduli, (int, np.integer)): + _ct13_moduli = [int(_ct13_moduli)] + ct13 = Polynomial( + { + "batch": _ct13_data.shape[0], + "num_elements": _ct13_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_m, + "precision": 32, + "degree_layout": (_ct13_r, _ct13_c), + }, + {"moduli": list(_ct13_moduli)[:_ct13_m]}, + ) + ct13.polynomial = _ct13_data.reshape( + _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in + )[..., :_ct13_m].copy() + ct13.batch = ct13.polynomial.shape[0] + ct13.num_elements = ct13.polynomial.shape[1] + ct13.num_moduli = _ct13_m + ct13.degree_layout = (_ct13_r, _ct13_c) + ct13.r = _ct13_r + ct13.c = _ct13_c + ct13.moduli = list(_ct13_moduli)[:_ct13_m] + ct13.moduli_array = jnp.array( + ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) + ) + _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + _ct14_m_in = _ct14_data.shape[-1] + _ct14_m = _ct14_m_in + _ct14_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_r) + ) + _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) + if isinstance(_ct14_moduli, (int, np.integer)): + _ct14_moduli = [int(_ct14_moduli)] + ct14 = Polynomial( + { + "batch": _ct14_data.shape[0], + "num_elements": _ct14_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_m, + "precision": 32, + "degree_layout": (_ct14_r, _ct14_c), + }, + {"moduli": list(_ct14_moduli)[:_ct14_m]}, + ) + ct14.polynomial = _ct14_data.reshape( + _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in + )[..., :_ct14_m].copy() + ct14.batch = ct14.polynomial.shape[0] + ct14.num_elements = ct14.polynomial.shape[1] + ct14.num_moduli = _ct14_m + ct14.degree_layout = (_ct14_r, _ct14_c) + ct14.r = _ct14_r + ct14.c = _ct14_c + ct14.moduli = list(_ct14_moduli)[:_ct14_m] + ct14.moduli_array = jnp.array( + ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) + ) + _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] + _ct14_rhs_m = _ct14_rhs_m_in + _ct14_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_rhs_r) + ) + _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) + if isinstance(_ct14_rhs_moduli, (int, np.integer)): + _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] + ct14_rhs = Polynomial( + { + "batch": _ct14_rhs_data.shape[0], + "num_elements": _ct14_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_rhs_m, + "precision": 32, + "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), + }, + {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, + ) + ct14_rhs.polynomial = _ct14_rhs_data.reshape( + _ct14_rhs_data.shape[0], + _ct14_rhs_data.shape[1], + _ct14_rhs_r, + _ct14_rhs_c, + _ct14_rhs_m_in, + )[..., :_ct14_rhs_m].copy() + ct14_rhs.batch = ct14_rhs.polynomial.shape[0] + ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] + ct14_rhs.num_moduli = _ct14_rhs_m + ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) + ct14_rhs.r = _ct14_rhs_r + ct14_rhs.c = _ct14_rhs_c + ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] + ct14_rhs.moduli_array = jnp.array( + ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) + ) + ct14.add(ct14_rhs) + _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) + ct14.polynomial = jnp.where( + ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial + ) + _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 + _ct15_arg_m_in = _ct15_arg_data.shape[-1] + _ct15_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_arg_m_in + ) + _ct15_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_arg_r) + ) + _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) + if isinstance(_ct15_arg_moduli, (int, np.integer)): + _ct15_arg_moduli = [int(_ct15_arg_moduli)] + ct15_arg = Polynomial( + { + "batch": _ct15_arg_data.shape[0], + "num_elements": _ct15_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_arg_m, + "precision": 32, + "degree_layout": (_ct15_arg_r, _ct15_arg_c), + }, + {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, + ) + ct15_arg.polynomial = _ct15_arg_data.reshape( + _ct15_arg_data.shape[0], + _ct15_arg_data.shape[1], + _ct15_arg_r, + _ct15_arg_c, + _ct15_arg_m_in, + )[..., :_ct15_arg_m].copy() + ct15_arg.batch = ct15_arg.polynomial.shape[0] + ct15_arg.num_elements = ct15_arg.polynomial.shape[1] + ct15_arg.num_moduli = _ct15_arg_m + ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) + ct15_arg.r = _ct15_arg_r + ct15_arg.c = _ct15_arg_c + ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] + ct15_arg.moduli_array = jnp.array( + ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) + ) + ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) + _ct15_data = ( + ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw + ) + _ct15_m_in = _ct15_data.shape[-1] + _ct15_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_m_in + ) + _ct15_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_r) + ) + _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) + if isinstance(_ct15_moduli, (int, np.integer)): + _ct15_moduli = [int(_ct15_moduli)] + ct15 = Polynomial( + { + "batch": _ct15_data.shape[0], + "num_elements": _ct15_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_m, + "precision": 32, + "degree_layout": (_ct15_r, _ct15_c), + }, + {"moduli": list(_ct15_moduli)[:_ct15_m]}, + ) + ct15.polynomial = _ct15_data.reshape( + _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in + )[..., :_ct15_m].copy() + ct15.batch = ct15.polynomial.shape[0] + ct15.num_elements = ct15.polynomial.shape[1] + ct15.num_moduli = _ct15_m + ct15.degree_layout = (_ct15_r, _ct15_c) + ct15.r = _ct15_r + ct15.c = _ct15_c + ct15.moduli = list(_ct15_moduli)[:_ct15_m] + ct15.moduli_array = jnp.array( + ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) + ) + _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + _ct16_m_in = _ct16_data.shape[-1] + _ct16_m = _ct16_m_in + _ct16_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_r) + ) + _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) + if isinstance(_ct16_moduli, (int, np.integer)): + _ct16_moduli = [int(_ct16_moduli)] + ct16 = Polynomial( + { + "batch": _ct16_data.shape[0], + "num_elements": _ct16_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_m, + "precision": 32, + "degree_layout": (_ct16_r, _ct16_c), + }, + {"moduli": list(_ct16_moduli)[:_ct16_m]}, + ) + ct16.polynomial = _ct16_data.reshape( + _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in + )[..., :_ct16_m].copy() + ct16.batch = ct16.polynomial.shape[0] + ct16.num_elements = ct16.polynomial.shape[1] + ct16.num_moduli = _ct16_m + ct16.degree_layout = (_ct16_r, _ct16_c) + ct16.r = _ct16_r + ct16.c = _ct16_c + ct16.moduli = list(_ct16_moduli)[:_ct16_m] + ct16.moduli_array = jnp.array( + ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) + ) + _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 + _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] + _ct16_rhs_m = _ct16_rhs_m_in + _ct16_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_rhs_r) + ) + _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) + if isinstance(_ct16_rhs_moduli, (int, np.integer)): + _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] + ct16_rhs = Polynomial( + { + "batch": _ct16_rhs_data.shape[0], + "num_elements": _ct16_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_rhs_m, + "precision": 32, + "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), + }, + {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, + ) + ct16_rhs.polynomial = _ct16_rhs_data.reshape( + _ct16_rhs_data.shape[0], + _ct16_rhs_data.shape[1], + _ct16_rhs_r, + _ct16_rhs_c, + _ct16_rhs_m_in, + )[..., :_ct16_rhs_m].copy() + ct16_rhs.batch = ct16_rhs.polynomial.shape[0] + ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] + ct16_rhs.num_moduli = _ct16_rhs_m + ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) + ct16_rhs.r = _ct16_rhs_r + ct16_rhs.c = _ct16_rhs_c + ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] + ct16_rhs.moduli_array = jnp.array( + ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) + ) + ct16.add(ct16_rhs) + _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) + ct16.polynomial = jnp.where( + ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial + ) + _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + _ct17_m_in = _ct17_data.shape[-1] + _ct17_m = _ct17_m_in + _ct17_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_r) + ) + _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) + if isinstance(_ct17_moduli, (int, np.integer)): + _ct17_moduli = [int(_ct17_moduli)] + ct17 = Polynomial( + { + "batch": _ct17_data.shape[0], + "num_elements": _ct17_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_m, + "precision": 32, + "degree_layout": (_ct17_r, _ct17_c), + }, + {"moduli": list(_ct17_moduli)[:_ct17_m]}, + ) + ct17.polynomial = _ct17_data.reshape( + _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in + )[..., :_ct17_m].copy() + ct17.batch = ct17.polynomial.shape[0] + ct17.num_elements = ct17.polynomial.shape[1] + ct17.num_moduli = _ct17_m + ct17.degree_layout = (_ct17_r, _ct17_c) + ct17.r = _ct17_r + ct17.c = _ct17_c + ct17.moduli = list(_ct17_moduli)[:_ct17_m] + ct17.moduli_array = jnp.array( + ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) + ) + _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] + _ct17_rhs_m = _ct17_rhs_m_in + _ct17_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_rhs_r) + ) + _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) + if isinstance(_ct17_rhs_moduli, (int, np.integer)): + _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] + ct17_rhs = Polynomial( + { + "batch": _ct17_rhs_data.shape[0], + "num_elements": _ct17_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_rhs_m, + "precision": 32, + "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), + }, + {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, + ) + ct17_rhs.polynomial = _ct17_rhs_data.reshape( + _ct17_rhs_data.shape[0], + _ct17_rhs_data.shape[1], + _ct17_rhs_r, + _ct17_rhs_c, + _ct17_rhs_m_in, + )[..., :_ct17_rhs_m].copy() + ct17_rhs.batch = ct17_rhs.polynomial.shape[0] + ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] + ct17_rhs.num_moduli = _ct17_rhs_m + ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) + ct17_rhs.r = _ct17_rhs_r + ct17_rhs.c = _ct17_rhs_c + ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] + ct17_rhs.moduli_array = jnp.array( + ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) + ) + ct17.add(ct17_rhs) + _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) + ct17.polynomial = jnp.where( + ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial + ) + _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 + _ct18_m_in = _ct18_data.shape[-1] + _ct18_m = _ct18_m_in + _ct18_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_r) + ) + _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) + if isinstance(_ct18_moduli, (int, np.integer)): + _ct18_moduli = [int(_ct18_moduli)] + ct18 = Polynomial( + { + "batch": _ct18_data.shape[0], + "num_elements": _ct18_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_m, + "precision": 32, + "degree_layout": (_ct18_r, _ct18_c), + }, + {"moduli": list(_ct18_moduli)[:_ct18_m]}, + ) + ct18.polynomial = _ct18_data.reshape( + _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in + )[..., :_ct18_m].copy() + ct18.batch = ct18.polynomial.shape[0] + ct18.num_elements = ct18.polynomial.shape[1] + ct18.num_moduli = _ct18_m + ct18.degree_layout = (_ct18_r, _ct18_c) + ct18.r = _ct18_r + ct18.c = _ct18_c + ct18.moduli = list(_ct18_moduli)[:_ct18_m] + ct18.moduli_array = jnp.array( + ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) + ) + _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 + _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] + _ct18_rhs_m = _ct18_rhs_m_in + _ct18_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_rhs_r) + ) + _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) + if isinstance(_ct18_rhs_moduli, (int, np.integer)): + _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] + ct18_rhs = Polynomial( + { + "batch": _ct18_rhs_data.shape[0], + "num_elements": _ct18_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_rhs_m, + "precision": 32, + "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), + }, + {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, + ) + ct18_rhs.polynomial = _ct18_rhs_data.reshape( + _ct18_rhs_data.shape[0], + _ct18_rhs_data.shape[1], + _ct18_rhs_r, + _ct18_rhs_c, + _ct18_rhs_m_in, + )[..., :_ct18_rhs_m].copy() + ct18_rhs.batch = ct18_rhs.polynomial.shape[0] + ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] + ct18_rhs.num_moduli = _ct18_rhs_m + ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) + ct18_rhs.r = _ct18_rhs_r + ct18_rhs.c = _ct18_rhs_c + ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] + ct18_rhs.moduli_array = jnp.array( + ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) + ) + ct18.add(ct18_rhs) + _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) + ct18.polynomial = jnp.where( + ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial + ) + _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 + _ct19_m_in = _ct19_data.shape[-1] + _ct19_m = _ct19_m_in + _ct19_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_r) + ) + _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) + if isinstance(_ct19_moduli, (int, np.integer)): + _ct19_moduli = [int(_ct19_moduli)] + ct19 = Polynomial( + { + "batch": _ct19_data.shape[0], + "num_elements": _ct19_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_m, + "precision": 32, + "degree_layout": (_ct19_r, _ct19_c), + }, + {"moduli": list(_ct19_moduli)[:_ct19_m]}, + ) + ct19.polynomial = _ct19_data.reshape( + _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in + )[..., :_ct19_m].copy() + ct19.batch = ct19.polynomial.shape[0] + ct19.num_elements = ct19.polynomial.shape[1] + ct19.num_moduli = _ct19_m + ct19.degree_layout = (_ct19_r, _ct19_c) + ct19.r = _ct19_r + ct19.c = _ct19_c + ct19.moduli = list(_ct19_moduli)[:_ct19_m] + ct19.moduli_array = jnp.array( + ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) + ) + _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 + _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] + _ct19_rhs_m = _ct19_rhs_m_in + _ct19_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_rhs_r) + ) + _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) + if isinstance(_ct19_rhs_moduli, (int, np.integer)): + _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] + ct19_rhs = Polynomial( + { + "batch": _ct19_rhs_data.shape[0], + "num_elements": _ct19_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_rhs_m, + "precision": 32, + "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), + }, + {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, + ) + ct19_rhs.polynomial = _ct19_rhs_data.reshape( + _ct19_rhs_data.shape[0], + _ct19_rhs_data.shape[1], + _ct19_rhs_r, + _ct19_rhs_c, + _ct19_rhs_m_in, + )[..., :_ct19_rhs_m].copy() + ct19_rhs.batch = ct19_rhs.polynomial.shape[0] + ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] + ct19_rhs.num_moduli = _ct19_rhs_m + ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) + ct19_rhs.r = _ct19_rhs_r + ct19_rhs.c = _ct19_rhs_c + ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] + ct19_rhs.moduli_array = jnp.array( + ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) + ) + ct19.add(ct19_rhs) + _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) + ct19.polynomial = jnp.where( + ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial + ) + v16 = [None] * 1 + _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 + _ct20_arg_m_in = _ct20_arg_data.shape[-1] + _ct20_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct20_arg_m_in + ) + _ct20_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_arg_r) + ) + _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) + if isinstance(_ct20_arg_moduli, (int, np.integer)): + _ct20_arg_moduli = [int(_ct20_arg_moduli)] + ct20_arg = Polynomial( + { + "batch": _ct20_arg_data.shape[0], + "num_elements": _ct20_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_arg_m, + "precision": 32, + "degree_layout": (_ct20_arg_r, _ct20_arg_c), + }, + {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, + ) + ct20_arg.polynomial = _ct20_arg_data.reshape( + _ct20_arg_data.shape[0], + _ct20_arg_data.shape[1], + _ct20_arg_r, + _ct20_arg_c, + _ct20_arg_m_in, + )[..., :_ct20_arg_m].copy() + ct20_arg.batch = ct20_arg.polynomial.shape[0] + ct20_arg.num_elements = ct20_arg.polynomial.shape[1] + ct20_arg.num_moduli = _ct20_arg_m + ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) + ct20_arg.r = _ct20_arg_r + ct20_arg.c = _ct20_arg_c + ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] + ct20_arg.moduli_array = jnp.array( + ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) + ) + ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) + _ct20_data = ( + ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw + ) + _ct20_m_in = _ct20_data.shape[-1] + _ct20_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct20_m_in + ) + _ct20_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_r) + ) + _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) + if isinstance(_ct20_moduli, (int, np.integer)): + _ct20_moduli = [int(_ct20_moduli)] + ct20 = Polynomial( + { + "batch": _ct20_data.shape[0], + "num_elements": _ct20_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_m, + "precision": 32, + "degree_layout": (_ct20_r, _ct20_c), + }, + {"moduli": list(_ct20_moduli)[:_ct20_m]}, + ) + ct20.polynomial = _ct20_data.reshape( + _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in + )[..., :_ct20_m].copy() + ct20.batch = ct20.polynomial.shape[0] + ct20.num_elements = ct20.polynomial.shape[1] + ct20.num_moduli = _ct20_m + ct20.degree_layout = (_ct20_r, _ct20_c) + ct20.r = _ct20_r + ct20.c = _ct20_c + ct20.moduli = list(_ct20_moduli)[:_ct20_m] + ct20.moduli_array = jnp.array( + ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) + ) + v16[0] = ct20 + v17 = v16 + return v17 + + +def matvec_shift( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_shift__preprocessing(v0, v1) + v11 = matvec_shift__preprocessed(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) + return v11 + + +def matvec_shift__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw + _ct_m_in = _ct_data.shape[-1] + _ct_m = _ct_m_in + _ct_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct_r) + ) + _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) + if isinstance(_ct_moduli, (int, np.integer)): + _ct_moduli = [int(_ct_moduli)] + ct = Polynomial( + { + "batch": _ct_data.shape[0], + "num_elements": _ct_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct_m, + "precision": 32, + "degree_layout": (_ct_r, _ct_c), + }, + {"moduli": list(_ct_moduli)[:_ct_m]}, + ) + ct.polynomial = _ct_data.reshape( + _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in + )[..., :_ct_m].copy() + ct.batch = ct.polynomial.shape[0] + ct.num_elements = ct.polynomial.shape[1] + ct.num_moduli = _ct_m + ct.degree_layout = (_ct_r, _ct_c) + ct.r = _ct_r + ct.c = _ct_c + ct.moduli = list(_ct_moduli)[:_ct_m] + ct.moduli_array = jnp.array( + ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) + ) + v16 = [ct] + return v16 + + +def matvec_shift__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 0 + v8 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + _num_moduli = ct.polynomial.shape[-1] + _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": ct.polynomial.shape[0], + "num_elements": ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + ct.polynomial.reshape( + ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli + ) + ) + pt = v0.decrypt(_ct_for_dec) + v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v10 = v8.copy() + for v11 in range(0, 8): + v13 = int(v11) + v14 = v9[0, v13] + v10[v13] = v14 + return v10 + + +def matvec_random__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> ( + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +): + v2 = np.array( + [ + 8.116263e-01, + 1.906357e00, + 1.490788e00, + 1.237451e00, + 3.964354e-01, + 3.963896e-01, + 2.103589e-01, + 1.745735e00, + 1.242118e00, + 1.445338e00, + 1.391105e-01, + 1.942829e00, + 1.681641e00, + 5.034443e-01, + 4.454674e-01, + 4.484686e-01, + 6.780602e-01, + 1.097037e00, + 9.206955e-01, + 6.533354e-01, + 1.262521e00, + 3.650383e-01, + 6.550748e-01, + 7.960875e-01, + 9.665329e-01, + 1.591834e00, + 4.793802e-01, + 1.077045e00, + 1.225588e00, + 1.882558e-01, + 1.254335e00, + 4.239958e-01, + 2.235980e-01, + 1.902883e00, + 1.934701e00, + 1.635955e00, + 6.787661e-01, + 2.855770e-01, + 1.400043e00, + 9.362897e-01, + 3.318726e-01, + 1.040836e00, + 1.653382e-01, + 1.827709e00, + 5.916820e-01, + 1.358792e00, + 6.922510e-01, + 1.088129e00, + 1.138749e00, + 4.512235e-01, + 1.942211e00, + 1.572752e00, + 1.885048e00, + 1.800172e00, + 1.236010e00, + 1.851561e00, + 2.681358e-01, + 4.723674e-01, + 1.859318e-01, + 7.181276e-01, + 8.384869e-01, + 6.155632e-01, + 1.674601e00, + 7.778313e-01, + ], + dtype=np.float32, + ).reshape(8, 8) + v3 = _assign_layout_15335824159471298539(v2) + v4 = v3[3 : 3 + 1, 0 : 0 + 5] + v5 = v3[3 : 3 + 1, 5 : 5 + 3] + v6 = np.zeros( + ( + 1, + 8, + ), + dtype=np.float32, + ) + v7 = v6.copy() + v7[0 : 0 + 1, 3 : 3 + 5] = v4 + v8 = v7.copy() + v8[0 : 0 + 1, 0 : 0 + 3] = v5 + v9 = v3[4 : 4 + 1, 0 : 0 + 5] + v10 = v3[4 : 4 + 1, 5 : 5 + 3] + v11 = v6.copy() + v11[0 : 0 + 1, 3 : 3 + 5] = v9 + v12 = v11.copy() + v12[0 : 0 + 1, 0 : 0 + 3] = v10 + v13 = v3[5 : 5 + 1, 0 : 0 + 5] + v14 = v3[5 : 5 + 1, 5 : 5 + 3] + v15 = v6.copy() + v15[0 : 0 + 1, 3 : 3 + 5] = v13 + v16 = v15.copy() + v16[0 : 0 + 1, 0 : 0 + 3] = v14 + v17 = v3[6 : 6 + 1, 0 : 0 + 2] + v18 = v3[6 : 6 + 1, 2 : 2 + 6] + v19 = v6.copy() + v19[0 : 0 + 1, 6 : 6 + 2] = v17 + v20 = v19.copy() + v20[0 : 0 + 1, 0 : 0 + 6] = v18 + v21 = v3[7 : 7 + 1, 0 : 0 + 2] + v22 = v3[7 : 7 + 1, 2 : 2 + 6] + v23 = v6.copy() + v23[0 : 0 + 1, 6 : 6 + 2] = v21 + v24 = v23.copy() + v24[0 : 0 + 1, 0 : 0 + 6] = v22 + v25 = v3[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v25) + v26 = v3[1 : 1 + 1, 0 : 0 + 8].reshape(8) + pt1 = v0.encode(v26) + v27 = v3[2 : 2 + 1, 0 : 0 + 8].reshape(8) + pt2 = v0.encode(v27) + v28 = v8[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt3 = v0.encode(v28) + v29 = v12[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt4 = v0.encode(v29) + v30 = v16[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt5 = v0.encode(v30) + v31 = v20[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt6 = v0.encode(v31) + v32 = v24[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt7 = v0.encode(v32) + v33 = [pt] + v34 = [pt1] + v35 = [pt2] + v36 = [pt3] + v37 = [pt4] + v38 = [pt5] + v39 = [pt6] + v40 = [pt7] + return (v33, v34, v35, v36, v37, v38, v39, v40) + + +def matvec_random__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, + v5: np.ndarray, + v6: np.ndarray, + v7: np.ndarray, + v8: np.ndarray, + v9: np.ndarray, + v10: np.ndarray, +) -> np.ndarray: + v11 = 1 + v12 = 2 + v13 = 3 + v14 = 6 + v15 = 0 + pt = v3[0] + pt1 = v4[0] + pt2 = v5[0] + pt3 = v6[0] + pt4 = v7[0] + pt5 = v8[0] + pt6 = v9[0] + pt7 = v10[0] + ct = v2[0] + _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct1_arg_m_in = _ct1_arg_data.shape[-1] + _ct1_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_arg_m_in + ) + _ct1_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_arg_r) + ) + _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct1_arg_moduli, (int, np.integer)): + _ct1_arg_moduli = [int(_ct1_arg_moduli)] + ct1_arg = Polynomial( + { + "batch": _ct1_arg_data.shape[0], + "num_elements": _ct1_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_arg_m, + "precision": 32, + "degree_layout": (_ct1_arg_r, _ct1_arg_c), + }, + {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, + ) + ct1_arg.polynomial = _ct1_arg_data.reshape( + _ct1_arg_data.shape[0], + _ct1_arg_data.shape[1], + _ct1_arg_r, + _ct1_arg_c, + _ct1_arg_m_in, + )[..., :_ct1_arg_m].copy() + ct1_arg.batch = ct1_arg.polynomial.shape[0] + ct1_arg.num_elements = ct1_arg.polynomial.shape[1] + ct1_arg.num_moduli = _ct1_arg_m + ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) + ct1_arg.r = _ct1_arg_r + ct1_arg.c = _ct1_arg_c + ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] + ct1_arg.moduli_array = jnp.array( + ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) + ) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) + _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw + _ct1_m_in = _ct1_data.shape[-1] + _ct1_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_m_in + ) + _ct1_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_r) + ) + _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) + if isinstance(_ct1_moduli, (int, np.integer)): + _ct1_moduli = [int(_ct1_moduli)] + ct1 = Polynomial( + { + "batch": _ct1_data.shape[0], + "num_elements": _ct1_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_m, + "precision": 32, + "degree_layout": (_ct1_r, _ct1_c), + }, + {"moduli": list(_ct1_moduli)[:_ct1_m]}, + ) + ct1.polynomial = _ct1_data.reshape( + _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in + )[..., :_ct1_m].copy() + ct1.batch = ct1.polynomial.shape[0] + ct1.num_elements = ct1.polynomial.shape[1] + ct1.num_moduli = _ct1_m + ct1.degree_layout = (_ct1_r, _ct1_c) + ct1.r = _ct1_r + ct1.c = _ct1_c + ct1.moduli = list(_ct1_moduli)[:_ct1_m] + ct1.moduli_array = jnp.array( + ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) + ) + _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct2_arg_m_in = _ct2_arg_data.shape[-1] + _ct2_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_arg_m_in + ) + _ct2_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_arg_r) + ) + _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct2_arg_moduli, (int, np.integer)): + _ct2_arg_moduli = [int(_ct2_arg_moduli)] + ct2_arg = Polynomial( + { + "batch": _ct2_arg_data.shape[0], + "num_elements": _ct2_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_arg_m, + "precision": 32, + "degree_layout": (_ct2_arg_r, _ct2_arg_c), + }, + {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, + ) + ct2_arg.polynomial = _ct2_arg_data.reshape( + _ct2_arg_data.shape[0], + _ct2_arg_data.shape[1], + _ct2_arg_r, + _ct2_arg_c, + _ct2_arg_m_in, + )[..., :_ct2_arg_m].copy() + ct2_arg.batch = ct2_arg.polynomial.shape[0] + ct2_arg.num_elements = ct2_arg.polynomial.shape[1] + ct2_arg.num_moduli = _ct2_arg_m + ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) + ct2_arg.r = _ct2_arg_r + ct2_arg.c = _ct2_arg_c + ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] + ct2_arg.moduli_array = jnp.array( + ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) + ) + ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw + _ct2_m_in = _ct2_data.shape[-1] + _ct2_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_m_in + ) + _ct2_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_r) + ) + _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) + if isinstance(_ct2_moduli, (int, np.integer)): + _ct2_moduli = [int(_ct2_moduli)] + ct2 = Polynomial( + { + "batch": _ct2_data.shape[0], + "num_elements": _ct2_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_m, + "precision": 32, + "degree_layout": (_ct2_r, _ct2_c), + }, + {"moduli": list(_ct2_moduli)[:_ct2_m]}, + ) + ct2.polynomial = _ct2_data.reshape( + _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in + )[..., :_ct2_m].copy() + ct2.batch = ct2.polynomial.shape[0] + ct2.num_elements = ct2.polynomial.shape[1] + ct2.num_moduli = _ct2_m + ct2.degree_layout = (_ct2_r, _ct2_c) + ct2.r = _ct2_r + ct2.c = _ct2_c + ct2.moduli = list(_ct2_moduli)[:_ct2_m] + ct2.moduli_array = jnp.array( + ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) + ) + _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct3_arg_m_in = _ct3_arg_data.shape[-1] + _ct3_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_arg_m_in + ) + _ct3_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_arg_r) + ) + _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct3_arg_moduli, (int, np.integer)): + _ct3_arg_moduli = [int(_ct3_arg_moduli)] + ct3_arg = Polynomial( + { + "batch": _ct3_arg_data.shape[0], + "num_elements": _ct3_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_arg_m, + "precision": 32, + "degree_layout": (_ct3_arg_r, _ct3_arg_c), + }, + {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, + ) + ct3_arg.polynomial = _ct3_arg_data.reshape( + _ct3_arg_data.shape[0], + _ct3_arg_data.shape[1], + _ct3_arg_r, + _ct3_arg_c, + _ct3_arg_m_in, + )[..., :_ct3_arg_m].copy() + ct3_arg.batch = ct3_arg.polynomial.shape[0] + ct3_arg.num_elements = ct3_arg.polynomial.shape[1] + ct3_arg.num_moduli = _ct3_arg_m + ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) + ct3_arg.r = _ct3_arg_r + ct3_arg.c = _ct3_arg_c + ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] + ct3_arg.moduli_array = jnp.array( + ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) + ) + ct3_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] + .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct3_ptct = v0.ptct_mul[v0.max_level] + ct3_ptct.set_plaintext(ct3_pt_ntt) + ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) + _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw + _ct3_m_in = _ct3_data.shape[-1] + _ct3_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_m_in + ) + _ct3_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_r) + ) + _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) + if isinstance(_ct3_moduli, (int, np.integer)): + _ct3_moduli = [int(_ct3_moduli)] + ct3 = Polynomial( + { + "batch": _ct3_data.shape[0], + "num_elements": _ct3_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_m, + "precision": 32, + "degree_layout": (_ct3_r, _ct3_c), + }, + {"moduli": list(_ct3_moduli)[:_ct3_m]}, + ) + ct3.polynomial = _ct3_data.reshape( + _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in + )[..., :_ct3_m].copy() + ct3.batch = ct3.polynomial.shape[0] + ct3.num_elements = ct3.polynomial.shape[1] + ct3.num_moduli = _ct3_m + ct3.degree_layout = (_ct3_r, _ct3_c) + ct3.r = _ct3_r + ct3.c = _ct3_c + ct3.moduli = list(_ct3_moduli)[:_ct3_m] + ct3.moduli_array = jnp.array( + ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) + ) + _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct4_arg_m_in = _ct4_arg_data.shape[-1] + _ct4_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_arg_m_in + ) + _ct4_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_arg_r) + ) + _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct4_arg_moduli, (int, np.integer)): + _ct4_arg_moduli = [int(_ct4_arg_moduli)] + ct4_arg = Polynomial( + { + "batch": _ct4_arg_data.shape[0], + "num_elements": _ct4_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_arg_m, + "precision": 32, + "degree_layout": (_ct4_arg_r, _ct4_arg_c), + }, + {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, + ) + ct4_arg.polynomial = _ct4_arg_data.reshape( + _ct4_arg_data.shape[0], + _ct4_arg_data.shape[1], + _ct4_arg_r, + _ct4_arg_c, + _ct4_arg_m_in, + )[..., :_ct4_arg_m].copy() + ct4_arg.batch = ct4_arg.polynomial.shape[0] + ct4_arg.num_elements = ct4_arg.polynomial.shape[1] + ct4_arg.num_moduli = _ct4_arg_m + ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) + ct4_arg.r = _ct4_arg_r + ct4_arg.c = _ct4_arg_c + ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] + ct4_arg.moduli_array = jnp.array( + ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) + ) + ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) + _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw + _ct4_m_in = _ct4_data.shape[-1] + _ct4_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_m_in + ) + _ct4_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_r) + ) + _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) + if isinstance(_ct4_moduli, (int, np.integer)): + _ct4_moduli = [int(_ct4_moduli)] + ct4 = Polynomial( + { + "batch": _ct4_data.shape[0], + "num_elements": _ct4_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_m, + "precision": 32, + "degree_layout": (_ct4_r, _ct4_c), + }, + {"moduli": list(_ct4_moduli)[:_ct4_m]}, + ) + ct4.polynomial = _ct4_data.reshape( + _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in + )[..., :_ct4_m].copy() + ct4.batch = ct4.polynomial.shape[0] + ct4.num_elements = ct4.polynomial.shape[1] + ct4.num_moduli = _ct4_m + ct4.degree_layout = (_ct4_r, _ct4_c) + ct4.r = _ct4_r + ct4.c = _ct4_c + ct4.moduli = list(_ct4_moduli)[:_ct4_m] + ct4.moduli_array = jnp.array( + ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) + ) + _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct5_arg_m_in = _ct5_arg_data.shape[-1] + _ct5_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_arg_m_in + ) + _ct5_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_arg_r) + ) + _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct5_arg_moduli, (int, np.integer)): + _ct5_arg_moduli = [int(_ct5_arg_moduli)] + ct5_arg = Polynomial( + { + "batch": _ct5_arg_data.shape[0], + "num_elements": _ct5_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_arg_m, + "precision": 32, + "degree_layout": (_ct5_arg_r, _ct5_arg_c), + }, + {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, + ) + ct5_arg.polynomial = _ct5_arg_data.reshape( + _ct5_arg_data.shape[0], + _ct5_arg_data.shape[1], + _ct5_arg_r, + _ct5_arg_c, + _ct5_arg_m_in, + )[..., :_ct5_arg_m].copy() + ct5_arg.batch = ct5_arg.polynomial.shape[0] + ct5_arg.num_elements = ct5_arg.polynomial.shape[1] + ct5_arg.num_moduli = _ct5_arg_m + ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) + ct5_arg.r = _ct5_arg_r + ct5_arg.c = _ct5_arg_c + ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] + ct5_arg.moduli_array = jnp.array( + ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) + ) + ct5_pt_ntt = ( + pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) + _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw + _ct5_m_in = _ct5_data.shape[-1] + _ct5_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_m_in + ) + _ct5_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_r) + ) + _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) + if isinstance(_ct5_moduli, (int, np.integer)): + _ct5_moduli = [int(_ct5_moduli)] + ct5 = Polynomial( + { + "batch": _ct5_data.shape[0], + "num_elements": _ct5_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_m, + "precision": 32, + "degree_layout": (_ct5_r, _ct5_c), + }, + {"moduli": list(_ct5_moduli)[:_ct5_m]}, + ) + ct5.polynomial = _ct5_data.reshape( + _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in + )[..., :_ct5_m].copy() + ct5.batch = ct5.polynomial.shape[0] + ct5.num_elements = ct5.polynomial.shape[1] + ct5.num_moduli = _ct5_m + ct5.degree_layout = (_ct5_r, _ct5_c) + ct5.r = _ct5_r + ct5.c = _ct5_c + ct5.moduli = list(_ct5_moduli)[:_ct5_m] + ct5.moduli_array = jnp.array( + ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) + ) + _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct6_arg_m_in = _ct6_arg_data.shape[-1] + _ct6_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_arg_m_in + ) + _ct6_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_arg_r) + ) + _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct6_arg_moduli, (int, np.integer)): + _ct6_arg_moduli = [int(_ct6_arg_moduli)] + ct6_arg = Polynomial( + { + "batch": _ct6_arg_data.shape[0], + "num_elements": _ct6_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_arg_m, + "precision": 32, + "degree_layout": (_ct6_arg_r, _ct6_arg_c), + }, + {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, + ) + ct6_arg.polynomial = _ct6_arg_data.reshape( + _ct6_arg_data.shape[0], + _ct6_arg_data.shape[1], + _ct6_arg_r, + _ct6_arg_c, + _ct6_arg_m_in, + )[..., :_ct6_arg_m].copy() + ct6_arg.batch = ct6_arg.polynomial.shape[0] + ct6_arg.num_elements = ct6_arg.polynomial.shape[1] + ct6_arg.num_moduli = _ct6_arg_m + ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) + ct6_arg.r = _ct6_arg_r + ct6_arg.c = _ct6_arg_c + ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] + ct6_arg.moduli_array = jnp.array( + ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) + ) + ct6_pt_ntt = ( + pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] + .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct6_ptct = v0.ptct_mul[v0.max_level] + ct6_ptct.set_plaintext(ct6_pt_ntt) + ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) + _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw + _ct6_m_in = _ct6_data.shape[-1] + _ct6_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_m_in + ) + _ct6_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_r) + ) + _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) + if isinstance(_ct6_moduli, (int, np.integer)): + _ct6_moduli = [int(_ct6_moduli)] + ct6 = Polynomial( + { + "batch": _ct6_data.shape[0], + "num_elements": _ct6_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_m, + "precision": 32, + "degree_layout": (_ct6_r, _ct6_c), + }, + {"moduli": list(_ct6_moduli)[:_ct6_m]}, + ) + ct6.polynomial = _ct6_data.reshape( + _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in + )[..., :_ct6_m].copy() + ct6.batch = ct6.polynomial.shape[0] + ct6.num_elements = ct6.polynomial.shape[1] + ct6.num_moduli = _ct6_m + ct6.degree_layout = (_ct6_r, _ct6_c) + ct6.r = _ct6_r + ct6.c = _ct6_c + ct6.moduli = list(_ct6_moduli)[:_ct6_m] + ct6.moduli_array = jnp.array( + ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) + ) + _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct7_arg_m_in = _ct7_arg_data.shape[-1] + _ct7_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_arg_m_in + ) + _ct7_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_arg_r) + ) + _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct7_arg_moduli, (int, np.integer)): + _ct7_arg_moduli = [int(_ct7_arg_moduli)] + ct7_arg = Polynomial( + { + "batch": _ct7_arg_data.shape[0], + "num_elements": _ct7_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_arg_m, + "precision": 32, + "degree_layout": (_ct7_arg_r, _ct7_arg_c), + }, + {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, + ) + ct7_arg.polynomial = _ct7_arg_data.reshape( + _ct7_arg_data.shape[0], + _ct7_arg_data.shape[1], + _ct7_arg_r, + _ct7_arg_c, + _ct7_arg_m_in, + )[..., :_ct7_arg_m].copy() + ct7_arg.batch = ct7_arg.polynomial.shape[0] + ct7_arg.num_elements = ct7_arg.polynomial.shape[1] + ct7_arg.num_moduli = _ct7_arg_m + ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) + ct7_arg.r = _ct7_arg_r + ct7_arg.c = _ct7_arg_c + ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] + ct7_arg.moduli_array = jnp.array( + ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) + ) + ct7_pt_ntt = ( + pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] + .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct7_ptct = v0.ptct_mul[v0.max_level] + ct7_ptct.set_plaintext(ct7_pt_ntt) + ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) + _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw + _ct7_m_in = _ct7_data.shape[-1] + _ct7_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_m_in + ) + _ct7_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_r) + ) + _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) + if isinstance(_ct7_moduli, (int, np.integer)): + _ct7_moduli = [int(_ct7_moduli)] + ct7 = Polynomial( + { + "batch": _ct7_data.shape[0], + "num_elements": _ct7_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_m, + "precision": 32, + "degree_layout": (_ct7_r, _ct7_c), + }, + {"moduli": list(_ct7_moduli)[:_ct7_m]}, + ) + ct7.polynomial = _ct7_data.reshape( + _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in + )[..., :_ct7_m].copy() + ct7.batch = ct7.polynomial.shape[0] + ct7.num_elements = ct7.polynomial.shape[1] + ct7.num_moduli = _ct7_m + ct7.degree_layout = (_ct7_r, _ct7_c) + ct7.r = _ct7_r + ct7.c = _ct7_c + ct7.moduli = list(_ct7_moduli)[:_ct7_m] + ct7.moduli_array = jnp.array( + ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) + ) + _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct8_arg_m_in = _ct8_arg_data.shape[-1] + _ct8_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_arg_m_in + ) + _ct8_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_arg_r) + ) + _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct8_arg_moduli, (int, np.integer)): + _ct8_arg_moduli = [int(_ct8_arg_moduli)] + ct8_arg = Polynomial( + { + "batch": _ct8_arg_data.shape[0], + "num_elements": _ct8_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_arg_m, + "precision": 32, + "degree_layout": (_ct8_arg_r, _ct8_arg_c), + }, + {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, + ) + ct8_arg.polynomial = _ct8_arg_data.reshape( + _ct8_arg_data.shape[0], + _ct8_arg_data.shape[1], + _ct8_arg_r, + _ct8_arg_c, + _ct8_arg_m_in, + )[..., :_ct8_arg_m].copy() + ct8_arg.batch = ct8_arg.polynomial.shape[0] + ct8_arg.num_elements = ct8_arg.polynomial.shape[1] + ct8_arg.num_moduli = _ct8_arg_m + ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) + ct8_arg.r = _ct8_arg_r + ct8_arg.c = _ct8_arg_c + ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] + ct8_arg.moduli_array = jnp.array( + ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) + ) + ct8_pt_ntt = ( + pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] + .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct8_ptct = v0.ptct_mul[v0.max_level] + ct8_ptct.set_plaintext(ct8_pt_ntt) + ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) + _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw + _ct8_m_in = _ct8_data.shape[-1] + _ct8_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_m_in + ) + _ct8_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_r) + ) + _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) + if isinstance(_ct8_moduli, (int, np.integer)): + _ct8_moduli = [int(_ct8_moduli)] + ct8 = Polynomial( + { + "batch": _ct8_data.shape[0], + "num_elements": _ct8_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_m, + "precision": 32, + "degree_layout": (_ct8_r, _ct8_c), + }, + {"moduli": list(_ct8_moduli)[:_ct8_m]}, + ) + ct8.polynomial = _ct8_data.reshape( + _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in + )[..., :_ct8_m].copy() + ct8.batch = ct8.polynomial.shape[0] + ct8.num_elements = ct8.polynomial.shape[1] + ct8.num_moduli = _ct8_m + ct8.degree_layout = (_ct8_r, _ct8_c) + ct8.r = _ct8_r + ct8.c = _ct8_c + ct8.moduli = list(_ct8_moduli)[:_ct8_m] + ct8.moduli_array = jnp.array( + ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) + ) + _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + _ct9_m_in = _ct9_data.shape[-1] + _ct9_m = _ct9_m_in + _ct9_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_r) + ) + _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) + if isinstance(_ct9_moduli, (int, np.integer)): + _ct9_moduli = [int(_ct9_moduli)] + ct9 = Polynomial( + { + "batch": _ct9_data.shape[0], + "num_elements": _ct9_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_m, + "precision": 32, + "degree_layout": (_ct9_r, _ct9_c), + }, + {"moduli": list(_ct9_moduli)[:_ct9_m]}, + ) + ct9.polynomial = _ct9_data.reshape( + _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in + )[..., :_ct9_m].copy() + ct9.batch = ct9.polynomial.shape[0] + ct9.num_elements = ct9.polynomial.shape[1] + ct9.num_moduli = _ct9_m + ct9.degree_layout = (_ct9_r, _ct9_c) + ct9.r = _ct9_r + ct9.c = _ct9_c + ct9.moduli = list(_ct9_moduli)[:_ct9_m] + ct9.moduli_array = jnp.array( + ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) + ) + _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 + _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] + _ct9_rhs_m = _ct9_rhs_m_in + _ct9_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_rhs_r) + ) + _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) + if isinstance(_ct9_rhs_moduli, (int, np.integer)): + _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] + ct9_rhs = Polynomial( + { + "batch": _ct9_rhs_data.shape[0], + "num_elements": _ct9_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_rhs_m, + "precision": 32, + "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), + }, + {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, + ) + ct9_rhs.polynomial = _ct9_rhs_data.reshape( + _ct9_rhs_data.shape[0], + _ct9_rhs_data.shape[1], + _ct9_rhs_r, + _ct9_rhs_c, + _ct9_rhs_m_in, + )[..., :_ct9_rhs_m].copy() + ct9_rhs.batch = ct9_rhs.polynomial.shape[0] + ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] + ct9_rhs.num_moduli = _ct9_rhs_m + ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) + ct9_rhs.r = _ct9_rhs_r + ct9_rhs.c = _ct9_rhs_c + ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] + ct9_rhs.moduli_array = jnp.array( + ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) + ) + ct9.add(ct9_rhs) + _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) + ct9.polynomial = jnp.where( + ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial + ) + _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + _ct10_m_in = _ct10_data.shape[-1] + _ct10_m = _ct10_m_in + _ct10_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_r) + ) + _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) + if isinstance(_ct10_moduli, (int, np.integer)): + _ct10_moduli = [int(_ct10_moduli)] + ct10 = Polynomial( + { + "batch": _ct10_data.shape[0], + "num_elements": _ct10_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_m, + "precision": 32, + "degree_layout": (_ct10_r, _ct10_c), + }, + {"moduli": list(_ct10_moduli)[:_ct10_m]}, + ) + ct10.polynomial = _ct10_data.reshape( + _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in + )[..., :_ct10_m].copy() + ct10.batch = ct10.polynomial.shape[0] + ct10.num_elements = ct10.polynomial.shape[1] + ct10.num_moduli = _ct10_m + ct10.degree_layout = (_ct10_r, _ct10_c) + ct10.r = _ct10_r + ct10.c = _ct10_c + ct10.moduli = list(_ct10_moduli)[:_ct10_m] + ct10.moduli_array = jnp.array( + ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) + ) + _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] + _ct10_rhs_m = _ct10_rhs_m_in + _ct10_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_rhs_r) + ) + _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) + if isinstance(_ct10_rhs_moduli, (int, np.integer)): + _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] + ct10_rhs = Polynomial( + { + "batch": _ct10_rhs_data.shape[0], + "num_elements": _ct10_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_rhs_m, + "precision": 32, + "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), + }, + {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, + ) + ct10_rhs.polynomial = _ct10_rhs_data.reshape( + _ct10_rhs_data.shape[0], + _ct10_rhs_data.shape[1], + _ct10_rhs_r, + _ct10_rhs_c, + _ct10_rhs_m_in, + )[..., :_ct10_rhs_m].copy() + ct10_rhs.batch = ct10_rhs.polynomial.shape[0] + ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] + ct10_rhs.num_moduli = _ct10_rhs_m + ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) + ct10_rhs.r = _ct10_rhs_r + ct10_rhs.c = _ct10_rhs_c + ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] + ct10_rhs.moduli_array = jnp.array( + ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) + ) + ct10.add(ct10_rhs) + _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) + ct10.polynomial = jnp.where( + ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial + ) + _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 + _ct11_arg_m_in = _ct11_arg_data.shape[-1] + _ct11_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_arg_m_in + ) + _ct11_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_arg_r) + ) + _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) + if isinstance(_ct11_arg_moduli, (int, np.integer)): + _ct11_arg_moduli = [int(_ct11_arg_moduli)] + ct11_arg = Polynomial( + { + "batch": _ct11_arg_data.shape[0], + "num_elements": _ct11_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_arg_m, + "precision": 32, + "degree_layout": (_ct11_arg_r, _ct11_arg_c), + }, + {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, + ) + ct11_arg.polynomial = _ct11_arg_data.reshape( + _ct11_arg_data.shape[0], + _ct11_arg_data.shape[1], + _ct11_arg_r, + _ct11_arg_c, + _ct11_arg_m_in, + )[..., :_ct11_arg_m].copy() + ct11_arg.batch = ct11_arg.polynomial.shape[0] + ct11_arg.num_elements = ct11_arg.polynomial.shape[1] + ct11_arg.num_moduli = _ct11_arg_m + ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) + ct11_arg.r = _ct11_arg_r + ct11_arg.c = _ct11_arg_c + ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] + ct11_arg.moduli_array = jnp.array( + ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) + ) + ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) + _ct11_data = ( + ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw + ) + _ct11_m_in = _ct11_data.shape[-1] + _ct11_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_m_in + ) + _ct11_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_r) + ) + _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) + if isinstance(_ct11_moduli, (int, np.integer)): + _ct11_moduli = [int(_ct11_moduli)] + ct11 = Polynomial( + { + "batch": _ct11_data.shape[0], + "num_elements": _ct11_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_m, + "precision": 32, + "degree_layout": (_ct11_r, _ct11_c), + }, + {"moduli": list(_ct11_moduli)[:_ct11_m]}, + ) + ct11.polynomial = _ct11_data.reshape( + _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in + )[..., :_ct11_m].copy() + ct11.batch = ct11.polynomial.shape[0] + ct11.num_elements = ct11.polynomial.shape[1] + ct11.num_moduli = _ct11_m + ct11.degree_layout = (_ct11_r, _ct11_c) + ct11.r = _ct11_r + ct11.c = _ct11_c + ct11.moduli = list(_ct11_moduli)[:_ct11_m] + ct11.moduli_array = jnp.array( + ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) + ) + _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct12_arg_m_in = _ct12_arg_data.shape[-1] + _ct12_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_arg_m_in + ) + _ct12_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_arg_r) + ) + _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct12_arg_moduli, (int, np.integer)): + _ct12_arg_moduli = [int(_ct12_arg_moduli)] + ct12_arg = Polynomial( + { + "batch": _ct12_arg_data.shape[0], + "num_elements": _ct12_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_arg_m, + "precision": 32, + "degree_layout": (_ct12_arg_r, _ct12_arg_c), + }, + {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, + ) + ct12_arg.polynomial = _ct12_arg_data.reshape( + _ct12_arg_data.shape[0], + _ct12_arg_data.shape[1], + _ct12_arg_r, + _ct12_arg_c, + _ct12_arg_m_in, + )[..., :_ct12_arg_m].copy() + ct12_arg.batch = ct12_arg.polynomial.shape[0] + ct12_arg.num_elements = ct12_arg.polynomial.shape[1] + ct12_arg.num_moduli = _ct12_arg_m + ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) + ct12_arg.r = _ct12_arg_r + ct12_arg.c = _ct12_arg_c + ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] + ct12_arg.moduli_array = jnp.array( + ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) + ) + ct12_pt_ntt = ( + pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] + .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct12_ptct = v0.ptct_mul[v0.max_level] + ct12_ptct.set_plaintext(ct12_pt_ntt) + ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) + _ct12_data = ( + ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw + ) + _ct12_m_in = _ct12_data.shape[-1] + _ct12_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_m_in + ) + _ct12_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_r) + ) + _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) + if isinstance(_ct12_moduli, (int, np.integer)): + _ct12_moduli = [int(_ct12_moduli)] + ct12 = Polynomial( + { + "batch": _ct12_data.shape[0], + "num_elements": _ct12_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_m, + "precision": 32, + "degree_layout": (_ct12_r, _ct12_c), + }, + {"moduli": list(_ct12_moduli)[:_ct12_m]}, + ) + ct12.polynomial = _ct12_data.reshape( + _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in + )[..., :_ct12_m].copy() + ct12.batch = ct12.polynomial.shape[0] + ct12.num_elements = ct12.polynomial.shape[1] + ct12.num_moduli = _ct12_m + ct12.degree_layout = (_ct12_r, _ct12_c) + ct12.r = _ct12_r + ct12.c = _ct12_c + ct12.moduli = list(_ct12_moduli)[:_ct12_m] + ct12.moduli_array = jnp.array( + ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) + ) + _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct13_arg_m_in = _ct13_arg_data.shape[-1] + _ct13_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_arg_m_in + ) + _ct13_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_arg_r) + ) + _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct13_arg_moduli, (int, np.integer)): + _ct13_arg_moduli = [int(_ct13_arg_moduli)] + ct13_arg = Polynomial( + { + "batch": _ct13_arg_data.shape[0], + "num_elements": _ct13_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_arg_m, + "precision": 32, + "degree_layout": (_ct13_arg_r, _ct13_arg_c), + }, + {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, + ) + ct13_arg.polynomial = _ct13_arg_data.reshape( + _ct13_arg_data.shape[0], + _ct13_arg_data.shape[1], + _ct13_arg_r, + _ct13_arg_c, + _ct13_arg_m_in, + )[..., :_ct13_arg_m].copy() + ct13_arg.batch = ct13_arg.polynomial.shape[0] + ct13_arg.num_elements = ct13_arg.polynomial.shape[1] + ct13_arg.num_moduli = _ct13_arg_m + ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) + ct13_arg.r = _ct13_arg_r + ct13_arg.c = _ct13_arg_c + ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] + ct13_arg.moduli_array = jnp.array( + ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) + ) + ct13_pt_ntt = ( + pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] + .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct13_ptct = v0.ptct_mul[v0.max_level] + ct13_ptct.set_plaintext(ct13_pt_ntt) + ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) + _ct13_data = ( + ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw + ) + _ct13_m_in = _ct13_data.shape[-1] + _ct13_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_m_in + ) + _ct13_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_r) + ) + _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) + if isinstance(_ct13_moduli, (int, np.integer)): + _ct13_moduli = [int(_ct13_moduli)] + ct13 = Polynomial( + { + "batch": _ct13_data.shape[0], + "num_elements": _ct13_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_m, + "precision": 32, + "degree_layout": (_ct13_r, _ct13_c), + }, + {"moduli": list(_ct13_moduli)[:_ct13_m]}, + ) + ct13.polynomial = _ct13_data.reshape( + _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in + )[..., :_ct13_m].copy() + ct13.batch = ct13.polynomial.shape[0] + ct13.num_elements = ct13.polynomial.shape[1] + ct13.num_moduli = _ct13_m + ct13.degree_layout = (_ct13_r, _ct13_c) + ct13.r = _ct13_r + ct13.c = _ct13_c + ct13.moduli = list(_ct13_moduli)[:_ct13_m] + ct13.moduli_array = jnp.array( + ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) + ) + _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + _ct14_m_in = _ct14_data.shape[-1] + _ct14_m = _ct14_m_in + _ct14_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_r) + ) + _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) + if isinstance(_ct14_moduli, (int, np.integer)): + _ct14_moduli = [int(_ct14_moduli)] + ct14 = Polynomial( + { + "batch": _ct14_data.shape[0], + "num_elements": _ct14_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_m, + "precision": 32, + "degree_layout": (_ct14_r, _ct14_c), + }, + {"moduli": list(_ct14_moduli)[:_ct14_m]}, + ) + ct14.polynomial = _ct14_data.reshape( + _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in + )[..., :_ct14_m].copy() + ct14.batch = ct14.polynomial.shape[0] + ct14.num_elements = ct14.polynomial.shape[1] + ct14.num_moduli = _ct14_m + ct14.degree_layout = (_ct14_r, _ct14_c) + ct14.r = _ct14_r + ct14.c = _ct14_c + ct14.moduli = list(_ct14_moduli)[:_ct14_m] + ct14.moduli_array = jnp.array( + ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) + ) + _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] + _ct14_rhs_m = _ct14_rhs_m_in + _ct14_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_rhs_r) + ) + _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) + if isinstance(_ct14_rhs_moduli, (int, np.integer)): + _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] + ct14_rhs = Polynomial( + { + "batch": _ct14_rhs_data.shape[0], + "num_elements": _ct14_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_rhs_m, + "precision": 32, + "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), + }, + {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, + ) + ct14_rhs.polynomial = _ct14_rhs_data.reshape( + _ct14_rhs_data.shape[0], + _ct14_rhs_data.shape[1], + _ct14_rhs_r, + _ct14_rhs_c, + _ct14_rhs_m_in, + )[..., :_ct14_rhs_m].copy() + ct14_rhs.batch = ct14_rhs.polynomial.shape[0] + ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] + ct14_rhs.num_moduli = _ct14_rhs_m + ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) + ct14_rhs.r = _ct14_rhs_r + ct14_rhs.c = _ct14_rhs_c + ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] + ct14_rhs.moduli_array = jnp.array( + ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) + ) + ct14.add(ct14_rhs) + _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) + ct14.polynomial = jnp.where( + ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial + ) + _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 + _ct15_arg_m_in = _ct15_arg_data.shape[-1] + _ct15_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_arg_m_in + ) + _ct15_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_arg_r) + ) + _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) + if isinstance(_ct15_arg_moduli, (int, np.integer)): + _ct15_arg_moduli = [int(_ct15_arg_moduli)] + ct15_arg = Polynomial( + { + "batch": _ct15_arg_data.shape[0], + "num_elements": _ct15_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_arg_m, + "precision": 32, + "degree_layout": (_ct15_arg_r, _ct15_arg_c), + }, + {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, + ) + ct15_arg.polynomial = _ct15_arg_data.reshape( + _ct15_arg_data.shape[0], + _ct15_arg_data.shape[1], + _ct15_arg_r, + _ct15_arg_c, + _ct15_arg_m_in, + )[..., :_ct15_arg_m].copy() + ct15_arg.batch = ct15_arg.polynomial.shape[0] + ct15_arg.num_elements = ct15_arg.polynomial.shape[1] + ct15_arg.num_moduli = _ct15_arg_m + ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) + ct15_arg.r = _ct15_arg_r + ct15_arg.c = _ct15_arg_c + ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] + ct15_arg.moduli_array = jnp.array( + ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) + ) + ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) + _ct15_data = ( + ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw + ) + _ct15_m_in = _ct15_data.shape[-1] + _ct15_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_m_in + ) + _ct15_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_r) + ) + _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) + if isinstance(_ct15_moduli, (int, np.integer)): + _ct15_moduli = [int(_ct15_moduli)] + ct15 = Polynomial( + { + "batch": _ct15_data.shape[0], + "num_elements": _ct15_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_m, + "precision": 32, + "degree_layout": (_ct15_r, _ct15_c), + }, + {"moduli": list(_ct15_moduli)[:_ct15_m]}, + ) + ct15.polynomial = _ct15_data.reshape( + _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in + )[..., :_ct15_m].copy() + ct15.batch = ct15.polynomial.shape[0] + ct15.num_elements = ct15.polynomial.shape[1] + ct15.num_moduli = _ct15_m + ct15.degree_layout = (_ct15_r, _ct15_c) + ct15.r = _ct15_r + ct15.c = _ct15_c + ct15.moduli = list(_ct15_moduli)[:_ct15_m] + ct15.moduli_array = jnp.array( + ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) + ) + _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + _ct16_m_in = _ct16_data.shape[-1] + _ct16_m = _ct16_m_in + _ct16_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_r) + ) + _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) + if isinstance(_ct16_moduli, (int, np.integer)): + _ct16_moduli = [int(_ct16_moduli)] + ct16 = Polynomial( + { + "batch": _ct16_data.shape[0], + "num_elements": _ct16_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_m, + "precision": 32, + "degree_layout": (_ct16_r, _ct16_c), + }, + {"moduli": list(_ct16_moduli)[:_ct16_m]}, + ) + ct16.polynomial = _ct16_data.reshape( + _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in + )[..., :_ct16_m].copy() + ct16.batch = ct16.polynomial.shape[0] + ct16.num_elements = ct16.polynomial.shape[1] + ct16.num_moduli = _ct16_m + ct16.degree_layout = (_ct16_r, _ct16_c) + ct16.r = _ct16_r + ct16.c = _ct16_c + ct16.moduli = list(_ct16_moduli)[:_ct16_m] + ct16.moduli_array = jnp.array( + ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) + ) + _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 + _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] + _ct16_rhs_m = _ct16_rhs_m_in + _ct16_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_rhs_r) + ) + _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) + if isinstance(_ct16_rhs_moduli, (int, np.integer)): + _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] + ct16_rhs = Polynomial( + { + "batch": _ct16_rhs_data.shape[0], + "num_elements": _ct16_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_rhs_m, + "precision": 32, + "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), + }, + {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, + ) + ct16_rhs.polynomial = _ct16_rhs_data.reshape( + _ct16_rhs_data.shape[0], + _ct16_rhs_data.shape[1], + _ct16_rhs_r, + _ct16_rhs_c, + _ct16_rhs_m_in, + )[..., :_ct16_rhs_m].copy() + ct16_rhs.batch = ct16_rhs.polynomial.shape[0] + ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] + ct16_rhs.num_moduli = _ct16_rhs_m + ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) + ct16_rhs.r = _ct16_rhs_r + ct16_rhs.c = _ct16_rhs_c + ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] + ct16_rhs.moduli_array = jnp.array( + ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) + ) + ct16.add(ct16_rhs) + _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) + ct16.polynomial = jnp.where( + ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial + ) + _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + _ct17_m_in = _ct17_data.shape[-1] + _ct17_m = _ct17_m_in + _ct17_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_r) + ) + _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) + if isinstance(_ct17_moduli, (int, np.integer)): + _ct17_moduli = [int(_ct17_moduli)] + ct17 = Polynomial( + { + "batch": _ct17_data.shape[0], + "num_elements": _ct17_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_m, + "precision": 32, + "degree_layout": (_ct17_r, _ct17_c), + }, + {"moduli": list(_ct17_moduli)[:_ct17_m]}, + ) + ct17.polynomial = _ct17_data.reshape( + _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in + )[..., :_ct17_m].copy() + ct17.batch = ct17.polynomial.shape[0] + ct17.num_elements = ct17.polynomial.shape[1] + ct17.num_moduli = _ct17_m + ct17.degree_layout = (_ct17_r, _ct17_c) + ct17.r = _ct17_r + ct17.c = _ct17_c + ct17.moduli = list(_ct17_moduli)[:_ct17_m] + ct17.moduli_array = jnp.array( + ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) + ) + _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] + _ct17_rhs_m = _ct17_rhs_m_in + _ct17_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_rhs_r) + ) + _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) + if isinstance(_ct17_rhs_moduli, (int, np.integer)): + _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] + ct17_rhs = Polynomial( + { + "batch": _ct17_rhs_data.shape[0], + "num_elements": _ct17_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_rhs_m, + "precision": 32, + "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), + }, + {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, + ) + ct17_rhs.polynomial = _ct17_rhs_data.reshape( + _ct17_rhs_data.shape[0], + _ct17_rhs_data.shape[1], + _ct17_rhs_r, + _ct17_rhs_c, + _ct17_rhs_m_in, + )[..., :_ct17_rhs_m].copy() + ct17_rhs.batch = ct17_rhs.polynomial.shape[0] + ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] + ct17_rhs.num_moduli = _ct17_rhs_m + ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) + ct17_rhs.r = _ct17_rhs_r + ct17_rhs.c = _ct17_rhs_c + ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] + ct17_rhs.moduli_array = jnp.array( + ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) + ) + ct17.add(ct17_rhs) + _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) + ct17.polynomial = jnp.where( + ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial + ) + _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 + _ct18_m_in = _ct18_data.shape[-1] + _ct18_m = _ct18_m_in + _ct18_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_r) + ) + _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) + if isinstance(_ct18_moduli, (int, np.integer)): + _ct18_moduli = [int(_ct18_moduli)] + ct18 = Polynomial( + { + "batch": _ct18_data.shape[0], + "num_elements": _ct18_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_m, + "precision": 32, + "degree_layout": (_ct18_r, _ct18_c), + }, + {"moduli": list(_ct18_moduli)[:_ct18_m]}, + ) + ct18.polynomial = _ct18_data.reshape( + _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in + )[..., :_ct18_m].copy() + ct18.batch = ct18.polynomial.shape[0] + ct18.num_elements = ct18.polynomial.shape[1] + ct18.num_moduli = _ct18_m + ct18.degree_layout = (_ct18_r, _ct18_c) + ct18.r = _ct18_r + ct18.c = _ct18_c + ct18.moduli = list(_ct18_moduli)[:_ct18_m] + ct18.moduli_array = jnp.array( + ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) + ) + _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 + _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] + _ct18_rhs_m = _ct18_rhs_m_in + _ct18_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_rhs_r) + ) + _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) + if isinstance(_ct18_rhs_moduli, (int, np.integer)): + _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] + ct18_rhs = Polynomial( + { + "batch": _ct18_rhs_data.shape[0], + "num_elements": _ct18_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_rhs_m, + "precision": 32, + "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), + }, + {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, + ) + ct18_rhs.polynomial = _ct18_rhs_data.reshape( + _ct18_rhs_data.shape[0], + _ct18_rhs_data.shape[1], + _ct18_rhs_r, + _ct18_rhs_c, + _ct18_rhs_m_in, + )[..., :_ct18_rhs_m].copy() + ct18_rhs.batch = ct18_rhs.polynomial.shape[0] + ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] + ct18_rhs.num_moduli = _ct18_rhs_m + ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) + ct18_rhs.r = _ct18_rhs_r + ct18_rhs.c = _ct18_rhs_c + ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] + ct18_rhs.moduli_array = jnp.array( + ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) + ) + ct18.add(ct18_rhs) + _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) + ct18.polynomial = jnp.where( + ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial + ) + _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 + _ct19_m_in = _ct19_data.shape[-1] + _ct19_m = _ct19_m_in + _ct19_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_r) + ) + _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) + if isinstance(_ct19_moduli, (int, np.integer)): + _ct19_moduli = [int(_ct19_moduli)] + ct19 = Polynomial( + { + "batch": _ct19_data.shape[0], + "num_elements": _ct19_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_m, + "precision": 32, + "degree_layout": (_ct19_r, _ct19_c), + }, + {"moduli": list(_ct19_moduli)[:_ct19_m]}, + ) + ct19.polynomial = _ct19_data.reshape( + _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in + )[..., :_ct19_m].copy() + ct19.batch = ct19.polynomial.shape[0] + ct19.num_elements = ct19.polynomial.shape[1] + ct19.num_moduli = _ct19_m + ct19.degree_layout = (_ct19_r, _ct19_c) + ct19.r = _ct19_r + ct19.c = _ct19_c + ct19.moduli = list(_ct19_moduli)[:_ct19_m] + ct19.moduli_array = jnp.array( + ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) + ) + _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 + _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] + _ct19_rhs_m = _ct19_rhs_m_in + _ct19_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_rhs_r) + ) + _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) + if isinstance(_ct19_rhs_moduli, (int, np.integer)): + _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] + ct19_rhs = Polynomial( + { + "batch": _ct19_rhs_data.shape[0], + "num_elements": _ct19_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_rhs_m, + "precision": 32, + "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), + }, + {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, + ) + ct19_rhs.polynomial = _ct19_rhs_data.reshape( + _ct19_rhs_data.shape[0], + _ct19_rhs_data.shape[1], + _ct19_rhs_r, + _ct19_rhs_c, + _ct19_rhs_m_in, + )[..., :_ct19_rhs_m].copy() + ct19_rhs.batch = ct19_rhs.polynomial.shape[0] + ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] + ct19_rhs.num_moduli = _ct19_rhs_m + ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) + ct19_rhs.r = _ct19_rhs_r + ct19_rhs.c = _ct19_rhs_c + ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] + ct19_rhs.moduli_array = jnp.array( + ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) + ) + ct19.add(ct19_rhs) + _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) + ct19.polynomial = jnp.where( + ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial + ) + v16 = [None] * 1 + _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 + _ct20_arg_m_in = _ct20_arg_data.shape[-1] + _ct20_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct20_arg_m_in + ) + _ct20_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_arg_r) + ) + _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) + if isinstance(_ct20_arg_moduli, (int, np.integer)): + _ct20_arg_moduli = [int(_ct20_arg_moduli)] + ct20_arg = Polynomial( + { + "batch": _ct20_arg_data.shape[0], + "num_elements": _ct20_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_arg_m, + "precision": 32, + "degree_layout": (_ct20_arg_r, _ct20_arg_c), + }, + {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, + ) + ct20_arg.polynomial = _ct20_arg_data.reshape( + _ct20_arg_data.shape[0], + _ct20_arg_data.shape[1], + _ct20_arg_r, + _ct20_arg_c, + _ct20_arg_m_in, + )[..., :_ct20_arg_m].copy() + ct20_arg.batch = ct20_arg.polynomial.shape[0] + ct20_arg.num_elements = ct20_arg.polynomial.shape[1] + ct20_arg.num_moduli = _ct20_arg_m + ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) + ct20_arg.r = _ct20_arg_r + ct20_arg.c = _ct20_arg_c + ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] + ct20_arg.moduli_array = jnp.array( + ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) + ) + ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) + _ct20_data = ( + ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw + ) + _ct20_m_in = _ct20_data.shape[-1] + _ct20_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct20_m_in + ) + _ct20_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_r) + ) + _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) + if isinstance(_ct20_moduli, (int, np.integer)): + _ct20_moduli = [int(_ct20_moduli)] + ct20 = Polynomial( + { + "batch": _ct20_data.shape[0], + "num_elements": _ct20_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_m, + "precision": 32, + "degree_layout": (_ct20_r, _ct20_c), + }, + {"moduli": list(_ct20_moduli)[:_ct20_m]}, + ) + ct20.polynomial = _ct20_data.reshape( + _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in + )[..., :_ct20_m].copy() + ct20.batch = ct20.polynomial.shape[0] + ct20.num_elements = ct20.polynomial.shape[1] + ct20.num_moduli = _ct20_m + ct20.degree_layout = (_ct20_r, _ct20_c) + ct20.r = _ct20_r + ct20.c = _ct20_c + ct20.moduli = list(_ct20_moduli)[:_ct20_m] + ct20.moduli_array = jnp.array( + ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) + ) + v16[0] = ct20 + v17 = v16 + return v17 + + +def matvec_random( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_random__preprocessing(v0, v1) + v11 = matvec_random__preprocessed(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) + return v11 + + +def matvec_random__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw + _ct_m_in = _ct_data.shape[-1] + _ct_m = _ct_m_in + _ct_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct_r) + ) + _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) + if isinstance(_ct_moduli, (int, np.integer)): + _ct_moduli = [int(_ct_moduli)] + ct = Polynomial( + { + "batch": _ct_data.shape[0], + "num_elements": _ct_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct_m, + "precision": 32, + "degree_layout": (_ct_r, _ct_c), + }, + {"moduli": list(_ct_moduli)[:_ct_m]}, + ) + ct.polynomial = _ct_data.reshape( + _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in + )[..., :_ct_m].copy() + ct.batch = ct.polynomial.shape[0] + ct.num_elements = ct.polynomial.shape[1] + ct.num_moduli = _ct_m + ct.degree_layout = (_ct_r, _ct_c) + ct.r = _ct_r + ct.c = _ct_c + ct.moduli = list(_ct_moduli)[:_ct_m] + ct.moduli_array = jnp.array( + ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) + ) + v16 = [ct] + return v16 + + +def matvec_random__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 0 + v8 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + _num_moduli = ct.polynomial.shape[-1] + _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": ct.polynomial.shape[0], + "num_elements": ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + ct.polynomial.reshape( + ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli + ) + ) + pt = v0.decrypt(_ct_for_dec) + v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v10 = v8.copy() + for v11 in range(0, 8): + v13 = int(v11) + v14 = v9[0, v13] + v10[v13] = v14 + return v10 + + +def matvec_chain__preprocessing( + v0: ckks.CKKSContext, + v1: dict, +) -> ( + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, +): + v2 = np.array( + [ + 1.340000e00, + 5.800000e-01, + 1.260000e00, + 7.400000e-01, + 6.900000e-01, + 1.070000e00, + 6.000000e-01, + 1.390000e00, + 1.130000e00, + 1.220000e00, + 5.200000e-01, + 1.090000e00, + 1.060000e00, + 6.600000e-01, + 6.500000e-01, + 1.200000e00, + 8.200000e-01, + 1.190000e00, + 1.050000e00, + 8.900000e-01, + 1.430000e00, + 1.340000e00, + 8.600000e-01, + 5.400000e-01, + 8.000000e-01, + 9.000000e-01, + 1.200000e00, + 1.500000e00, + 8.600000e-01, + 1.260000e00, + 1.090000e00, + 1.190000e00, + 6.500000e-01, + 9.000000e-01, + 7.400000e-01, + 8.400000e-01, + 1.010000e00, + 1.170000e00, + 6.100000e-01, + 6.300000e-01, + 8.200000e-01, + 1.160000e00, + 1.350000e00, + 1.050000e00, + 1.350000e00, + 8.800000e-01, + 8.200000e-01, + 8.500000e-01, + 6.700000e-01, + 1.330000e00, + 8.400000e-01, + 1.050000e00, + 1.080000e00, + 1.020000e00, + 5.000000e-01, + 1.490000e00, + 1.410000e00, + 7.100000e-01, + 7.900000e-01, + 1.020000e00, + 1.400000e00, + 1.480000e00, + 7.600000e-01, + 1.060000e00, + ], + dtype=np.float32, + ).reshape(8, 8) + v3 = np.array( + [ + 1.200000e00, + 7.900000e-01, + 7.300000e-01, + 1.050000e00, + 1.220000e00, + 9.200000e-01, + 1.480000e00, + 1.180000e00, + 9.800000e-01, + 8.900000e-01, + 8.400000e-01, + 1.230000e00, + 9.400000e-01, + 5.600000e-01, + 9.000000e-01, + 1.240000e00, + 6.800000e-01, + 6.800000e-01, + 1.030000e00, + 1.030000e00, + 1.130000e00, + 1.350000e00, + 1.220000e00, + 1.110000e00, + 1.220000e00, + 8.200000e-01, + 8.600000e-01, + 7.300000e-01, + 7.900000e-01, + 1.130000e00, + 5.900000e-01, + 9.300000e-01, + 9.300000e-01, + 9.900000e-01, + 9.300000e-01, + 8.100000e-01, + 9.300000e-01, + 1.390000e00, + 1.440000e00, + 1.000000e00, + 1.120000e00, + 6.200000e-01, + 8.200000e-01, + 9.100000e-01, + 1.370000e00, + 7.500000e-01, + 9.800000e-01, + 1.490000e00, + 1.020000e00, + 1.110000e00, + 6.200000e-01, + 1.330000e00, + 1.100000e00, + 1.050000e00, + 8.400000e-01, + 8.000000e-01, + 9.200000e-01, + 1.180000e00, + 1.380000e00, + 1.010000e00, + 1.170000e00, + 1.090000e00, + 1.120000e00, + 1.170000e00, + ], + dtype=np.float32, + ).reshape(8, 8) + v4 = _assign_layout_15335824159471298539(v2) + v5 = _assign_layout_15335824159471298539(v3) + v6 = v5[3 : 3 + 1, 0 : 0 + 5] + v7 = v5[3 : 3 + 1, 5 : 5 + 3] + v8 = np.zeros( + ( + 1, + 8, + ), + dtype=np.float32, + ) + v9 = v8.copy() + v9[0 : 0 + 1, 3 : 3 + 5] = v6 + v10 = v9.copy() + v10[0 : 0 + 1, 0 : 0 + 3] = v7 + v11 = v5[4 : 4 + 1, 0 : 0 + 5] + v12 = v5[4 : 4 + 1, 5 : 5 + 3] + v13 = v8.copy() + v13[0 : 0 + 1, 3 : 3 + 5] = v11 + v14 = v13.copy() + v14[0 : 0 + 1, 0 : 0 + 3] = v12 + v15 = v5[5 : 5 + 1, 0 : 0 + 5] + v16 = v5[5 : 5 + 1, 5 : 5 + 3] + v17 = v8.copy() + v17[0 : 0 + 1, 3 : 3 + 5] = v15 + v18 = v17.copy() + v18[0 : 0 + 1, 0 : 0 + 3] = v16 + v19 = v5[6 : 6 + 1, 0 : 0 + 2] + v20 = v5[6 : 6 + 1, 2 : 2 + 6] + v21 = v8.copy() + v21[0 : 0 + 1, 6 : 6 + 2] = v19 + v22 = v21.copy() + v22[0 : 0 + 1, 0 : 0 + 6] = v20 + v23 = v5[7 : 7 + 1, 0 : 0 + 2] + v24 = v5[7 : 7 + 1, 2 : 2 + 6] + v25 = v8.copy() + v25[0 : 0 + 1, 6 : 6 + 2] = v23 + v26 = v25.copy() + v26[0 : 0 + 1, 0 : 0 + 6] = v24 + v27 = v4[3 : 3 + 1, 0 : 0 + 5] + v28 = v4[3 : 3 + 1, 5 : 5 + 3] + v29 = v8.copy() + v29[0 : 0 + 1, 3 : 3 + 5] = v27 + v30 = v29.copy() + v30[0 : 0 + 1, 0 : 0 + 3] = v28 + v31 = v4[4 : 4 + 1, 0 : 0 + 5] + v32 = v4[4 : 4 + 1, 5 : 5 + 3] + v33 = v8.copy() + v33[0 : 0 + 1, 3 : 3 + 5] = v31 + v34 = v33.copy() + v34[0 : 0 + 1, 0 : 0 + 3] = v32 + v35 = v4[5 : 5 + 1, 0 : 0 + 5] + v36 = v4[5 : 5 + 1, 5 : 5 + 3] + v37 = v8.copy() + v37[0 : 0 + 1, 3 : 3 + 5] = v35 + v38 = v37.copy() + v38[0 : 0 + 1, 0 : 0 + 3] = v36 + v39 = v4[6 : 6 + 1, 0 : 0 + 2] + v40 = v4[6 : 6 + 1, 2 : 2 + 6] + v41 = v8.copy() + v41[0 : 0 + 1, 6 : 6 + 2] = v39 + v42 = v41.copy() + v42[0 : 0 + 1, 0 : 0 + 6] = v40 + v43 = v4[7 : 7 + 1, 0 : 0 + 2] + v44 = v4[7 : 7 + 1, 2 : 2 + 6] + v45 = v8.copy() + v45[0 : 0 + 1, 6 : 6 + 2] = v43 + v46 = v45.copy() + v46[0 : 0 + 1, 0 : 0 + 6] = v44 + v47 = v4[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v47) + v48 = v4[1 : 1 + 1, 0 : 0 + 8].reshape(8) + pt1 = v0.encode(v48) + v49 = v4[2 : 2 + 1, 0 : 0 + 8].reshape(8) + pt2 = v0.encode(v49) + v50 = v30[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt3 = v0.encode(v50) + v51 = v34[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt4 = v0.encode(v51) + v52 = v38[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt5 = v0.encode(v52) + v53 = v42[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt6 = v0.encode(v53) + v54 = v46[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt7 = v0.encode(v54) + v55 = v5[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt8 = v0.encode(v55) + v56 = v5[1 : 1 + 1, 0 : 0 + 8].reshape(8) + pt9 = v0.encode(v56) + v57 = v5[2 : 2 + 1, 0 : 0 + 8].reshape(8) + pt10 = v0.encode(v57) + v58 = v10[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt11 = v0.encode(v58) + v59 = v14[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt12 = v0.encode(v59) + v60 = v18[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt13 = v0.encode(v60) + v61 = v22[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt14 = v0.encode(v61) + v62 = v26[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt15 = v0.encode(v62) + v63 = [pt] + v64 = [pt1] + v65 = [pt2] + v66 = [pt3] + v67 = [pt4] + v68 = [pt5] + v69 = [pt6] + v70 = [pt7] + v71 = [pt8, pt9] + v72 = [pt10, pt11] + v73 = [pt12, pt13] + v74 = [pt14, pt15] + return (v63, v64, v65, v66, v67, v68, v69, v70, v71, v72, v73, v74) + + +def matvec_chain__preprocessed( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, + v4: np.ndarray, + v5: np.ndarray, + v6: np.ndarray, + v7: np.ndarray, + v8: np.ndarray, + v9: np.ndarray, + v10: np.ndarray, + v11: np.ndarray, + v12: np.ndarray, + v13: np.ndarray, + v14: np.ndarray, +) -> np.ndarray: + v15 = 1 + v16 = 2 + v17 = 3 + v18 = 6 + v19 = 0 + pt = v3[0] + pt1 = v4[0] + pt2 = v5[0] + pt3 = v6[0] + pt4 = v7[0] + pt5 = v8[0] + pt6 = v9[0] + pt7 = v10[0] + pt8 = v11[0] + pt9 = v11[1] + pt10 = v12[0] + pt11 = v12[1] + pt12 = v13[0] + pt13 = v13[1] + pt14 = v14[0] + pt15 = v14[1] + ct = v2[0] + _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct1_arg_m_in = _ct1_arg_data.shape[-1] + _ct1_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_arg_m_in + ) + _ct1_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_arg_r) + ) + _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct1_arg_moduli, (int, np.integer)): + _ct1_arg_moduli = [int(_ct1_arg_moduli)] + ct1_arg = Polynomial( + { + "batch": _ct1_arg_data.shape[0], + "num_elements": _ct1_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_arg_m, + "precision": 32, + "degree_layout": (_ct1_arg_r, _ct1_arg_c), + }, + {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, + ) + ct1_arg.polynomial = _ct1_arg_data.reshape( + _ct1_arg_data.shape[0], + _ct1_arg_data.shape[1], + _ct1_arg_r, + _ct1_arg_c, + _ct1_arg_m_in, + )[..., :_ct1_arg_m].copy() + ct1_arg.batch = ct1_arg.polynomial.shape[0] + ct1_arg.num_elements = ct1_arg.polynomial.shape[1] + ct1_arg.num_moduli = _ct1_arg_m + ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) + ct1_arg.r = _ct1_arg_r + ct1_arg.c = _ct1_arg_c + ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] + ct1_arg.moduli_array = jnp.array( + ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) + ) + ct1_pt_ntt = ( + pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] + .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct1_ptct = v0.ptct_mul[v0.max_level] + ct1_ptct.set_plaintext(ct1_pt_ntt) + ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) + _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw + _ct1_m_in = _ct1_data.shape[-1] + _ct1_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct1_m_in + ) + _ct1_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct1_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct1_r) + ) + _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) + if isinstance(_ct1_moduli, (int, np.integer)): + _ct1_moduli = [int(_ct1_moduli)] + ct1 = Polynomial( + { + "batch": _ct1_data.shape[0], + "num_elements": _ct1_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct1_m, + "precision": 32, + "degree_layout": (_ct1_r, _ct1_c), + }, + {"moduli": list(_ct1_moduli)[:_ct1_m]}, + ) + ct1.polynomial = _ct1_data.reshape( + _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in + )[..., :_ct1_m].copy() + ct1.batch = ct1.polynomial.shape[0] + ct1.num_elements = ct1.polynomial.shape[1] + ct1.num_moduli = _ct1_m + ct1.degree_layout = (_ct1_r, _ct1_c) + ct1.r = _ct1_r + ct1.c = _ct1_c + ct1.moduli = list(_ct1_moduli)[:_ct1_m] + ct1.moduli_array = jnp.array( + ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) + ) + _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct2_arg_m_in = _ct2_arg_data.shape[-1] + _ct2_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_arg_m_in + ) + _ct2_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_arg_r) + ) + _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct2_arg_moduli, (int, np.integer)): + _ct2_arg_moduli = [int(_ct2_arg_moduli)] + ct2_arg = Polynomial( + { + "batch": _ct2_arg_data.shape[0], + "num_elements": _ct2_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_arg_m, + "precision": 32, + "degree_layout": (_ct2_arg_r, _ct2_arg_c), + }, + {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, + ) + ct2_arg.polynomial = _ct2_arg_data.reshape( + _ct2_arg_data.shape[0], + _ct2_arg_data.shape[1], + _ct2_arg_r, + _ct2_arg_c, + _ct2_arg_m_in, + )[..., :_ct2_arg_m].copy() + ct2_arg.batch = ct2_arg.polynomial.shape[0] + ct2_arg.num_elements = ct2_arg.polynomial.shape[1] + ct2_arg.num_moduli = _ct2_arg_m + ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) + ct2_arg.r = _ct2_arg_r + ct2_arg.c = _ct2_arg_c + ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] + ct2_arg.moduli_array = jnp.array( + ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) + ) + ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) + _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw + _ct2_m_in = _ct2_data.shape[-1] + _ct2_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct2_m_in + ) + _ct2_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct2_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct2_r) + ) + _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) + if isinstance(_ct2_moduli, (int, np.integer)): + _ct2_moduli = [int(_ct2_moduli)] + ct2 = Polynomial( + { + "batch": _ct2_data.shape[0], + "num_elements": _ct2_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct2_m, + "precision": 32, + "degree_layout": (_ct2_r, _ct2_c), + }, + {"moduli": list(_ct2_moduli)[:_ct2_m]}, + ) + ct2.polynomial = _ct2_data.reshape( + _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in + )[..., :_ct2_m].copy() + ct2.batch = ct2.polynomial.shape[0] + ct2.num_elements = ct2.polynomial.shape[1] + ct2.num_moduli = _ct2_m + ct2.degree_layout = (_ct2_r, _ct2_c) + ct2.r = _ct2_r + ct2.c = _ct2_c + ct2.moduli = list(_ct2_moduli)[:_ct2_m] + ct2.moduli_array = jnp.array( + ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) + ) + _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct3_arg_m_in = _ct3_arg_data.shape[-1] + _ct3_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_arg_m_in + ) + _ct3_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_arg_r) + ) + _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct3_arg_moduli, (int, np.integer)): + _ct3_arg_moduli = [int(_ct3_arg_moduli)] + ct3_arg = Polynomial( + { + "batch": _ct3_arg_data.shape[0], + "num_elements": _ct3_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_arg_m, + "precision": 32, + "degree_layout": (_ct3_arg_r, _ct3_arg_c), + }, + {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, + ) + ct3_arg.polynomial = _ct3_arg_data.reshape( + _ct3_arg_data.shape[0], + _ct3_arg_data.shape[1], + _ct3_arg_r, + _ct3_arg_c, + _ct3_arg_m_in, + )[..., :_ct3_arg_m].copy() + ct3_arg.batch = ct3_arg.polynomial.shape[0] + ct3_arg.num_elements = ct3_arg.polynomial.shape[1] + ct3_arg.num_moduli = _ct3_arg_m + ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) + ct3_arg.r = _ct3_arg_r + ct3_arg.c = _ct3_arg_c + ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] + ct3_arg.moduli_array = jnp.array( + ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) + ) + ct3_pt_ntt = ( + pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] + .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct3_ptct = v0.ptct_mul[v0.max_level] + ct3_ptct.set_plaintext(ct3_pt_ntt) + ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) + _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw + _ct3_m_in = _ct3_data.shape[-1] + _ct3_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct3_m_in + ) + _ct3_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct3_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct3_r) + ) + _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) + if isinstance(_ct3_moduli, (int, np.integer)): + _ct3_moduli = [int(_ct3_moduli)] + ct3 = Polynomial( + { + "batch": _ct3_data.shape[0], + "num_elements": _ct3_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct3_m, + "precision": 32, + "degree_layout": (_ct3_r, _ct3_c), + }, + {"moduli": list(_ct3_moduli)[:_ct3_m]}, + ) + ct3.polynomial = _ct3_data.reshape( + _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in + )[..., :_ct3_m].copy() + ct3.batch = ct3.polynomial.shape[0] + ct3.num_elements = ct3.polynomial.shape[1] + ct3.num_moduli = _ct3_m + ct3.degree_layout = (_ct3_r, _ct3_c) + ct3.r = _ct3_r + ct3.c = _ct3_c + ct3.moduli = list(_ct3_moduli)[:_ct3_m] + ct3.moduli_array = jnp.array( + ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) + ) + _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct4_arg_m_in = _ct4_arg_data.shape[-1] + _ct4_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_arg_m_in + ) + _ct4_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_arg_r) + ) + _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct4_arg_moduli, (int, np.integer)): + _ct4_arg_moduli = [int(_ct4_arg_moduli)] + ct4_arg = Polynomial( + { + "batch": _ct4_arg_data.shape[0], + "num_elements": _ct4_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_arg_m, + "precision": 32, + "degree_layout": (_ct4_arg_r, _ct4_arg_c), + }, + {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, + ) + ct4_arg.polynomial = _ct4_arg_data.reshape( + _ct4_arg_data.shape[0], + _ct4_arg_data.shape[1], + _ct4_arg_r, + _ct4_arg_c, + _ct4_arg_m_in, + )[..., :_ct4_arg_m].copy() + ct4_arg.batch = ct4_arg.polynomial.shape[0] + ct4_arg.num_elements = ct4_arg.polynomial.shape[1] + ct4_arg.num_moduli = _ct4_arg_m + ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) + ct4_arg.r = _ct4_arg_r + ct4_arg.c = _ct4_arg_c + ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] + ct4_arg.moduli_array = jnp.array( + ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) + ) + ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) + _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw + _ct4_m_in = _ct4_data.shape[-1] + _ct4_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct4_m_in + ) + _ct4_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct4_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct4_r) + ) + _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) + if isinstance(_ct4_moduli, (int, np.integer)): + _ct4_moduli = [int(_ct4_moduli)] + ct4 = Polynomial( + { + "batch": _ct4_data.shape[0], + "num_elements": _ct4_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct4_m, + "precision": 32, + "degree_layout": (_ct4_r, _ct4_c), + }, + {"moduli": list(_ct4_moduli)[:_ct4_m]}, + ) + ct4.polynomial = _ct4_data.reshape( + _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in + )[..., :_ct4_m].copy() + ct4.batch = ct4.polynomial.shape[0] + ct4.num_elements = ct4.polynomial.shape[1] + ct4.num_moduli = _ct4_m + ct4.degree_layout = (_ct4_r, _ct4_c) + ct4.r = _ct4_r + ct4.c = _ct4_c + ct4.moduli = list(_ct4_moduli)[:_ct4_m] + ct4.moduli_array = jnp.array( + ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) + ) + _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct5_arg_m_in = _ct5_arg_data.shape[-1] + _ct5_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_arg_m_in + ) + _ct5_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_arg_r) + ) + _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct5_arg_moduli, (int, np.integer)): + _ct5_arg_moduli = [int(_ct5_arg_moduli)] + ct5_arg = Polynomial( + { + "batch": _ct5_arg_data.shape[0], + "num_elements": _ct5_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_arg_m, + "precision": 32, + "degree_layout": (_ct5_arg_r, _ct5_arg_c), + }, + {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, + ) + ct5_arg.polynomial = _ct5_arg_data.reshape( + _ct5_arg_data.shape[0], + _ct5_arg_data.shape[1], + _ct5_arg_r, + _ct5_arg_c, + _ct5_arg_m_in, + )[..., :_ct5_arg_m].copy() + ct5_arg.batch = ct5_arg.polynomial.shape[0] + ct5_arg.num_elements = ct5_arg.polynomial.shape[1] + ct5_arg.num_moduli = _ct5_arg_m + ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) + ct5_arg.r = _ct5_arg_r + ct5_arg.c = _ct5_arg_c + ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] + ct5_arg.moduli_array = jnp.array( + ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) + ) + ct5_pt_ntt = ( + pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] + .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct5_ptct = v0.ptct_mul[v0.max_level] + ct5_ptct.set_plaintext(ct5_pt_ntt) + ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) + _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw + _ct5_m_in = _ct5_data.shape[-1] + _ct5_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct5_m_in + ) + _ct5_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct5_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct5_r) + ) + _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) + if isinstance(_ct5_moduli, (int, np.integer)): + _ct5_moduli = [int(_ct5_moduli)] + ct5 = Polynomial( + { + "batch": _ct5_data.shape[0], + "num_elements": _ct5_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct5_m, + "precision": 32, + "degree_layout": (_ct5_r, _ct5_c), + }, + {"moduli": list(_ct5_moduli)[:_ct5_m]}, + ) + ct5.polynomial = _ct5_data.reshape( + _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in + )[..., :_ct5_m].copy() + ct5.batch = ct5.polynomial.shape[0] + ct5.num_elements = ct5.polynomial.shape[1] + ct5.num_moduli = _ct5_m + ct5.degree_layout = (_ct5_r, _ct5_c) + ct5.r = _ct5_r + ct5.c = _ct5_c + ct5.moduli = list(_ct5_moduli)[:_ct5_m] + ct5.moduli_array = jnp.array( + ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) + ) + _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct6_arg_m_in = _ct6_arg_data.shape[-1] + _ct6_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_arg_m_in + ) + _ct6_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_arg_r) + ) + _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct6_arg_moduli, (int, np.integer)): + _ct6_arg_moduli = [int(_ct6_arg_moduli)] + ct6_arg = Polynomial( + { + "batch": _ct6_arg_data.shape[0], + "num_elements": _ct6_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_arg_m, + "precision": 32, + "degree_layout": (_ct6_arg_r, _ct6_arg_c), + }, + {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, + ) + ct6_arg.polynomial = _ct6_arg_data.reshape( + _ct6_arg_data.shape[0], + _ct6_arg_data.shape[1], + _ct6_arg_r, + _ct6_arg_c, + _ct6_arg_m_in, + )[..., :_ct6_arg_m].copy() + ct6_arg.batch = ct6_arg.polynomial.shape[0] + ct6_arg.num_elements = ct6_arg.polynomial.shape[1] + ct6_arg.num_moduli = _ct6_arg_m + ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) + ct6_arg.r = _ct6_arg_r + ct6_arg.c = _ct6_arg_c + ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] + ct6_arg.moduli_array = jnp.array( + ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) + ) + ct6_pt_ntt = ( + pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] + .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct6_ptct = v0.ptct_mul[v0.max_level] + ct6_ptct.set_plaintext(ct6_pt_ntt) + ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) + _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw + _ct6_m_in = _ct6_data.shape[-1] + _ct6_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct6_m_in + ) + _ct6_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct6_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct6_r) + ) + _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) + if isinstance(_ct6_moduli, (int, np.integer)): + _ct6_moduli = [int(_ct6_moduli)] + ct6 = Polynomial( + { + "batch": _ct6_data.shape[0], + "num_elements": _ct6_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct6_m, + "precision": 32, + "degree_layout": (_ct6_r, _ct6_c), + }, + {"moduli": list(_ct6_moduli)[:_ct6_m]}, + ) + ct6.polynomial = _ct6_data.reshape( + _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in + )[..., :_ct6_m].copy() + ct6.batch = ct6.polynomial.shape[0] + ct6.num_elements = ct6.polynomial.shape[1] + ct6.num_moduli = _ct6_m + ct6.degree_layout = (_ct6_r, _ct6_c) + ct6.r = _ct6_r + ct6.c = _ct6_c + ct6.moduli = list(_ct6_moduli)[:_ct6_m] + ct6.moduli_array = jnp.array( + ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) + ) + _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct7_arg_m_in = _ct7_arg_data.shape[-1] + _ct7_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_arg_m_in + ) + _ct7_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_arg_r) + ) + _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct7_arg_moduli, (int, np.integer)): + _ct7_arg_moduli = [int(_ct7_arg_moduli)] + ct7_arg = Polynomial( + { + "batch": _ct7_arg_data.shape[0], + "num_elements": _ct7_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_arg_m, + "precision": 32, + "degree_layout": (_ct7_arg_r, _ct7_arg_c), + }, + {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, + ) + ct7_arg.polynomial = _ct7_arg_data.reshape( + _ct7_arg_data.shape[0], + _ct7_arg_data.shape[1], + _ct7_arg_r, + _ct7_arg_c, + _ct7_arg_m_in, + )[..., :_ct7_arg_m].copy() + ct7_arg.batch = ct7_arg.polynomial.shape[0] + ct7_arg.num_elements = ct7_arg.polynomial.shape[1] + ct7_arg.num_moduli = _ct7_arg_m + ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) + ct7_arg.r = _ct7_arg_r + ct7_arg.c = _ct7_arg_c + ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] + ct7_arg.moduli_array = jnp.array( + ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) + ) + ct7_pt_ntt = ( + pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] + .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct7_ptct = v0.ptct_mul[v0.max_level] + ct7_ptct.set_plaintext(ct7_pt_ntt) + ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) + _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw + _ct7_m_in = _ct7_data.shape[-1] + _ct7_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct7_m_in + ) + _ct7_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct7_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct7_r) + ) + _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) + if isinstance(_ct7_moduli, (int, np.integer)): + _ct7_moduli = [int(_ct7_moduli)] + ct7 = Polynomial( + { + "batch": _ct7_data.shape[0], + "num_elements": _ct7_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct7_m, + "precision": 32, + "degree_layout": (_ct7_r, _ct7_c), + }, + {"moduli": list(_ct7_moduli)[:_ct7_m]}, + ) + ct7.polynomial = _ct7_data.reshape( + _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in + )[..., :_ct7_m].copy() + ct7.batch = ct7.polynomial.shape[0] + ct7.num_elements = ct7.polynomial.shape[1] + ct7.num_moduli = _ct7_m + ct7.degree_layout = (_ct7_r, _ct7_c) + ct7.r = _ct7_r + ct7.c = _ct7_c + ct7.moduli = list(_ct7_moduli)[:_ct7_m] + ct7.moduli_array = jnp.array( + ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) + ) + _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 + _ct8_arg_m_in = _ct8_arg_data.shape[-1] + _ct8_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_arg_m_in + ) + _ct8_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_arg_r) + ) + _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) + if isinstance(_ct8_arg_moduli, (int, np.integer)): + _ct8_arg_moduli = [int(_ct8_arg_moduli)] + ct8_arg = Polynomial( + { + "batch": _ct8_arg_data.shape[0], + "num_elements": _ct8_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_arg_m, + "precision": 32, + "degree_layout": (_ct8_arg_r, _ct8_arg_c), + }, + {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, + ) + ct8_arg.polynomial = _ct8_arg_data.reshape( + _ct8_arg_data.shape[0], + _ct8_arg_data.shape[1], + _ct8_arg_r, + _ct8_arg_c, + _ct8_arg_m_in, + )[..., :_ct8_arg_m].copy() + ct8_arg.batch = ct8_arg.polynomial.shape[0] + ct8_arg.num_elements = ct8_arg.polynomial.shape[1] + ct8_arg.num_moduli = _ct8_arg_m + ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) + ct8_arg.r = _ct8_arg_r + ct8_arg.c = _ct8_arg_c + ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] + ct8_arg.moduli_array = jnp.array( + ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) + ) + ct8_pt_ntt = ( + pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] + .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct8_ptct = v0.ptct_mul[v0.max_level] + ct8_ptct.set_plaintext(ct8_pt_ntt) + ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) + _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw + _ct8_m_in = _ct8_data.shape[-1] + _ct8_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct8_m_in + ) + _ct8_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct8_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct8_r) + ) + _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) + if isinstance(_ct8_moduli, (int, np.integer)): + _ct8_moduli = [int(_ct8_moduli)] + ct8 = Polynomial( + { + "batch": _ct8_data.shape[0], + "num_elements": _ct8_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct8_m, + "precision": 32, + "degree_layout": (_ct8_r, _ct8_c), + }, + {"moduli": list(_ct8_moduli)[:_ct8_m]}, + ) + ct8.polynomial = _ct8_data.reshape( + _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in + )[..., :_ct8_m].copy() + ct8.batch = ct8.polynomial.shape[0] + ct8.num_elements = ct8.polynomial.shape[1] + ct8.num_moduli = _ct8_m + ct8.degree_layout = (_ct8_r, _ct8_c) + ct8.r = _ct8_r + ct8.c = _ct8_c + ct8.moduli = list(_ct8_moduli)[:_ct8_m] + ct8.moduli_array = jnp.array( + ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) + ) + _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 + _ct9_m_in = _ct9_data.shape[-1] + _ct9_m = _ct9_m_in + _ct9_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_r) + ) + _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) + if isinstance(_ct9_moduli, (int, np.integer)): + _ct9_moduli = [int(_ct9_moduli)] + ct9 = Polynomial( + { + "batch": _ct9_data.shape[0], + "num_elements": _ct9_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_m, + "precision": 32, + "degree_layout": (_ct9_r, _ct9_c), + }, + {"moduli": list(_ct9_moduli)[:_ct9_m]}, + ) + ct9.polynomial = _ct9_data.reshape( + _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in + )[..., :_ct9_m].copy() + ct9.batch = ct9.polynomial.shape[0] + ct9.num_elements = ct9.polynomial.shape[1] + ct9.num_moduli = _ct9_m + ct9.degree_layout = (_ct9_r, _ct9_c) + ct9.r = _ct9_r + ct9.c = _ct9_c + ct9.moduli = list(_ct9_moduli)[:_ct9_m] + ct9.moduli_array = jnp.array( + ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) + ) + _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 + _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] + _ct9_rhs_m = _ct9_rhs_m_in + _ct9_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct9_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct9_rhs_r) + ) + _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) + if isinstance(_ct9_rhs_moduli, (int, np.integer)): + _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] + ct9_rhs = Polynomial( + { + "batch": _ct9_rhs_data.shape[0], + "num_elements": _ct9_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct9_rhs_m, + "precision": 32, + "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), + }, + {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, + ) + ct9_rhs.polynomial = _ct9_rhs_data.reshape( + _ct9_rhs_data.shape[0], + _ct9_rhs_data.shape[1], + _ct9_rhs_r, + _ct9_rhs_c, + _ct9_rhs_m_in, + )[..., :_ct9_rhs_m].copy() + ct9_rhs.batch = ct9_rhs.polynomial.shape[0] + ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] + ct9_rhs.num_moduli = _ct9_rhs_m + ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) + ct9_rhs.r = _ct9_rhs_r + ct9_rhs.c = _ct9_rhs_c + ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] + ct9_rhs.moduli_array = jnp.array( + ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) + ) + ct9.add(ct9_rhs) + _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) + ct9.polynomial = jnp.where( + ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial + ) + _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 + _ct10_m_in = _ct10_data.shape[-1] + _ct10_m = _ct10_m_in + _ct10_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_r) + ) + _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) + if isinstance(_ct10_moduli, (int, np.integer)): + _ct10_moduli = [int(_ct10_moduli)] + ct10 = Polynomial( + { + "batch": _ct10_data.shape[0], + "num_elements": _ct10_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_m, + "precision": 32, + "degree_layout": (_ct10_r, _ct10_c), + }, + {"moduli": list(_ct10_moduli)[:_ct10_m]}, + ) + ct10.polynomial = _ct10_data.reshape( + _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in + )[..., :_ct10_m].copy() + ct10.batch = ct10.polynomial.shape[0] + ct10.num_elements = ct10.polynomial.shape[1] + ct10.num_moduli = _ct10_m + ct10.degree_layout = (_ct10_r, _ct10_c) + ct10.r = _ct10_r + ct10.c = _ct10_c + ct10.moduli = list(_ct10_moduli)[:_ct10_m] + ct10.moduli_array = jnp.array( + ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) + ) + _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 + _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] + _ct10_rhs_m = _ct10_rhs_m_in + _ct10_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct10_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct10_rhs_r) + ) + _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) + if isinstance(_ct10_rhs_moduli, (int, np.integer)): + _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] + ct10_rhs = Polynomial( + { + "batch": _ct10_rhs_data.shape[0], + "num_elements": _ct10_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct10_rhs_m, + "precision": 32, + "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), + }, + {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, + ) + ct10_rhs.polynomial = _ct10_rhs_data.reshape( + _ct10_rhs_data.shape[0], + _ct10_rhs_data.shape[1], + _ct10_rhs_r, + _ct10_rhs_c, + _ct10_rhs_m_in, + )[..., :_ct10_rhs_m].copy() + ct10_rhs.batch = ct10_rhs.polynomial.shape[0] + ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] + ct10_rhs.num_moduli = _ct10_rhs_m + ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) + ct10_rhs.r = _ct10_rhs_r + ct10_rhs.c = _ct10_rhs_c + ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] + ct10_rhs.moduli_array = jnp.array( + ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) + ) + ct10.add(ct10_rhs) + _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) + ct10.polynomial = jnp.where( + ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial + ) + _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 + _ct11_arg_m_in = _ct11_arg_data.shape[-1] + _ct11_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_arg_m_in + ) + _ct11_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_arg_r) + ) + _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) + if isinstance(_ct11_arg_moduli, (int, np.integer)): + _ct11_arg_moduli = [int(_ct11_arg_moduli)] + ct11_arg = Polynomial( + { + "batch": _ct11_arg_data.shape[0], + "num_elements": _ct11_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_arg_m, + "precision": 32, + "degree_layout": (_ct11_arg_r, _ct11_arg_c), + }, + {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, + ) + ct11_arg.polynomial = _ct11_arg_data.reshape( + _ct11_arg_data.shape[0], + _ct11_arg_data.shape[1], + _ct11_arg_r, + _ct11_arg_c, + _ct11_arg_m_in, + )[..., :_ct11_arg_m].copy() + ct11_arg.batch = ct11_arg.polynomial.shape[0] + ct11_arg.num_elements = ct11_arg.polynomial.shape[1] + ct11_arg.num_moduli = _ct11_arg_m + ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) + ct11_arg.r = _ct11_arg_r + ct11_arg.c = _ct11_arg_c + ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] + ct11_arg.moduli_array = jnp.array( + ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) + ) + ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) + _ct11_data = ( + ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw + ) + _ct11_m_in = _ct11_data.shape[-1] + _ct11_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct11_m_in + ) + _ct11_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct11_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct11_r) + ) + _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) + if isinstance(_ct11_moduli, (int, np.integer)): + _ct11_moduli = [int(_ct11_moduli)] + ct11 = Polynomial( + { + "batch": _ct11_data.shape[0], + "num_elements": _ct11_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct11_m, + "precision": 32, + "degree_layout": (_ct11_r, _ct11_c), + }, + {"moduli": list(_ct11_moduli)[:_ct11_m]}, + ) + ct11.polynomial = _ct11_data.reshape( + _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in + )[..., :_ct11_m].copy() + ct11.batch = ct11.polynomial.shape[0] + ct11.num_elements = ct11.polynomial.shape[1] + ct11.num_moduli = _ct11_m + ct11.degree_layout = (_ct11_r, _ct11_c) + ct11.r = _ct11_r + ct11.c = _ct11_c + ct11.moduli = list(_ct11_moduli)[:_ct11_m] + ct11.moduli_array = jnp.array( + ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) + ) + _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct + _ct12_arg_m_in = _ct12_arg_data.shape[-1] + _ct12_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_arg_m_in + ) + _ct12_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_arg_r) + ) + _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) + if isinstance(_ct12_arg_moduli, (int, np.integer)): + _ct12_arg_moduli = [int(_ct12_arg_moduli)] + ct12_arg = Polynomial( + { + "batch": _ct12_arg_data.shape[0], + "num_elements": _ct12_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_arg_m, + "precision": 32, + "degree_layout": (_ct12_arg_r, _ct12_arg_c), + }, + {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, + ) + ct12_arg.polynomial = _ct12_arg_data.reshape( + _ct12_arg_data.shape[0], + _ct12_arg_data.shape[1], + _ct12_arg_r, + _ct12_arg_c, + _ct12_arg_m_in, + )[..., :_ct12_arg_m].copy() + ct12_arg.batch = ct12_arg.polynomial.shape[0] + ct12_arg.num_elements = ct12_arg.polynomial.shape[1] + ct12_arg.num_moduli = _ct12_arg_m + ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) + ct12_arg.r = _ct12_arg_r + ct12_arg.c = _ct12_arg_c + ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] + ct12_arg.moduli_array = jnp.array( + ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) + ) + ct12_pt_ntt = ( + pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] + .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct12_ptct = v0.ptct_mul[v0.max_level] + ct12_ptct.set_plaintext(ct12_pt_ntt) + ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) + _ct12_data = ( + ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw + ) + _ct12_m_in = _ct12_data.shape[-1] + _ct12_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct12_m_in + ) + _ct12_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct12_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct12_r) + ) + _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) + if isinstance(_ct12_moduli, (int, np.integer)): + _ct12_moduli = [int(_ct12_moduli)] + ct12 = Polynomial( + { + "batch": _ct12_data.shape[0], + "num_elements": _ct12_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct12_m, + "precision": 32, + "degree_layout": (_ct12_r, _ct12_c), + }, + {"moduli": list(_ct12_moduli)[:_ct12_m]}, + ) + ct12.polynomial = _ct12_data.reshape( + _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in + )[..., :_ct12_m].copy() + ct12.batch = ct12.polynomial.shape[0] + ct12.num_elements = ct12.polynomial.shape[1] + ct12.num_moduli = _ct12_m + ct12.degree_layout = (_ct12_r, _ct12_c) + ct12.r = _ct12_r + ct12.c = _ct12_c + ct12.moduli = list(_ct12_moduli)[:_ct12_m] + ct12.moduli_array = jnp.array( + ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) + ) + _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 + _ct13_arg_m_in = _ct13_arg_data.shape[-1] + _ct13_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_arg_m_in + ) + _ct13_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_arg_r) + ) + _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) + if isinstance(_ct13_arg_moduli, (int, np.integer)): + _ct13_arg_moduli = [int(_ct13_arg_moduli)] + ct13_arg = Polynomial( + { + "batch": _ct13_arg_data.shape[0], + "num_elements": _ct13_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_arg_m, + "precision": 32, + "degree_layout": (_ct13_arg_r, _ct13_arg_c), + }, + {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, + ) + ct13_arg.polynomial = _ct13_arg_data.reshape( + _ct13_arg_data.shape[0], + _ct13_arg_data.shape[1], + _ct13_arg_r, + _ct13_arg_c, + _ct13_arg_m_in, + )[..., :_ct13_arg_m].copy() + ct13_arg.batch = ct13_arg.polynomial.shape[0] + ct13_arg.num_elements = ct13_arg.polynomial.shape[1] + ct13_arg.num_moduli = _ct13_arg_m + ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) + ct13_arg.r = _ct13_arg_r + ct13_arg.c = _ct13_arg_c + ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] + ct13_arg.moduli_array = jnp.array( + ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) + ) + ct13_pt_ntt = ( + pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] + .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct13_ptct = v0.ptct_mul[v0.max_level] + ct13_ptct.set_plaintext(ct13_pt_ntt) + ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) + _ct13_data = ( + ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw + ) + _ct13_m_in = _ct13_data.shape[-1] + _ct13_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct13_m_in + ) + _ct13_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct13_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct13_r) + ) + _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) + if isinstance(_ct13_moduli, (int, np.integer)): + _ct13_moduli = [int(_ct13_moduli)] + ct13 = Polynomial( + { + "batch": _ct13_data.shape[0], + "num_elements": _ct13_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct13_m, + "precision": 32, + "degree_layout": (_ct13_r, _ct13_c), + }, + {"moduli": list(_ct13_moduli)[:_ct13_m]}, + ) + ct13.polynomial = _ct13_data.reshape( + _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in + )[..., :_ct13_m].copy() + ct13.batch = ct13.polynomial.shape[0] + ct13.num_elements = ct13.polynomial.shape[1] + ct13.num_moduli = _ct13_m + ct13.degree_layout = (_ct13_r, _ct13_c) + ct13.r = _ct13_r + ct13.c = _ct13_c + ct13.moduli = list(_ct13_moduli)[:_ct13_m] + ct13.moduli_array = jnp.array( + ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) + ) + _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 + _ct14_m_in = _ct14_data.shape[-1] + _ct14_m = _ct14_m_in + _ct14_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_r) + ) + _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) + if isinstance(_ct14_moduli, (int, np.integer)): + _ct14_moduli = [int(_ct14_moduli)] + ct14 = Polynomial( + { + "batch": _ct14_data.shape[0], + "num_elements": _ct14_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_m, + "precision": 32, + "degree_layout": (_ct14_r, _ct14_c), + }, + {"moduli": list(_ct14_moduli)[:_ct14_m]}, + ) + ct14.polynomial = _ct14_data.reshape( + _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in + )[..., :_ct14_m].copy() + ct14.batch = ct14.polynomial.shape[0] + ct14.num_elements = ct14.polynomial.shape[1] + ct14.num_moduli = _ct14_m + ct14.degree_layout = (_ct14_r, _ct14_c) + ct14.r = _ct14_r + ct14.c = _ct14_c + ct14.moduli = list(_ct14_moduli)[:_ct14_m] + ct14.moduli_array = jnp.array( + ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) + ) + _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 + _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] + _ct14_rhs_m = _ct14_rhs_m_in + _ct14_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct14_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct14_rhs_r) + ) + _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) + if isinstance(_ct14_rhs_moduli, (int, np.integer)): + _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] + ct14_rhs = Polynomial( + { + "batch": _ct14_rhs_data.shape[0], + "num_elements": _ct14_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct14_rhs_m, + "precision": 32, + "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), + }, + {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, + ) + ct14_rhs.polynomial = _ct14_rhs_data.reshape( + _ct14_rhs_data.shape[0], + _ct14_rhs_data.shape[1], + _ct14_rhs_r, + _ct14_rhs_c, + _ct14_rhs_m_in, + )[..., :_ct14_rhs_m].copy() + ct14_rhs.batch = ct14_rhs.polynomial.shape[0] + ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] + ct14_rhs.num_moduli = _ct14_rhs_m + ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) + ct14_rhs.r = _ct14_rhs_r + ct14_rhs.c = _ct14_rhs_c + ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] + ct14_rhs.moduli_array = jnp.array( + ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) + ) + ct14.add(ct14_rhs) + _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) + ct14.polynomial = jnp.where( + ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial + ) + _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 + _ct15_arg_m_in = _ct15_arg_data.shape[-1] + _ct15_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_arg_m_in + ) + _ct15_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_arg_r) + ) + _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) + if isinstance(_ct15_arg_moduli, (int, np.integer)): + _ct15_arg_moduli = [int(_ct15_arg_moduli)] + ct15_arg = Polynomial( + { + "batch": _ct15_arg_data.shape[0], + "num_elements": _ct15_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_arg_m, + "precision": 32, + "degree_layout": (_ct15_arg_r, _ct15_arg_c), + }, + {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, + ) + ct15_arg.polynomial = _ct15_arg_data.reshape( + _ct15_arg_data.shape[0], + _ct15_arg_data.shape[1], + _ct15_arg_r, + _ct15_arg_c, + _ct15_arg_m_in, + )[..., :_ct15_arg_m].copy() + ct15_arg.batch = ct15_arg.polynomial.shape[0] + ct15_arg.num_elements = ct15_arg.polynomial.shape[1] + ct15_arg.num_moduli = _ct15_arg_m + ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) + ct15_arg.r = _ct15_arg_r + ct15_arg.c = _ct15_arg_c + ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] + ct15_arg.moduli_array = jnp.array( + ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) + ) + ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) + _ct15_data = ( + ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw + ) + _ct15_m_in = _ct15_data.shape[-1] + _ct15_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct15_m_in + ) + _ct15_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct15_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct15_r) + ) + _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) + if isinstance(_ct15_moduli, (int, np.integer)): + _ct15_moduli = [int(_ct15_moduli)] + ct15 = Polynomial( + { + "batch": _ct15_data.shape[0], + "num_elements": _ct15_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct15_m, + "precision": 32, + "degree_layout": (_ct15_r, _ct15_c), + }, + {"moduli": list(_ct15_moduli)[:_ct15_m]}, + ) + ct15.polynomial = _ct15_data.reshape( + _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in + )[..., :_ct15_m].copy() + ct15.batch = ct15.polynomial.shape[0] + ct15.num_elements = ct15.polynomial.shape[1] + ct15.num_moduli = _ct15_m + ct15.degree_layout = (_ct15_r, _ct15_c) + ct15.r = _ct15_r + ct15.c = _ct15_c + ct15.moduli = list(_ct15_moduli)[:_ct15_m] + ct15.moduli_array = jnp.array( + ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) + ) + _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 + _ct16_m_in = _ct16_data.shape[-1] + _ct16_m = _ct16_m_in + _ct16_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_r) + ) + _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) + if isinstance(_ct16_moduli, (int, np.integer)): + _ct16_moduli = [int(_ct16_moduli)] + ct16 = Polynomial( + { + "batch": _ct16_data.shape[0], + "num_elements": _ct16_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_m, + "precision": 32, + "degree_layout": (_ct16_r, _ct16_c), + }, + {"moduli": list(_ct16_moduli)[:_ct16_m]}, + ) + ct16.polynomial = _ct16_data.reshape( + _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in + )[..., :_ct16_m].copy() + ct16.batch = ct16.polynomial.shape[0] + ct16.num_elements = ct16.polynomial.shape[1] + ct16.num_moduli = _ct16_m + ct16.degree_layout = (_ct16_r, _ct16_c) + ct16.r = _ct16_r + ct16.c = _ct16_c + ct16.moduli = list(_ct16_moduli)[:_ct16_m] + ct16.moduli_array = jnp.array( + ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) + ) + _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 + _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] + _ct16_rhs_m = _ct16_rhs_m_in + _ct16_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct16_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct16_rhs_r) + ) + _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) + if isinstance(_ct16_rhs_moduli, (int, np.integer)): + _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] + ct16_rhs = Polynomial( + { + "batch": _ct16_rhs_data.shape[0], + "num_elements": _ct16_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct16_rhs_m, + "precision": 32, + "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), + }, + {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, + ) + ct16_rhs.polynomial = _ct16_rhs_data.reshape( + _ct16_rhs_data.shape[0], + _ct16_rhs_data.shape[1], + _ct16_rhs_r, + _ct16_rhs_c, + _ct16_rhs_m_in, + )[..., :_ct16_rhs_m].copy() + ct16_rhs.batch = ct16_rhs.polynomial.shape[0] + ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] + ct16_rhs.num_moduli = _ct16_rhs_m + ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) + ct16_rhs.r = _ct16_rhs_r + ct16_rhs.c = _ct16_rhs_c + ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] + ct16_rhs.moduli_array = jnp.array( + ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) + ) + ct16.add(ct16_rhs) + _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) + ct16.polynomial = jnp.where( + ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial + ) + _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 + _ct17_m_in = _ct17_data.shape[-1] + _ct17_m = _ct17_m_in + _ct17_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_r) + ) + _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) + if isinstance(_ct17_moduli, (int, np.integer)): + _ct17_moduli = [int(_ct17_moduli)] + ct17 = Polynomial( + { + "batch": _ct17_data.shape[0], + "num_elements": _ct17_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_m, + "precision": 32, + "degree_layout": (_ct17_r, _ct17_c), + }, + {"moduli": list(_ct17_moduli)[:_ct17_m]}, + ) + ct17.polynomial = _ct17_data.reshape( + _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in + )[..., :_ct17_m].copy() + ct17.batch = ct17.polynomial.shape[0] + ct17.num_elements = ct17.polynomial.shape[1] + ct17.num_moduli = _ct17_m + ct17.degree_layout = (_ct17_r, _ct17_c) + ct17.r = _ct17_r + ct17.c = _ct17_c + ct17.moduli = list(_ct17_moduli)[:_ct17_m] + ct17.moduli_array = jnp.array( + ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) + ) + _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 + _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] + _ct17_rhs_m = _ct17_rhs_m_in + _ct17_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct17_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct17_rhs_r) + ) + _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) + if isinstance(_ct17_rhs_moduli, (int, np.integer)): + _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] + ct17_rhs = Polynomial( + { + "batch": _ct17_rhs_data.shape[0], + "num_elements": _ct17_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct17_rhs_m, + "precision": 32, + "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), + }, + {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, + ) + ct17_rhs.polynomial = _ct17_rhs_data.reshape( + _ct17_rhs_data.shape[0], + _ct17_rhs_data.shape[1], + _ct17_rhs_r, + _ct17_rhs_c, + _ct17_rhs_m_in, + )[..., :_ct17_rhs_m].copy() + ct17_rhs.batch = ct17_rhs.polynomial.shape[0] + ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] + ct17_rhs.num_moduli = _ct17_rhs_m + ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) + ct17_rhs.r = _ct17_rhs_r + ct17_rhs.c = _ct17_rhs_c + ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] + ct17_rhs.moduli_array = jnp.array( + ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) + ) + ct17.add(ct17_rhs) + _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) + ct17.polynomial = jnp.where( + ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial + ) + _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 + _ct18_m_in = _ct18_data.shape[-1] + _ct18_m = _ct18_m_in + _ct18_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_r) + ) + _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) + if isinstance(_ct18_moduli, (int, np.integer)): + _ct18_moduli = [int(_ct18_moduli)] + ct18 = Polynomial( + { + "batch": _ct18_data.shape[0], + "num_elements": _ct18_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_m, + "precision": 32, + "degree_layout": (_ct18_r, _ct18_c), + }, + {"moduli": list(_ct18_moduli)[:_ct18_m]}, + ) + ct18.polynomial = _ct18_data.reshape( + _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in + )[..., :_ct18_m].copy() + ct18.batch = ct18.polynomial.shape[0] + ct18.num_elements = ct18.polynomial.shape[1] + ct18.num_moduli = _ct18_m + ct18.degree_layout = (_ct18_r, _ct18_c) + ct18.r = _ct18_r + ct18.c = _ct18_c + ct18.moduli = list(_ct18_moduli)[:_ct18_m] + ct18.moduli_array = jnp.array( + ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) + ) + _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 + _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] + _ct18_rhs_m = _ct18_rhs_m_in + _ct18_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct18_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct18_rhs_r) + ) + _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) + if isinstance(_ct18_rhs_moduli, (int, np.integer)): + _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] + ct18_rhs = Polynomial( + { + "batch": _ct18_rhs_data.shape[0], + "num_elements": _ct18_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct18_rhs_m, + "precision": 32, + "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), + }, + {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, + ) + ct18_rhs.polynomial = _ct18_rhs_data.reshape( + _ct18_rhs_data.shape[0], + _ct18_rhs_data.shape[1], + _ct18_rhs_r, + _ct18_rhs_c, + _ct18_rhs_m_in, + )[..., :_ct18_rhs_m].copy() + ct18_rhs.batch = ct18_rhs.polynomial.shape[0] + ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] + ct18_rhs.num_moduli = _ct18_rhs_m + ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) + ct18_rhs.r = _ct18_rhs_r + ct18_rhs.c = _ct18_rhs_c + ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] + ct18_rhs.moduli_array = jnp.array( + ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) + ) + ct18.add(ct18_rhs) + _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) + ct18.polynomial = jnp.where( + ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial + ) + _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 + _ct19_m_in = _ct19_data.shape[-1] + _ct19_m = _ct19_m_in + _ct19_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_r) + ) + _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) + if isinstance(_ct19_moduli, (int, np.integer)): + _ct19_moduli = [int(_ct19_moduli)] + ct19 = Polynomial( + { + "batch": _ct19_data.shape[0], + "num_elements": _ct19_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_m, + "precision": 32, + "degree_layout": (_ct19_r, _ct19_c), + }, + {"moduli": list(_ct19_moduli)[:_ct19_m]}, + ) + ct19.polynomial = _ct19_data.reshape( + _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in + )[..., :_ct19_m].copy() + ct19.batch = ct19.polynomial.shape[0] + ct19.num_elements = ct19.polynomial.shape[1] + ct19.num_moduli = _ct19_m + ct19.degree_layout = (_ct19_r, _ct19_c) + ct19.r = _ct19_r + ct19.c = _ct19_c + ct19.moduli = list(_ct19_moduli)[:_ct19_m] + ct19.moduli_array = jnp.array( + ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) + ) + _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 + _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] + _ct19_rhs_m = _ct19_rhs_m_in + _ct19_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct19_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct19_rhs_r) + ) + _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) + if isinstance(_ct19_rhs_moduli, (int, np.integer)): + _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] + ct19_rhs = Polynomial( + { + "batch": _ct19_rhs_data.shape[0], + "num_elements": _ct19_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct19_rhs_m, + "precision": 32, + "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), + }, + {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, + ) + ct19_rhs.polynomial = _ct19_rhs_data.reshape( + _ct19_rhs_data.shape[0], + _ct19_rhs_data.shape[1], + _ct19_rhs_r, + _ct19_rhs_c, + _ct19_rhs_m_in, + )[..., :_ct19_rhs_m].copy() + ct19_rhs.batch = ct19_rhs.polynomial.shape[0] + ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] + ct19_rhs.num_moduli = _ct19_rhs_m + ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) + ct19_rhs.r = _ct19_rhs_r + ct19_rhs.c = _ct19_rhs_c + ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] + ct19_rhs.moduli_array = jnp.array( + ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) + ) + ct19.add(ct19_rhs) + _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) + ct19.polynomial = jnp.where( + ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial + ) + _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 + _ct20_arg_m_in = _ct20_arg_data.shape[-1] + _ct20_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct20_arg_m_in + ) + _ct20_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_arg_r) + ) + _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) + if isinstance(_ct20_arg_moduli, (int, np.integer)): + _ct20_arg_moduli = [int(_ct20_arg_moduli)] + ct20_arg = Polynomial( + { + "batch": _ct20_arg_data.shape[0], + "num_elements": _ct20_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_arg_m, + "precision": 32, + "degree_layout": (_ct20_arg_r, _ct20_arg_c), + }, + {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, + ) + ct20_arg.polynomial = _ct20_arg_data.reshape( + _ct20_arg_data.shape[0], + _ct20_arg_data.shape[1], + _ct20_arg_r, + _ct20_arg_c, + _ct20_arg_m_in, + )[..., :_ct20_arg_m].copy() + ct20_arg.batch = ct20_arg.polynomial.shape[0] + ct20_arg.num_elements = ct20_arg.polynomial.shape[1] + ct20_arg.num_moduli = _ct20_arg_m + ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) + ct20_arg.r = _ct20_arg_r + ct20_arg.c = _ct20_arg_c + ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] + ct20_arg.moduli_array = jnp.array( + ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) + ) + ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) + _ct20_data = ( + ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw + ) + _ct20_m_in = _ct20_data.shape[-1] + _ct20_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct20_m_in + ) + _ct20_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct20_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct20_r) + ) + _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) + if isinstance(_ct20_moduli, (int, np.integer)): + _ct20_moduli = [int(_ct20_moduli)] + ct20 = Polynomial( + { + "batch": _ct20_data.shape[0], + "num_elements": _ct20_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct20_m, + "precision": 32, + "degree_layout": (_ct20_r, _ct20_c), + }, + {"moduli": list(_ct20_moduli)[:_ct20_m]}, + ) + ct20.polynomial = _ct20_data.reshape( + _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in + )[..., :_ct20_m].copy() + ct20.batch = ct20.polynomial.shape[0] + ct20.num_elements = ct20.polynomial.shape[1] + ct20.num_moduli = _ct20_m + ct20.degree_layout = (_ct20_r, _ct20_c) + ct20.r = _ct20_r + ct20.c = _ct20_c + ct20.moduli = list(_ct20_moduli)[:_ct20_m] + ct20.moduli_array = jnp.array( + ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) + ) + _ct21_arg_data = ct20.polynomial if hasattr(ct20, "polynomial") else ct20 + _ct21_arg_m_in = _ct21_arg_data.shape[-1] + _ct21_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct21_arg_m_in + ) + _ct21_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct21_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct21_arg_r) + ) + _ct21_arg_moduli = getattr(ct20, "moduli", v0.q_towers) + if isinstance(_ct21_arg_moduli, (int, np.integer)): + _ct21_arg_moduli = [int(_ct21_arg_moduli)] + ct21_arg = Polynomial( + { + "batch": _ct21_arg_data.shape[0], + "num_elements": _ct21_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct21_arg_m, + "precision": 32, + "degree_layout": (_ct21_arg_r, _ct21_arg_c), + }, + {"moduli": list(_ct21_arg_moduli)[:_ct21_arg_m]}, + ) + ct21_arg.polynomial = _ct21_arg_data.reshape( + _ct21_arg_data.shape[0], + _ct21_arg_data.shape[1], + _ct21_arg_r, + _ct21_arg_c, + _ct21_arg_m_in, + )[..., :_ct21_arg_m].copy() + ct21_arg.batch = ct21_arg.polynomial.shape[0] + ct21_arg.num_elements = ct21_arg.polynomial.shape[1] + ct21_arg.num_moduli = _ct21_arg_m + ct21_arg.degree_layout = (_ct21_arg_r, _ct21_arg_c) + ct21_arg.r = _ct21_arg_r + ct21_arg.c = _ct21_arg_c + ct21_arg.moduli = list(_ct21_arg_moduli)[:_ct21_arg_m] + ct21_arg.moduli_array = jnp.array( + ct21_arg.moduli, dtype=getattr(ct21_arg, "modulus_dtype", jnp.uint32) + ) + ct21_pt_ntt = ( + pt8.polynomial[0, 0, :, : ct21_arg.polynomial.shape[-1]] + .reshape(ct21_arg.r, ct21_arg.c, ct21_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct21_ptct = v0.ptct_mul[v0.max_level - 1] + ct21_ptct.set_plaintext(ct21_pt_ntt) + ct21_raw = ct21_ptct.mul(ct21_arg, use_bat=False) + _ct21_data = ( + ct21_raw.polynomial if hasattr(ct21_raw, "polynomial") else ct21_raw + ) + _ct21_m_in = _ct21_data.shape[-1] + _ct21_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct21_m_in + ) + _ct21_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct21_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct21_r) + ) + _ct21_moduli = getattr(ct21_raw, "moduli", v0.q_towers) + if isinstance(_ct21_moduli, (int, np.integer)): + _ct21_moduli = [int(_ct21_moduli)] + ct21 = Polynomial( + { + "batch": _ct21_data.shape[0], + "num_elements": _ct21_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct21_m, + "precision": 32, + "degree_layout": (_ct21_r, _ct21_c), + }, + {"moduli": list(_ct21_moduli)[:_ct21_m]}, + ) + ct21.polynomial = _ct21_data.reshape( + _ct21_data.shape[0], _ct21_data.shape[1], _ct21_r, _ct21_c, _ct21_m_in + )[..., :_ct21_m].copy() + ct21.batch = ct21.polynomial.shape[0] + ct21.num_elements = ct21.polynomial.shape[1] + ct21.num_moduli = _ct21_m + ct21.degree_layout = (_ct21_r, _ct21_c) + ct21.r = _ct21_r + ct21.c = _ct21_c + ct21.moduli = list(_ct21_moduli)[:_ct21_m] + ct21.moduli_array = jnp.array( + ct21.moduli, dtype=getattr(ct21, "modulus_dtype", jnp.uint32) + ) + _ct22_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 + _ct22_arg_m_in = _ct22_arg_data.shape[-1] + _ct22_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct22_arg_m_in + ) + _ct22_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct22_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct22_arg_r) + ) + _ct22_arg_moduli = getattr(ct19, "moduli", v0.q_towers) + if isinstance(_ct22_arg_moduli, (int, np.integer)): + _ct22_arg_moduli = [int(_ct22_arg_moduli)] + ct22_arg = Polynomial( + { + "batch": _ct22_arg_data.shape[0], + "num_elements": _ct22_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct22_arg_m, + "precision": 32, + "degree_layout": (_ct22_arg_r, _ct22_arg_c), + }, + {"moduli": list(_ct22_arg_moduli)[:_ct22_arg_m]}, + ) + ct22_arg.polynomial = _ct22_arg_data.reshape( + _ct22_arg_data.shape[0], + _ct22_arg_data.shape[1], + _ct22_arg_r, + _ct22_arg_c, + _ct22_arg_m_in, + )[..., :_ct22_arg_m].copy() + ct22_arg.batch = ct22_arg.polynomial.shape[0] + ct22_arg.num_elements = ct22_arg.polynomial.shape[1] + ct22_arg.num_moduli = _ct22_arg_m + ct22_arg.degree_layout = (_ct22_arg_r, _ct22_arg_c) + ct22_arg.r = _ct22_arg_r + ct22_arg.c = _ct22_arg_c + ct22_arg.moduli = list(_ct22_arg_moduli)[:_ct22_arg_m] + ct22_arg.moduli_array = jnp.array( + ct22_arg.moduli, dtype=getattr(ct22_arg, "modulus_dtype", jnp.uint32) + ) + ct22_raw = v0.he_rot[v0.max_level, 1].rotate(ct22_arg) + _ct22_data = ( + ct22_raw.polynomial if hasattr(ct22_raw, "polynomial") else ct22_raw + ) + _ct22_m_in = _ct22_data.shape[-1] + _ct22_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct22_m_in + ) + _ct22_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct22_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct22_r) + ) + _ct22_moduli = getattr(ct22_raw, "moduli", v0.q_towers) + if isinstance(_ct22_moduli, (int, np.integer)): + _ct22_moduli = [int(_ct22_moduli)] + ct22 = Polynomial( + { + "batch": _ct22_data.shape[0], + "num_elements": _ct22_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct22_m, + "precision": 32, + "degree_layout": (_ct22_r, _ct22_c), + }, + {"moduli": list(_ct22_moduli)[:_ct22_m]}, + ) + ct22.polynomial = _ct22_data.reshape( + _ct22_data.shape[0], _ct22_data.shape[1], _ct22_r, _ct22_c, _ct22_m_in + )[..., :_ct22_m].copy() + ct22.batch = ct22.polynomial.shape[0] + ct22.num_elements = ct22.polynomial.shape[1] + ct22.num_moduli = _ct22_m + ct22.degree_layout = (_ct22_r, _ct22_c) + ct22.r = _ct22_r + ct22.c = _ct22_c + ct22.moduli = list(_ct22_moduli)[:_ct22_m] + ct22.moduli_array = jnp.array( + ct22.moduli, dtype=getattr(ct22, "modulus_dtype", jnp.uint32) + ) + _ct23_arg_data = ct22.polynomial if hasattr(ct22, "polynomial") else ct22 + _ct23_arg_m_in = _ct23_arg_data.shape[-1] + _ct23_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct23_arg_m_in + ) + _ct23_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct23_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct23_arg_r) + ) + _ct23_arg_moduli = getattr(ct22, "moduli", v0.q_towers) + if isinstance(_ct23_arg_moduli, (int, np.integer)): + _ct23_arg_moduli = [int(_ct23_arg_moduli)] + ct23_arg = Polynomial( + { + "batch": _ct23_arg_data.shape[0], + "num_elements": _ct23_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct23_arg_m, + "precision": 32, + "degree_layout": (_ct23_arg_r, _ct23_arg_c), + }, + {"moduli": list(_ct23_arg_moduli)[:_ct23_arg_m]}, + ) + ct23_arg.polynomial = _ct23_arg_data.reshape( + _ct23_arg_data.shape[0], + _ct23_arg_data.shape[1], + _ct23_arg_r, + _ct23_arg_c, + _ct23_arg_m_in, + )[..., :_ct23_arg_m].copy() + ct23_arg.batch = ct23_arg.polynomial.shape[0] + ct23_arg.num_elements = ct23_arg.polynomial.shape[1] + ct23_arg.num_moduli = _ct23_arg_m + ct23_arg.degree_layout = (_ct23_arg_r, _ct23_arg_c) + ct23_arg.r = _ct23_arg_r + ct23_arg.c = _ct23_arg_c + ct23_arg.moduli = list(_ct23_arg_moduli)[:_ct23_arg_m] + ct23_arg.moduli_array = jnp.array( + ct23_arg.moduli, dtype=getattr(ct23_arg, "modulus_dtype", jnp.uint32) + ) + ct23_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct23_arg) + _ct23_data = ( + ct23_raw.polynomial if hasattr(ct23_raw, "polynomial") else ct23_raw + ) + _ct23_m_in = _ct23_data.shape[-1] + _ct23_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct23_m_in + ) + _ct23_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct23_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct23_r) + ) + _ct23_moduli = getattr(ct23_raw, "moduli", v0.q_towers) + if isinstance(_ct23_moduli, (int, np.integer)): + _ct23_moduli = [int(_ct23_moduli)] + ct23 = Polynomial( + { + "batch": _ct23_data.shape[0], + "num_elements": _ct23_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct23_m, + "precision": 32, + "degree_layout": (_ct23_r, _ct23_c), + }, + {"moduli": list(_ct23_moduli)[:_ct23_m]}, + ) + ct23.polynomial = _ct23_data.reshape( + _ct23_data.shape[0], _ct23_data.shape[1], _ct23_r, _ct23_c, _ct23_m_in + )[..., :_ct23_m].copy() + ct23.batch = ct23.polynomial.shape[0] + ct23.num_elements = ct23.polynomial.shape[1] + ct23.num_moduli = _ct23_m + ct23.degree_layout = (_ct23_r, _ct23_c) + ct23.r = _ct23_r + ct23.c = _ct23_c + ct23.moduli = list(_ct23_moduli)[:_ct23_m] + ct23.moduli_array = jnp.array( + ct23.moduli, dtype=getattr(ct23, "modulus_dtype", jnp.uint32) + ) + _ct24_arg_data = ct23.polynomial if hasattr(ct23, "polynomial") else ct23 + _ct24_arg_m_in = _ct24_arg_data.shape[-1] + _ct24_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct24_arg_m_in + ) + _ct24_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct24_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct24_arg_r) + ) + _ct24_arg_moduli = getattr(ct23, "moduli", v0.q_towers) + if isinstance(_ct24_arg_moduli, (int, np.integer)): + _ct24_arg_moduli = [int(_ct24_arg_moduli)] + ct24_arg = Polynomial( + { + "batch": _ct24_arg_data.shape[0], + "num_elements": _ct24_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct24_arg_m, + "precision": 32, + "degree_layout": (_ct24_arg_r, _ct24_arg_c), + }, + {"moduli": list(_ct24_arg_moduli)[:_ct24_arg_m]}, + ) + ct24_arg.polynomial = _ct24_arg_data.reshape( + _ct24_arg_data.shape[0], + _ct24_arg_data.shape[1], + _ct24_arg_r, + _ct24_arg_c, + _ct24_arg_m_in, + )[..., :_ct24_arg_m].copy() + ct24_arg.batch = ct24_arg.polynomial.shape[0] + ct24_arg.num_elements = ct24_arg.polynomial.shape[1] + ct24_arg.num_moduli = _ct24_arg_m + ct24_arg.degree_layout = (_ct24_arg_r, _ct24_arg_c) + ct24_arg.r = _ct24_arg_r + ct24_arg.c = _ct24_arg_c + ct24_arg.moduli = list(_ct24_arg_moduli)[:_ct24_arg_m] + ct24_arg.moduli_array = jnp.array( + ct24_arg.moduli, dtype=getattr(ct24_arg, "modulus_dtype", jnp.uint32) + ) + ct24_pt_ntt = ( + pt9.polynomial[0, 0, :, : ct24_arg.polynomial.shape[-1]] + .reshape(ct24_arg.r, ct24_arg.c, ct24_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct24_ptct = v0.ptct_mul[v0.max_level - 1] + ct24_ptct.set_plaintext(ct24_pt_ntt) + ct24_raw = ct24_ptct.mul(ct24_arg, use_bat=False) + _ct24_data = ( + ct24_raw.polynomial if hasattr(ct24_raw, "polynomial") else ct24_raw + ) + _ct24_m_in = _ct24_data.shape[-1] + _ct24_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct24_m_in + ) + _ct24_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct24_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct24_r) + ) + _ct24_moduli = getattr(ct24_raw, "moduli", v0.q_towers) + if isinstance(_ct24_moduli, (int, np.integer)): + _ct24_moduli = [int(_ct24_moduli)] + ct24 = Polynomial( + { + "batch": _ct24_data.shape[0], + "num_elements": _ct24_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct24_m, + "precision": 32, + "degree_layout": (_ct24_r, _ct24_c), + }, + {"moduli": list(_ct24_moduli)[:_ct24_m]}, + ) + ct24.polynomial = _ct24_data.reshape( + _ct24_data.shape[0], _ct24_data.shape[1], _ct24_r, _ct24_c, _ct24_m_in + )[..., :_ct24_m].copy() + ct24.batch = ct24.polynomial.shape[0] + ct24.num_elements = ct24.polynomial.shape[1] + ct24.num_moduli = _ct24_m + ct24.degree_layout = (_ct24_r, _ct24_c) + ct24.r = _ct24_r + ct24.c = _ct24_c + ct24.moduli = list(_ct24_moduli)[:_ct24_m] + ct24.moduli_array = jnp.array( + ct24.moduli, dtype=getattr(ct24, "modulus_dtype", jnp.uint32) + ) + _ct25_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 + _ct25_arg_m_in = _ct25_arg_data.shape[-1] + _ct25_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct25_arg_m_in + ) + _ct25_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct25_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct25_arg_r) + ) + _ct25_arg_moduli = getattr(ct19, "moduli", v0.q_towers) + if isinstance(_ct25_arg_moduli, (int, np.integer)): + _ct25_arg_moduli = [int(_ct25_arg_moduli)] + ct25_arg = Polynomial( + { + "batch": _ct25_arg_data.shape[0], + "num_elements": _ct25_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct25_arg_m, + "precision": 32, + "degree_layout": (_ct25_arg_r, _ct25_arg_c), + }, + {"moduli": list(_ct25_arg_moduli)[:_ct25_arg_m]}, + ) + ct25_arg.polynomial = _ct25_arg_data.reshape( + _ct25_arg_data.shape[0], + _ct25_arg_data.shape[1], + _ct25_arg_r, + _ct25_arg_c, + _ct25_arg_m_in, + )[..., :_ct25_arg_m].copy() + ct25_arg.batch = ct25_arg.polynomial.shape[0] + ct25_arg.num_elements = ct25_arg.polynomial.shape[1] + ct25_arg.num_moduli = _ct25_arg_m + ct25_arg.degree_layout = (_ct25_arg_r, _ct25_arg_c) + ct25_arg.r = _ct25_arg_r + ct25_arg.c = _ct25_arg_c + ct25_arg.moduli = list(_ct25_arg_moduli)[:_ct25_arg_m] + ct25_arg.moduli_array = jnp.array( + ct25_arg.moduli, dtype=getattr(ct25_arg, "modulus_dtype", jnp.uint32) + ) + ct25_raw = v0.he_rot[v0.max_level, 2].rotate(ct25_arg) + _ct25_data = ( + ct25_raw.polynomial if hasattr(ct25_raw, "polynomial") else ct25_raw + ) + _ct25_m_in = _ct25_data.shape[-1] + _ct25_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct25_m_in + ) + _ct25_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct25_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct25_r) + ) + _ct25_moduli = getattr(ct25_raw, "moduli", v0.q_towers) + if isinstance(_ct25_moduli, (int, np.integer)): + _ct25_moduli = [int(_ct25_moduli)] + ct25 = Polynomial( + { + "batch": _ct25_data.shape[0], + "num_elements": _ct25_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct25_m, + "precision": 32, + "degree_layout": (_ct25_r, _ct25_c), + }, + {"moduli": list(_ct25_moduli)[:_ct25_m]}, + ) + ct25.polynomial = _ct25_data.reshape( + _ct25_data.shape[0], _ct25_data.shape[1], _ct25_r, _ct25_c, _ct25_m_in + )[..., :_ct25_m].copy() + ct25.batch = ct25.polynomial.shape[0] + ct25.num_elements = ct25.polynomial.shape[1] + ct25.num_moduli = _ct25_m + ct25.degree_layout = (_ct25_r, _ct25_c) + ct25.r = _ct25_r + ct25.c = _ct25_c + ct25.moduli = list(_ct25_moduli)[:_ct25_m] + ct25.moduli_array = jnp.array( + ct25.moduli, dtype=getattr(ct25, "modulus_dtype", jnp.uint32) + ) + _ct26_arg_data = ct25.polynomial if hasattr(ct25, "polynomial") else ct25 + _ct26_arg_m_in = _ct26_arg_data.shape[-1] + _ct26_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level) + if hasattr(v0, "_param_cache") + else _ct26_arg_m_in + ) + _ct26_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct26_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct26_arg_r) + ) + _ct26_arg_moduli = getattr(ct25, "moduli", v0.q_towers) + if isinstance(_ct26_arg_moduli, (int, np.integer)): + _ct26_arg_moduli = [int(_ct26_arg_moduli)] + ct26_arg = Polynomial( + { + "batch": _ct26_arg_data.shape[0], + "num_elements": _ct26_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct26_arg_m, + "precision": 32, + "degree_layout": (_ct26_arg_r, _ct26_arg_c), + }, + {"moduli": list(_ct26_arg_moduli)[:_ct26_arg_m]}, + ) + ct26_arg.polynomial = _ct26_arg_data.reshape( + _ct26_arg_data.shape[0], + _ct26_arg_data.shape[1], + _ct26_arg_r, + _ct26_arg_c, + _ct26_arg_m_in, + )[..., :_ct26_arg_m].copy() + ct26_arg.batch = ct26_arg.polynomial.shape[0] + ct26_arg.num_elements = ct26_arg.polynomial.shape[1] + ct26_arg.num_moduli = _ct26_arg_m + ct26_arg.degree_layout = (_ct26_arg_r, _ct26_arg_c) + ct26_arg.r = _ct26_arg_r + ct26_arg.c = _ct26_arg_c + ct26_arg.moduli = list(_ct26_arg_moduli)[:_ct26_arg_m] + ct26_arg.moduli_array = jnp.array( + ct26_arg.moduli, dtype=getattr(ct26_arg, "modulus_dtype", jnp.uint32) + ) + ct26_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct26_arg) + _ct26_data = ( + ct26_raw.polynomial if hasattr(ct26_raw, "polynomial") else ct26_raw + ) + _ct26_m_in = _ct26_data.shape[-1] + _ct26_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct26_m_in + ) + _ct26_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct26_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct26_r) + ) + _ct26_moduli = getattr(ct26_raw, "moduli", v0.q_towers) + if isinstance(_ct26_moduli, (int, np.integer)): + _ct26_moduli = [int(_ct26_moduli)] + ct26 = Polynomial( + { + "batch": _ct26_data.shape[0], + "num_elements": _ct26_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct26_m, + "precision": 32, + "degree_layout": (_ct26_r, _ct26_c), + }, + {"moduli": list(_ct26_moduli)[:_ct26_m]}, + ) + ct26.polynomial = _ct26_data.reshape( + _ct26_data.shape[0], _ct26_data.shape[1], _ct26_r, _ct26_c, _ct26_m_in + )[..., :_ct26_m].copy() + ct26.batch = ct26.polynomial.shape[0] + ct26.num_elements = ct26.polynomial.shape[1] + ct26.num_moduli = _ct26_m + ct26.degree_layout = (_ct26_r, _ct26_c) + ct26.r = _ct26_r + ct26.c = _ct26_c + ct26.moduli = list(_ct26_moduli)[:_ct26_m] + ct26.moduli_array = jnp.array( + ct26.moduli, dtype=getattr(ct26, "modulus_dtype", jnp.uint32) + ) + _ct27_arg_data = ct26.polynomial if hasattr(ct26, "polynomial") else ct26 + _ct27_arg_m_in = _ct27_arg_data.shape[-1] + _ct27_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct27_arg_m_in + ) + _ct27_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct27_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct27_arg_r) + ) + _ct27_arg_moduli = getattr(ct26, "moduli", v0.q_towers) + if isinstance(_ct27_arg_moduli, (int, np.integer)): + _ct27_arg_moduli = [int(_ct27_arg_moduli)] + ct27_arg = Polynomial( + { + "batch": _ct27_arg_data.shape[0], + "num_elements": _ct27_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct27_arg_m, + "precision": 32, + "degree_layout": (_ct27_arg_r, _ct27_arg_c), + }, + {"moduli": list(_ct27_arg_moduli)[:_ct27_arg_m]}, + ) + ct27_arg.polynomial = _ct27_arg_data.reshape( + _ct27_arg_data.shape[0], + _ct27_arg_data.shape[1], + _ct27_arg_r, + _ct27_arg_c, + _ct27_arg_m_in, + )[..., :_ct27_arg_m].copy() + ct27_arg.batch = ct27_arg.polynomial.shape[0] + ct27_arg.num_elements = ct27_arg.polynomial.shape[1] + ct27_arg.num_moduli = _ct27_arg_m + ct27_arg.degree_layout = (_ct27_arg_r, _ct27_arg_c) + ct27_arg.r = _ct27_arg_r + ct27_arg.c = _ct27_arg_c + ct27_arg.moduli = list(_ct27_arg_moduli)[:_ct27_arg_m] + ct27_arg.moduli_array = jnp.array( + ct27_arg.moduli, dtype=getattr(ct27_arg, "modulus_dtype", jnp.uint32) + ) + ct27_pt_ntt = ( + pt10.polynomial[0, 0, :, : ct27_arg.polynomial.shape[-1]] + .reshape(ct27_arg.r, ct27_arg.c, ct27_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct27_ptct = v0.ptct_mul[v0.max_level - 1] + ct27_ptct.set_plaintext(ct27_pt_ntt) + ct27_raw = ct27_ptct.mul(ct27_arg, use_bat=False) + _ct27_data = ( + ct27_raw.polynomial if hasattr(ct27_raw, "polynomial") else ct27_raw + ) + _ct27_m_in = _ct27_data.shape[-1] + _ct27_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct27_m_in + ) + _ct27_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct27_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct27_r) + ) + _ct27_moduli = getattr(ct27_raw, "moduli", v0.q_towers) + if isinstance(_ct27_moduli, (int, np.integer)): + _ct27_moduli = [int(_ct27_moduli)] + ct27 = Polynomial( + { + "batch": _ct27_data.shape[0], + "num_elements": _ct27_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct27_m, + "precision": 32, + "degree_layout": (_ct27_r, _ct27_c), + }, + {"moduli": list(_ct27_moduli)[:_ct27_m]}, + ) + ct27.polynomial = _ct27_data.reshape( + _ct27_data.shape[0], _ct27_data.shape[1], _ct27_r, _ct27_c, _ct27_m_in + )[..., :_ct27_m].copy() + ct27.batch = ct27.polynomial.shape[0] + ct27.num_elements = ct27.polynomial.shape[1] + ct27.num_moduli = _ct27_m + ct27.degree_layout = (_ct27_r, _ct27_c) + ct27.r = _ct27_r + ct27.c = _ct27_c + ct27.moduli = list(_ct27_moduli)[:_ct27_m] + ct27.moduli_array = jnp.array( + ct27.moduli, dtype=getattr(ct27, "modulus_dtype", jnp.uint32) + ) + _ct28_arg_data = ct20.polynomial if hasattr(ct20, "polynomial") else ct20 + _ct28_arg_m_in = _ct28_arg_data.shape[-1] + _ct28_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct28_arg_m_in + ) + _ct28_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct28_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct28_arg_r) + ) + _ct28_arg_moduli = getattr(ct20, "moduli", v0.q_towers) + if isinstance(_ct28_arg_moduli, (int, np.integer)): + _ct28_arg_moduli = [int(_ct28_arg_moduli)] + ct28_arg = Polynomial( + { + "batch": _ct28_arg_data.shape[0], + "num_elements": _ct28_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct28_arg_m, + "precision": 32, + "degree_layout": (_ct28_arg_r, _ct28_arg_c), + }, + {"moduli": list(_ct28_arg_moduli)[:_ct28_arg_m]}, + ) + ct28_arg.polynomial = _ct28_arg_data.reshape( + _ct28_arg_data.shape[0], + _ct28_arg_data.shape[1], + _ct28_arg_r, + _ct28_arg_c, + _ct28_arg_m_in, + )[..., :_ct28_arg_m].copy() + ct28_arg.batch = ct28_arg.polynomial.shape[0] + ct28_arg.num_elements = ct28_arg.polynomial.shape[1] + ct28_arg.num_moduli = _ct28_arg_m + ct28_arg.degree_layout = (_ct28_arg_r, _ct28_arg_c) + ct28_arg.r = _ct28_arg_r + ct28_arg.c = _ct28_arg_c + ct28_arg.moduli = list(_ct28_arg_moduli)[:_ct28_arg_m] + ct28_arg.moduli_array = jnp.array( + ct28_arg.moduli, dtype=getattr(ct28_arg, "modulus_dtype", jnp.uint32) + ) + ct28_pt_ntt = ( + pt11.polynomial[0, 0, :, : ct28_arg.polynomial.shape[-1]] + .reshape(ct28_arg.r, ct28_arg.c, ct28_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct28_ptct = v0.ptct_mul[v0.max_level - 1] + ct28_ptct.set_plaintext(ct28_pt_ntt) + ct28_raw = ct28_ptct.mul(ct28_arg, use_bat=False) + _ct28_data = ( + ct28_raw.polynomial if hasattr(ct28_raw, "polynomial") else ct28_raw + ) + _ct28_m_in = _ct28_data.shape[-1] + _ct28_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct28_m_in + ) + _ct28_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct28_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct28_r) + ) + _ct28_moduli = getattr(ct28_raw, "moduli", v0.q_towers) + if isinstance(_ct28_moduli, (int, np.integer)): + _ct28_moduli = [int(_ct28_moduli)] + ct28 = Polynomial( + { + "batch": _ct28_data.shape[0], + "num_elements": _ct28_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct28_m, + "precision": 32, + "degree_layout": (_ct28_r, _ct28_c), + }, + {"moduli": list(_ct28_moduli)[:_ct28_m]}, + ) + ct28.polynomial = _ct28_data.reshape( + _ct28_data.shape[0], _ct28_data.shape[1], _ct28_r, _ct28_c, _ct28_m_in + )[..., :_ct28_m].copy() + ct28.batch = ct28.polynomial.shape[0] + ct28.num_elements = ct28.polynomial.shape[1] + ct28.num_moduli = _ct28_m + ct28.degree_layout = (_ct28_r, _ct28_c) + ct28.r = _ct28_r + ct28.c = _ct28_c + ct28.moduli = list(_ct28_moduli)[:_ct28_m] + ct28.moduli_array = jnp.array( + ct28.moduli, dtype=getattr(ct28, "modulus_dtype", jnp.uint32) + ) + _ct29_arg_data = ct23.polynomial if hasattr(ct23, "polynomial") else ct23 + _ct29_arg_m_in = _ct29_arg_data.shape[-1] + _ct29_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct29_arg_m_in + ) + _ct29_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct29_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct29_arg_r) + ) + _ct29_arg_moduli = getattr(ct23, "moduli", v0.q_towers) + if isinstance(_ct29_arg_moduli, (int, np.integer)): + _ct29_arg_moduli = [int(_ct29_arg_moduli)] + ct29_arg = Polynomial( + { + "batch": _ct29_arg_data.shape[0], + "num_elements": _ct29_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct29_arg_m, + "precision": 32, + "degree_layout": (_ct29_arg_r, _ct29_arg_c), + }, + {"moduli": list(_ct29_arg_moduli)[:_ct29_arg_m]}, + ) + ct29_arg.polynomial = _ct29_arg_data.reshape( + _ct29_arg_data.shape[0], + _ct29_arg_data.shape[1], + _ct29_arg_r, + _ct29_arg_c, + _ct29_arg_m_in, + )[..., :_ct29_arg_m].copy() + ct29_arg.batch = ct29_arg.polynomial.shape[0] + ct29_arg.num_elements = ct29_arg.polynomial.shape[1] + ct29_arg.num_moduli = _ct29_arg_m + ct29_arg.degree_layout = (_ct29_arg_r, _ct29_arg_c) + ct29_arg.r = _ct29_arg_r + ct29_arg.c = _ct29_arg_c + ct29_arg.moduli = list(_ct29_arg_moduli)[:_ct29_arg_m] + ct29_arg.moduli_array = jnp.array( + ct29_arg.moduli, dtype=getattr(ct29_arg, "modulus_dtype", jnp.uint32) + ) + ct29_pt_ntt = ( + pt12.polynomial[0, 0, :, : ct29_arg.polynomial.shape[-1]] + .reshape(ct29_arg.r, ct29_arg.c, ct29_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct29_ptct = v0.ptct_mul[v0.max_level - 1] + ct29_ptct.set_plaintext(ct29_pt_ntt) + ct29_raw = ct29_ptct.mul(ct29_arg, use_bat=False) + _ct29_data = ( + ct29_raw.polynomial if hasattr(ct29_raw, "polynomial") else ct29_raw + ) + _ct29_m_in = _ct29_data.shape[-1] + _ct29_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct29_m_in + ) + _ct29_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct29_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct29_r) + ) + _ct29_moduli = getattr(ct29_raw, "moduli", v0.q_towers) + if isinstance(_ct29_moduli, (int, np.integer)): + _ct29_moduli = [int(_ct29_moduli)] + ct29 = Polynomial( + { + "batch": _ct29_data.shape[0], + "num_elements": _ct29_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct29_m, + "precision": 32, + "degree_layout": (_ct29_r, _ct29_c), + }, + {"moduli": list(_ct29_moduli)[:_ct29_m]}, + ) + ct29.polynomial = _ct29_data.reshape( + _ct29_data.shape[0], _ct29_data.shape[1], _ct29_r, _ct29_c, _ct29_m_in + )[..., :_ct29_m].copy() + ct29.batch = ct29.polynomial.shape[0] + ct29.num_elements = ct29.polynomial.shape[1] + ct29.num_moduli = _ct29_m + ct29.degree_layout = (_ct29_r, _ct29_c) + ct29.r = _ct29_r + ct29.c = _ct29_c + ct29.moduli = list(_ct29_moduli)[:_ct29_m] + ct29.moduli_array = jnp.array( + ct29.moduli, dtype=getattr(ct29, "modulus_dtype", jnp.uint32) + ) + _ct30_arg_data = ct26.polynomial if hasattr(ct26, "polynomial") else ct26 + _ct30_arg_m_in = _ct30_arg_data.shape[-1] + _ct30_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct30_arg_m_in + ) + _ct30_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct30_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct30_arg_r) + ) + _ct30_arg_moduli = getattr(ct26, "moduli", v0.q_towers) + if isinstance(_ct30_arg_moduli, (int, np.integer)): + _ct30_arg_moduli = [int(_ct30_arg_moduli)] + ct30_arg = Polynomial( + { + "batch": _ct30_arg_data.shape[0], + "num_elements": _ct30_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct30_arg_m, + "precision": 32, + "degree_layout": (_ct30_arg_r, _ct30_arg_c), + }, + {"moduli": list(_ct30_arg_moduli)[:_ct30_arg_m]}, + ) + ct30_arg.polynomial = _ct30_arg_data.reshape( + _ct30_arg_data.shape[0], + _ct30_arg_data.shape[1], + _ct30_arg_r, + _ct30_arg_c, + _ct30_arg_m_in, + )[..., :_ct30_arg_m].copy() + ct30_arg.batch = ct30_arg.polynomial.shape[0] + ct30_arg.num_elements = ct30_arg.polynomial.shape[1] + ct30_arg.num_moduli = _ct30_arg_m + ct30_arg.degree_layout = (_ct30_arg_r, _ct30_arg_c) + ct30_arg.r = _ct30_arg_r + ct30_arg.c = _ct30_arg_c + ct30_arg.moduli = list(_ct30_arg_moduli)[:_ct30_arg_m] + ct30_arg.moduli_array = jnp.array( + ct30_arg.moduli, dtype=getattr(ct30_arg, "modulus_dtype", jnp.uint32) + ) + ct30_pt_ntt = ( + pt13.polynomial[0, 0, :, : ct30_arg.polynomial.shape[-1]] + .reshape(ct30_arg.r, ct30_arg.c, ct30_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct30_ptct = v0.ptct_mul[v0.max_level - 1] + ct30_ptct.set_plaintext(ct30_pt_ntt) + ct30_raw = ct30_ptct.mul(ct30_arg, use_bat=False) + _ct30_data = ( + ct30_raw.polynomial if hasattr(ct30_raw, "polynomial") else ct30_raw + ) + _ct30_m_in = _ct30_data.shape[-1] + _ct30_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct30_m_in + ) + _ct30_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct30_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct30_r) + ) + _ct30_moduli = getattr(ct30_raw, "moduli", v0.q_towers) + if isinstance(_ct30_moduli, (int, np.integer)): + _ct30_moduli = [int(_ct30_moduli)] + ct30 = Polynomial( + { + "batch": _ct30_data.shape[0], + "num_elements": _ct30_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct30_m, + "precision": 32, + "degree_layout": (_ct30_r, _ct30_c), + }, + {"moduli": list(_ct30_moduli)[:_ct30_m]}, + ) + ct30.polynomial = _ct30_data.reshape( + _ct30_data.shape[0], _ct30_data.shape[1], _ct30_r, _ct30_c, _ct30_m_in + )[..., :_ct30_m].copy() + ct30.batch = ct30.polynomial.shape[0] + ct30.num_elements = ct30.polynomial.shape[1] + ct30.num_moduli = _ct30_m + ct30.degree_layout = (_ct30_r, _ct30_c) + ct30.r = _ct30_r + ct30.c = _ct30_c + ct30.moduli = list(_ct30_moduli)[:_ct30_m] + ct30.moduli_array = jnp.array( + ct30.moduli, dtype=getattr(ct30, "modulus_dtype", jnp.uint32) + ) + _ct31_data = ct28.polynomial if hasattr(ct28, "polynomial") else ct28 + _ct31_m_in = _ct31_data.shape[-1] + _ct31_m = _ct31_m_in + _ct31_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct31_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct31_r) + ) + _ct31_moduli = getattr(ct28, "moduli", v0.q_towers) + if isinstance(_ct31_moduli, (int, np.integer)): + _ct31_moduli = [int(_ct31_moduli)] + ct31 = Polynomial( + { + "batch": _ct31_data.shape[0], + "num_elements": _ct31_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct31_m, + "precision": 32, + "degree_layout": (_ct31_r, _ct31_c), + }, + {"moduli": list(_ct31_moduli)[:_ct31_m]}, + ) + ct31.polynomial = _ct31_data.reshape( + _ct31_data.shape[0], _ct31_data.shape[1], _ct31_r, _ct31_c, _ct31_m_in + )[..., :_ct31_m].copy() + ct31.batch = ct31.polynomial.shape[0] + ct31.num_elements = ct31.polynomial.shape[1] + ct31.num_moduli = _ct31_m + ct31.degree_layout = (_ct31_r, _ct31_c) + ct31.r = _ct31_r + ct31.c = _ct31_c + ct31.moduli = list(_ct31_moduli)[:_ct31_m] + ct31.moduli_array = jnp.array( + ct31.moduli, dtype=getattr(ct31, "modulus_dtype", jnp.uint32) + ) + _ct31_rhs_data = ct29.polynomial if hasattr(ct29, "polynomial") else ct29 + _ct31_rhs_m_in = _ct31_rhs_data.shape[-1] + _ct31_rhs_m = _ct31_rhs_m_in + _ct31_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct31_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct31_rhs_r) + ) + _ct31_rhs_moduli = getattr(ct29, "moduli", v0.q_towers) + if isinstance(_ct31_rhs_moduli, (int, np.integer)): + _ct31_rhs_moduli = [int(_ct31_rhs_moduli)] + ct31_rhs = Polynomial( + { + "batch": _ct31_rhs_data.shape[0], + "num_elements": _ct31_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct31_rhs_m, + "precision": 32, + "degree_layout": (_ct31_rhs_r, _ct31_rhs_c), + }, + {"moduli": list(_ct31_rhs_moduli)[:_ct31_rhs_m]}, + ) + ct31_rhs.polynomial = _ct31_rhs_data.reshape( + _ct31_rhs_data.shape[0], + _ct31_rhs_data.shape[1], + _ct31_rhs_r, + _ct31_rhs_c, + _ct31_rhs_m_in, + )[..., :_ct31_rhs_m].copy() + ct31_rhs.batch = ct31_rhs.polynomial.shape[0] + ct31_rhs.num_elements = ct31_rhs.polynomial.shape[1] + ct31_rhs.num_moduli = _ct31_rhs_m + ct31_rhs.degree_layout = (_ct31_rhs_r, _ct31_rhs_c) + ct31_rhs.r = _ct31_rhs_r + ct31_rhs.c = _ct31_rhs_c + ct31_rhs.moduli = list(_ct31_rhs_moduli)[:_ct31_rhs_m] + ct31_rhs.moduli_array = jnp.array( + ct31_rhs.moduli, dtype=getattr(ct31_rhs, "modulus_dtype", jnp.uint32) + ) + ct31.add(ct31_rhs) + _moduli = jnp.array(ct31.moduli, dtype=jnp.uint32) + ct31.polynomial = jnp.where( + ct31.polynomial >= _moduli, ct31.polynomial - _moduli, ct31.polynomial + ) + _ct32_data = ct31.polynomial if hasattr(ct31, "polynomial") else ct31 + _ct32_m_in = _ct32_data.shape[-1] + _ct32_m = _ct32_m_in + _ct32_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct32_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct32_r) + ) + _ct32_moduli = getattr(ct31, "moduli", v0.q_towers) + if isinstance(_ct32_moduli, (int, np.integer)): + _ct32_moduli = [int(_ct32_moduli)] + ct32 = Polynomial( + { + "batch": _ct32_data.shape[0], + "num_elements": _ct32_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct32_m, + "precision": 32, + "degree_layout": (_ct32_r, _ct32_c), + }, + {"moduli": list(_ct32_moduli)[:_ct32_m]}, + ) + ct32.polynomial = _ct32_data.reshape( + _ct32_data.shape[0], _ct32_data.shape[1], _ct32_r, _ct32_c, _ct32_m_in + )[..., :_ct32_m].copy() + ct32.batch = ct32.polynomial.shape[0] + ct32.num_elements = ct32.polynomial.shape[1] + ct32.num_moduli = _ct32_m + ct32.degree_layout = (_ct32_r, _ct32_c) + ct32.r = _ct32_r + ct32.c = _ct32_c + ct32.moduli = list(_ct32_moduli)[:_ct32_m] + ct32.moduli_array = jnp.array( + ct32.moduli, dtype=getattr(ct32, "modulus_dtype", jnp.uint32) + ) + _ct32_rhs_data = ct30.polynomial if hasattr(ct30, "polynomial") else ct30 + _ct32_rhs_m_in = _ct32_rhs_data.shape[-1] + _ct32_rhs_m = _ct32_rhs_m_in + _ct32_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct32_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct32_rhs_r) + ) + _ct32_rhs_moduli = getattr(ct30, "moduli", v0.q_towers) + if isinstance(_ct32_rhs_moduli, (int, np.integer)): + _ct32_rhs_moduli = [int(_ct32_rhs_moduli)] + ct32_rhs = Polynomial( + { + "batch": _ct32_rhs_data.shape[0], + "num_elements": _ct32_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct32_rhs_m, + "precision": 32, + "degree_layout": (_ct32_rhs_r, _ct32_rhs_c), + }, + {"moduli": list(_ct32_rhs_moduli)[:_ct32_rhs_m]}, + ) + ct32_rhs.polynomial = _ct32_rhs_data.reshape( + _ct32_rhs_data.shape[0], + _ct32_rhs_data.shape[1], + _ct32_rhs_r, + _ct32_rhs_c, + _ct32_rhs_m_in, + )[..., :_ct32_rhs_m].copy() + ct32_rhs.batch = ct32_rhs.polynomial.shape[0] + ct32_rhs.num_elements = ct32_rhs.polynomial.shape[1] + ct32_rhs.num_moduli = _ct32_rhs_m + ct32_rhs.degree_layout = (_ct32_rhs_r, _ct32_rhs_c) + ct32_rhs.r = _ct32_rhs_r + ct32_rhs.c = _ct32_rhs_c + ct32_rhs.moduli = list(_ct32_rhs_moduli)[:_ct32_rhs_m] + ct32_rhs.moduli_array = jnp.array( + ct32_rhs.moduli, dtype=getattr(ct32_rhs, "modulus_dtype", jnp.uint32) + ) + ct32.add(ct32_rhs) + _moduli = jnp.array(ct32.moduli, dtype=jnp.uint32) + ct32.polynomial = jnp.where( + ct32.polynomial >= _moduli, ct32.polynomial - _moduli, ct32.polynomial + ) + _ct33_arg_data = ct32.polynomial if hasattr(ct32, "polynomial") else ct32 + _ct33_arg_m_in = _ct33_arg_data.shape[-1] + _ct33_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct33_arg_m_in + ) + _ct33_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct33_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct33_arg_r) + ) + _ct33_arg_moduli = getattr(ct32, "moduli", v0.q_towers) + if isinstance(_ct33_arg_moduli, (int, np.integer)): + _ct33_arg_moduli = [int(_ct33_arg_moduli)] + ct33_arg = Polynomial( + { + "batch": _ct33_arg_data.shape[0], + "num_elements": _ct33_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct33_arg_m, + "precision": 32, + "degree_layout": (_ct33_arg_r, _ct33_arg_c), + }, + {"moduli": list(_ct33_arg_moduli)[:_ct33_arg_m]}, + ) + ct33_arg.polynomial = _ct33_arg_data.reshape( + _ct33_arg_data.shape[0], + _ct33_arg_data.shape[1], + _ct33_arg_r, + _ct33_arg_c, + _ct33_arg_m_in, + )[..., :_ct33_arg_m].copy() + ct33_arg.batch = ct33_arg.polynomial.shape[0] + ct33_arg.num_elements = ct33_arg.polynomial.shape[1] + ct33_arg.num_moduli = _ct33_arg_m + ct33_arg.degree_layout = (_ct33_arg_r, _ct33_arg_c) + ct33_arg.r = _ct33_arg_r + ct33_arg.c = _ct33_arg_c + ct33_arg.moduli = list(_ct33_arg_moduli)[:_ct33_arg_m] + ct33_arg.moduli_array = jnp.array( + ct33_arg.moduli, dtype=getattr(ct33_arg, "modulus_dtype", jnp.uint32) + ) + ct33_raw = v0.he_rot[v0.max_level - 1, 3].rotate(ct33_arg) + _ct33_data = ( + ct33_raw.polynomial if hasattr(ct33_raw, "polynomial") else ct33_raw + ) + _ct33_m_in = _ct33_data.shape[-1] + _ct33_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct33_m_in + ) + _ct33_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct33_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct33_r) + ) + _ct33_moduli = getattr(ct33_raw, "moduli", v0.q_towers) + if isinstance(_ct33_moduli, (int, np.integer)): + _ct33_moduli = [int(_ct33_moduli)] + ct33 = Polynomial( + { + "batch": _ct33_data.shape[0], + "num_elements": _ct33_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct33_m, + "precision": 32, + "degree_layout": (_ct33_r, _ct33_c), + }, + {"moduli": list(_ct33_moduli)[:_ct33_m]}, + ) + ct33.polynomial = _ct33_data.reshape( + _ct33_data.shape[0], _ct33_data.shape[1], _ct33_r, _ct33_c, _ct33_m_in + )[..., :_ct33_m].copy() + ct33.batch = ct33.polynomial.shape[0] + ct33.num_elements = ct33.polynomial.shape[1] + ct33.num_moduli = _ct33_m + ct33.degree_layout = (_ct33_r, _ct33_c) + ct33.r = _ct33_r + ct33.c = _ct33_c + ct33.moduli = list(_ct33_moduli)[:_ct33_m] + ct33.moduli_array = jnp.array( + ct33.moduli, dtype=getattr(ct33, "modulus_dtype", jnp.uint32) + ) + _ct34_arg_data = ct20.polynomial if hasattr(ct20, "polynomial") else ct20 + _ct34_arg_m_in = _ct34_arg_data.shape[-1] + _ct34_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct34_arg_m_in + ) + _ct34_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct34_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct34_arg_r) + ) + _ct34_arg_moduli = getattr(ct20, "moduli", v0.q_towers) + if isinstance(_ct34_arg_moduli, (int, np.integer)): + _ct34_arg_moduli = [int(_ct34_arg_moduli)] + ct34_arg = Polynomial( + { + "batch": _ct34_arg_data.shape[0], + "num_elements": _ct34_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct34_arg_m, + "precision": 32, + "degree_layout": (_ct34_arg_r, _ct34_arg_c), + }, + {"moduli": list(_ct34_arg_moduli)[:_ct34_arg_m]}, + ) + ct34_arg.polynomial = _ct34_arg_data.reshape( + _ct34_arg_data.shape[0], + _ct34_arg_data.shape[1], + _ct34_arg_r, + _ct34_arg_c, + _ct34_arg_m_in, + )[..., :_ct34_arg_m].copy() + ct34_arg.batch = ct34_arg.polynomial.shape[0] + ct34_arg.num_elements = ct34_arg.polynomial.shape[1] + ct34_arg.num_moduli = _ct34_arg_m + ct34_arg.degree_layout = (_ct34_arg_r, _ct34_arg_c) + ct34_arg.r = _ct34_arg_r + ct34_arg.c = _ct34_arg_c + ct34_arg.moduli = list(_ct34_arg_moduli)[:_ct34_arg_m] + ct34_arg.moduli_array = jnp.array( + ct34_arg.moduli, dtype=getattr(ct34_arg, "modulus_dtype", jnp.uint32) + ) + ct34_pt_ntt = ( + pt14.polynomial[0, 0, :, : ct34_arg.polynomial.shape[-1]] + .reshape(ct34_arg.r, ct34_arg.c, ct34_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct34_ptct = v0.ptct_mul[v0.max_level - 1] + ct34_ptct.set_plaintext(ct34_pt_ntt) + ct34_raw = ct34_ptct.mul(ct34_arg, use_bat=False) + _ct34_data = ( + ct34_raw.polynomial if hasattr(ct34_raw, "polynomial") else ct34_raw + ) + _ct34_m_in = _ct34_data.shape[-1] + _ct34_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct34_m_in + ) + _ct34_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct34_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct34_r) + ) + _ct34_moduli = getattr(ct34_raw, "moduli", v0.q_towers) + if isinstance(_ct34_moduli, (int, np.integer)): + _ct34_moduli = [int(_ct34_moduli)] + ct34 = Polynomial( + { + "batch": _ct34_data.shape[0], + "num_elements": _ct34_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct34_m, + "precision": 32, + "degree_layout": (_ct34_r, _ct34_c), + }, + {"moduli": list(_ct34_moduli)[:_ct34_m]}, + ) + ct34.polynomial = _ct34_data.reshape( + _ct34_data.shape[0], _ct34_data.shape[1], _ct34_r, _ct34_c, _ct34_m_in + )[..., :_ct34_m].copy() + ct34.batch = ct34.polynomial.shape[0] + ct34.num_elements = ct34.polynomial.shape[1] + ct34.num_moduli = _ct34_m + ct34.degree_layout = (_ct34_r, _ct34_c) + ct34.r = _ct34_r + ct34.c = _ct34_c + ct34.moduli = list(_ct34_moduli)[:_ct34_m] + ct34.moduli_array = jnp.array( + ct34.moduli, dtype=getattr(ct34, "modulus_dtype", jnp.uint32) + ) + _ct35_arg_data = ct23.polynomial if hasattr(ct23, "polynomial") else ct23 + _ct35_arg_m_in = _ct35_arg_data.shape[-1] + _ct35_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct35_arg_m_in + ) + _ct35_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct35_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct35_arg_r) + ) + _ct35_arg_moduli = getattr(ct23, "moduli", v0.q_towers) + if isinstance(_ct35_arg_moduli, (int, np.integer)): + _ct35_arg_moduli = [int(_ct35_arg_moduli)] + ct35_arg = Polynomial( + { + "batch": _ct35_arg_data.shape[0], + "num_elements": _ct35_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct35_arg_m, + "precision": 32, + "degree_layout": (_ct35_arg_r, _ct35_arg_c), + }, + {"moduli": list(_ct35_arg_moduli)[:_ct35_arg_m]}, + ) + ct35_arg.polynomial = _ct35_arg_data.reshape( + _ct35_arg_data.shape[0], + _ct35_arg_data.shape[1], + _ct35_arg_r, + _ct35_arg_c, + _ct35_arg_m_in, + )[..., :_ct35_arg_m].copy() + ct35_arg.batch = ct35_arg.polynomial.shape[0] + ct35_arg.num_elements = ct35_arg.polynomial.shape[1] + ct35_arg.num_moduli = _ct35_arg_m + ct35_arg.degree_layout = (_ct35_arg_r, _ct35_arg_c) + ct35_arg.r = _ct35_arg_r + ct35_arg.c = _ct35_arg_c + ct35_arg.moduli = list(_ct35_arg_moduli)[:_ct35_arg_m] + ct35_arg.moduli_array = jnp.array( + ct35_arg.moduli, dtype=getattr(ct35_arg, "modulus_dtype", jnp.uint32) + ) + ct35_pt_ntt = ( + pt15.polynomial[0, 0, :, : ct35_arg.polynomial.shape[-1]] + .reshape(ct35_arg.r, ct35_arg.c, ct35_arg.polynomial.shape[-1]) + .astype(jnp.uint32) + ) + ct35_ptct = v0.ptct_mul[v0.max_level - 1] + ct35_ptct.set_plaintext(ct35_pt_ntt) + ct35_raw = ct35_ptct.mul(ct35_arg, use_bat=False) + _ct35_data = ( + ct35_raw.polynomial if hasattr(ct35_raw, "polynomial") else ct35_raw + ) + _ct35_m_in = _ct35_data.shape[-1] + _ct35_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct35_m_in + ) + _ct35_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct35_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct35_r) + ) + _ct35_moduli = getattr(ct35_raw, "moduli", v0.q_towers) + if isinstance(_ct35_moduli, (int, np.integer)): + _ct35_moduli = [int(_ct35_moduli)] + ct35 = Polynomial( + { + "batch": _ct35_data.shape[0], + "num_elements": _ct35_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct35_m, + "precision": 32, + "degree_layout": (_ct35_r, _ct35_c), + }, + {"moduli": list(_ct35_moduli)[:_ct35_m]}, + ) + ct35.polynomial = _ct35_data.reshape( + _ct35_data.shape[0], _ct35_data.shape[1], _ct35_r, _ct35_c, _ct35_m_in + )[..., :_ct35_m].copy() + ct35.batch = ct35.polynomial.shape[0] + ct35.num_elements = ct35.polynomial.shape[1] + ct35.num_moduli = _ct35_m + ct35.degree_layout = (_ct35_r, _ct35_c) + ct35.r = _ct35_r + ct35.c = _ct35_c + ct35.moduli = list(_ct35_moduli)[:_ct35_m] + ct35.moduli_array = jnp.array( + ct35.moduli, dtype=getattr(ct35, "modulus_dtype", jnp.uint32) + ) + _ct36_data = ct34.polynomial if hasattr(ct34, "polynomial") else ct34 + _ct36_m_in = _ct36_data.shape[-1] + _ct36_m = _ct36_m_in + _ct36_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct36_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct36_r) + ) + _ct36_moduli = getattr(ct34, "moduli", v0.q_towers) + if isinstance(_ct36_moduli, (int, np.integer)): + _ct36_moduli = [int(_ct36_moduli)] + ct36 = Polynomial( + { + "batch": _ct36_data.shape[0], + "num_elements": _ct36_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct36_m, + "precision": 32, + "degree_layout": (_ct36_r, _ct36_c), + }, + {"moduli": list(_ct36_moduli)[:_ct36_m]}, + ) + ct36.polynomial = _ct36_data.reshape( + _ct36_data.shape[0], _ct36_data.shape[1], _ct36_r, _ct36_c, _ct36_m_in + )[..., :_ct36_m].copy() + ct36.batch = ct36.polynomial.shape[0] + ct36.num_elements = ct36.polynomial.shape[1] + ct36.num_moduli = _ct36_m + ct36.degree_layout = (_ct36_r, _ct36_c) + ct36.r = _ct36_r + ct36.c = _ct36_c + ct36.moduli = list(_ct36_moduli)[:_ct36_m] + ct36.moduli_array = jnp.array( + ct36.moduli, dtype=getattr(ct36, "modulus_dtype", jnp.uint32) + ) + _ct36_rhs_data = ct35.polynomial if hasattr(ct35, "polynomial") else ct35 + _ct36_rhs_m_in = _ct36_rhs_data.shape[-1] + _ct36_rhs_m = _ct36_rhs_m_in + _ct36_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct36_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct36_rhs_r) + ) + _ct36_rhs_moduli = getattr(ct35, "moduli", v0.q_towers) + if isinstance(_ct36_rhs_moduli, (int, np.integer)): + _ct36_rhs_moduli = [int(_ct36_rhs_moduli)] + ct36_rhs = Polynomial( + { + "batch": _ct36_rhs_data.shape[0], + "num_elements": _ct36_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct36_rhs_m, + "precision": 32, + "degree_layout": (_ct36_rhs_r, _ct36_rhs_c), + }, + {"moduli": list(_ct36_rhs_moduli)[:_ct36_rhs_m]}, + ) + ct36_rhs.polynomial = _ct36_rhs_data.reshape( + _ct36_rhs_data.shape[0], + _ct36_rhs_data.shape[1], + _ct36_rhs_r, + _ct36_rhs_c, + _ct36_rhs_m_in, + )[..., :_ct36_rhs_m].copy() + ct36_rhs.batch = ct36_rhs.polynomial.shape[0] + ct36_rhs.num_elements = ct36_rhs.polynomial.shape[1] + ct36_rhs.num_moduli = _ct36_rhs_m + ct36_rhs.degree_layout = (_ct36_rhs_r, _ct36_rhs_c) + ct36_rhs.r = _ct36_rhs_r + ct36_rhs.c = _ct36_rhs_c + ct36_rhs.moduli = list(_ct36_rhs_moduli)[:_ct36_rhs_m] + ct36_rhs.moduli_array = jnp.array( + ct36_rhs.moduli, dtype=getattr(ct36_rhs, "modulus_dtype", jnp.uint32) + ) + ct36.add(ct36_rhs) + _moduli = jnp.array(ct36.moduli, dtype=jnp.uint32) + ct36.polynomial = jnp.where( + ct36.polynomial >= _moduli, ct36.polynomial - _moduli, ct36.polynomial + ) + _ct37_arg_data = ct36.polynomial if hasattr(ct36, "polynomial") else ct36 + _ct37_arg_m_in = _ct37_arg_data.shape[-1] + _ct37_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct37_arg_m_in + ) + _ct37_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct37_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct37_arg_r) + ) + _ct37_arg_moduli = getattr(ct36, "moduli", v0.q_towers) + if isinstance(_ct37_arg_moduli, (int, np.integer)): + _ct37_arg_moduli = [int(_ct37_arg_moduli)] + ct37_arg = Polynomial( + { + "batch": _ct37_arg_data.shape[0], + "num_elements": _ct37_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct37_arg_m, + "precision": 32, + "degree_layout": (_ct37_arg_r, _ct37_arg_c), + }, + {"moduli": list(_ct37_arg_moduli)[:_ct37_arg_m]}, + ) + ct37_arg.polynomial = _ct37_arg_data.reshape( + _ct37_arg_data.shape[0], + _ct37_arg_data.shape[1], + _ct37_arg_r, + _ct37_arg_c, + _ct37_arg_m_in, + )[..., :_ct37_arg_m].copy() + ct37_arg.batch = ct37_arg.polynomial.shape[0] + ct37_arg.num_elements = ct37_arg.polynomial.shape[1] + ct37_arg.num_moduli = _ct37_arg_m + ct37_arg.degree_layout = (_ct37_arg_r, _ct37_arg_c) + ct37_arg.r = _ct37_arg_r + ct37_arg.c = _ct37_arg_c + ct37_arg.moduli = list(_ct37_arg_moduli)[:_ct37_arg_m] + ct37_arg.moduli_array = jnp.array( + ct37_arg.moduli, dtype=getattr(ct37_arg, "modulus_dtype", jnp.uint32) + ) + ct37_raw = v0.he_rot[v0.max_level - 1, 6].rotate(ct37_arg) + _ct37_data = ( + ct37_raw.polynomial if hasattr(ct37_raw, "polynomial") else ct37_raw + ) + _ct37_m_in = _ct37_data.shape[-1] + _ct37_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct37_m_in + ) + _ct37_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct37_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct37_r) + ) + _ct37_moduli = getattr(ct37_raw, "moduli", v0.q_towers) + if isinstance(_ct37_moduli, (int, np.integer)): + _ct37_moduli = [int(_ct37_moduli)] + ct37 = Polynomial( + { + "batch": _ct37_data.shape[0], + "num_elements": _ct37_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct37_m, + "precision": 32, + "degree_layout": (_ct37_r, _ct37_c), + }, + {"moduli": list(_ct37_moduli)[:_ct37_m]}, + ) + ct37.polynomial = _ct37_data.reshape( + _ct37_data.shape[0], _ct37_data.shape[1], _ct37_r, _ct37_c, _ct37_m_in + )[..., :_ct37_m].copy() + ct37.batch = ct37.polynomial.shape[0] + ct37.num_elements = ct37.polynomial.shape[1] + ct37.num_moduli = _ct37_m + ct37.degree_layout = (_ct37_r, _ct37_c) + ct37.r = _ct37_r + ct37.c = _ct37_c + ct37.moduli = list(_ct37_moduli)[:_ct37_m] + ct37.moduli_array = jnp.array( + ct37.moduli, dtype=getattr(ct37, "modulus_dtype", jnp.uint32) + ) + _ct38_data = ct21.polynomial if hasattr(ct21, "polynomial") else ct21 + _ct38_m_in = _ct38_data.shape[-1] + _ct38_m = _ct38_m_in + _ct38_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct38_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct38_r) + ) + _ct38_moduli = getattr(ct21, "moduli", v0.q_towers) + if isinstance(_ct38_moduli, (int, np.integer)): + _ct38_moduli = [int(_ct38_moduli)] + ct38 = Polynomial( + { + "batch": _ct38_data.shape[0], + "num_elements": _ct38_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct38_m, + "precision": 32, + "degree_layout": (_ct38_r, _ct38_c), + }, + {"moduli": list(_ct38_moduli)[:_ct38_m]}, + ) + ct38.polynomial = _ct38_data.reshape( + _ct38_data.shape[0], _ct38_data.shape[1], _ct38_r, _ct38_c, _ct38_m_in + )[..., :_ct38_m].copy() + ct38.batch = ct38.polynomial.shape[0] + ct38.num_elements = ct38.polynomial.shape[1] + ct38.num_moduli = _ct38_m + ct38.degree_layout = (_ct38_r, _ct38_c) + ct38.r = _ct38_r + ct38.c = _ct38_c + ct38.moduli = list(_ct38_moduli)[:_ct38_m] + ct38.moduli_array = jnp.array( + ct38.moduli, dtype=getattr(ct38, "modulus_dtype", jnp.uint32) + ) + _ct38_rhs_data = ct24.polynomial if hasattr(ct24, "polynomial") else ct24 + _ct38_rhs_m_in = _ct38_rhs_data.shape[-1] + _ct38_rhs_m = _ct38_rhs_m_in + _ct38_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct38_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct38_rhs_r) + ) + _ct38_rhs_moduli = getattr(ct24, "moduli", v0.q_towers) + if isinstance(_ct38_rhs_moduli, (int, np.integer)): + _ct38_rhs_moduli = [int(_ct38_rhs_moduli)] + ct38_rhs = Polynomial( + { + "batch": _ct38_rhs_data.shape[0], + "num_elements": _ct38_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct38_rhs_m, + "precision": 32, + "degree_layout": (_ct38_rhs_r, _ct38_rhs_c), + }, + {"moduli": list(_ct38_rhs_moduli)[:_ct38_rhs_m]}, + ) + ct38_rhs.polynomial = _ct38_rhs_data.reshape( + _ct38_rhs_data.shape[0], + _ct38_rhs_data.shape[1], + _ct38_rhs_r, + _ct38_rhs_c, + _ct38_rhs_m_in, + )[..., :_ct38_rhs_m].copy() + ct38_rhs.batch = ct38_rhs.polynomial.shape[0] + ct38_rhs.num_elements = ct38_rhs.polynomial.shape[1] + ct38_rhs.num_moduli = _ct38_rhs_m + ct38_rhs.degree_layout = (_ct38_rhs_r, _ct38_rhs_c) + ct38_rhs.r = _ct38_rhs_r + ct38_rhs.c = _ct38_rhs_c + ct38_rhs.moduli = list(_ct38_rhs_moduli)[:_ct38_rhs_m] + ct38_rhs.moduli_array = jnp.array( + ct38_rhs.moduli, dtype=getattr(ct38_rhs, "modulus_dtype", jnp.uint32) + ) + ct38.add(ct38_rhs) + _moduli = jnp.array(ct38.moduli, dtype=jnp.uint32) + ct38.polynomial = jnp.where( + ct38.polynomial >= _moduli, ct38.polynomial - _moduli, ct38.polynomial + ) + _ct39_data = ct27.polynomial if hasattr(ct27, "polynomial") else ct27 + _ct39_m_in = _ct39_data.shape[-1] + _ct39_m = _ct39_m_in + _ct39_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct39_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct39_r) + ) + _ct39_moduli = getattr(ct27, "moduli", v0.q_towers) + if isinstance(_ct39_moduli, (int, np.integer)): + _ct39_moduli = [int(_ct39_moduli)] + ct39 = Polynomial( + { + "batch": _ct39_data.shape[0], + "num_elements": _ct39_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct39_m, + "precision": 32, + "degree_layout": (_ct39_r, _ct39_c), + }, + {"moduli": list(_ct39_moduli)[:_ct39_m]}, + ) + ct39.polynomial = _ct39_data.reshape( + _ct39_data.shape[0], _ct39_data.shape[1], _ct39_r, _ct39_c, _ct39_m_in + )[..., :_ct39_m].copy() + ct39.batch = ct39.polynomial.shape[0] + ct39.num_elements = ct39.polynomial.shape[1] + ct39.num_moduli = _ct39_m + ct39.degree_layout = (_ct39_r, _ct39_c) + ct39.r = _ct39_r + ct39.c = _ct39_c + ct39.moduli = list(_ct39_moduli)[:_ct39_m] + ct39.moduli_array = jnp.array( + ct39.moduli, dtype=getattr(ct39, "modulus_dtype", jnp.uint32) + ) + _ct39_rhs_data = ct33.polynomial if hasattr(ct33, "polynomial") else ct33 + _ct39_rhs_m_in = _ct39_rhs_data.shape[-1] + _ct39_rhs_m = _ct39_rhs_m_in + _ct39_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct39_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct39_rhs_r) + ) + _ct39_rhs_moduli = getattr(ct33, "moduli", v0.q_towers) + if isinstance(_ct39_rhs_moduli, (int, np.integer)): + _ct39_rhs_moduli = [int(_ct39_rhs_moduli)] + ct39_rhs = Polynomial( + { + "batch": _ct39_rhs_data.shape[0], + "num_elements": _ct39_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct39_rhs_m, + "precision": 32, + "degree_layout": (_ct39_rhs_r, _ct39_rhs_c), + }, + {"moduli": list(_ct39_rhs_moduli)[:_ct39_rhs_m]}, + ) + ct39_rhs.polynomial = _ct39_rhs_data.reshape( + _ct39_rhs_data.shape[0], + _ct39_rhs_data.shape[1], + _ct39_rhs_r, + _ct39_rhs_c, + _ct39_rhs_m_in, + )[..., :_ct39_rhs_m].copy() + ct39_rhs.batch = ct39_rhs.polynomial.shape[0] + ct39_rhs.num_elements = ct39_rhs.polynomial.shape[1] + ct39_rhs.num_moduli = _ct39_rhs_m + ct39_rhs.degree_layout = (_ct39_rhs_r, _ct39_rhs_c) + ct39_rhs.r = _ct39_rhs_r + ct39_rhs.c = _ct39_rhs_c + ct39_rhs.moduli = list(_ct39_rhs_moduli)[:_ct39_rhs_m] + ct39_rhs.moduli_array = jnp.array( + ct39_rhs.moduli, dtype=getattr(ct39_rhs, "modulus_dtype", jnp.uint32) + ) + ct39.add(ct39_rhs) + _moduli = jnp.array(ct39.moduli, dtype=jnp.uint32) + ct39.polynomial = jnp.where( + ct39.polynomial >= _moduli, ct39.polynomial - _moduli, ct39.polynomial + ) + _ct40_data = ct39.polynomial if hasattr(ct39, "polynomial") else ct39 + _ct40_m_in = _ct40_data.shape[-1] + _ct40_m = _ct40_m_in + _ct40_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct40_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct40_r) + ) + _ct40_moduli = getattr(ct39, "moduli", v0.q_towers) + if isinstance(_ct40_moduli, (int, np.integer)): + _ct40_moduli = [int(_ct40_moduli)] + ct40 = Polynomial( + { + "batch": _ct40_data.shape[0], + "num_elements": _ct40_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct40_m, + "precision": 32, + "degree_layout": (_ct40_r, _ct40_c), + }, + {"moduli": list(_ct40_moduli)[:_ct40_m]}, + ) + ct40.polynomial = _ct40_data.reshape( + _ct40_data.shape[0], _ct40_data.shape[1], _ct40_r, _ct40_c, _ct40_m_in + )[..., :_ct40_m].copy() + ct40.batch = ct40.polynomial.shape[0] + ct40.num_elements = ct40.polynomial.shape[1] + ct40.num_moduli = _ct40_m + ct40.degree_layout = (_ct40_r, _ct40_c) + ct40.r = _ct40_r + ct40.c = _ct40_c + ct40.moduli = list(_ct40_moduli)[:_ct40_m] + ct40.moduli_array = jnp.array( + ct40.moduli, dtype=getattr(ct40, "modulus_dtype", jnp.uint32) + ) + _ct40_rhs_data = ct37.polynomial if hasattr(ct37, "polynomial") else ct37 + _ct40_rhs_m_in = _ct40_rhs_data.shape[-1] + _ct40_rhs_m = _ct40_rhs_m_in + _ct40_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct40_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct40_rhs_r) + ) + _ct40_rhs_moduli = getattr(ct37, "moduli", v0.q_towers) + if isinstance(_ct40_rhs_moduli, (int, np.integer)): + _ct40_rhs_moduli = [int(_ct40_rhs_moduli)] + ct40_rhs = Polynomial( + { + "batch": _ct40_rhs_data.shape[0], + "num_elements": _ct40_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct40_rhs_m, + "precision": 32, + "degree_layout": (_ct40_rhs_r, _ct40_rhs_c), + }, + {"moduli": list(_ct40_rhs_moduli)[:_ct40_rhs_m]}, + ) + ct40_rhs.polynomial = _ct40_rhs_data.reshape( + _ct40_rhs_data.shape[0], + _ct40_rhs_data.shape[1], + _ct40_rhs_r, + _ct40_rhs_c, + _ct40_rhs_m_in, + )[..., :_ct40_rhs_m].copy() + ct40_rhs.batch = ct40_rhs.polynomial.shape[0] + ct40_rhs.num_elements = ct40_rhs.polynomial.shape[1] + ct40_rhs.num_moduli = _ct40_rhs_m + ct40_rhs.degree_layout = (_ct40_rhs_r, _ct40_rhs_c) + ct40_rhs.r = _ct40_rhs_r + ct40_rhs.c = _ct40_rhs_c + ct40_rhs.moduli = list(_ct40_rhs_moduli)[:_ct40_rhs_m] + ct40_rhs.moduli_array = jnp.array( + ct40_rhs.moduli, dtype=getattr(ct40_rhs, "modulus_dtype", jnp.uint32) + ) + ct40.add(ct40_rhs) + _moduli = jnp.array(ct40.moduli, dtype=jnp.uint32) + ct40.polynomial = jnp.where( + ct40.polynomial >= _moduli, ct40.polynomial - _moduli, ct40.polynomial + ) + _ct41_data = ct38.polynomial if hasattr(ct38, "polynomial") else ct38 + _ct41_m_in = _ct41_data.shape[-1] + _ct41_m = _ct41_m_in + _ct41_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct41_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct41_r) + ) + _ct41_moduli = getattr(ct38, "moduli", v0.q_towers) + if isinstance(_ct41_moduli, (int, np.integer)): + _ct41_moduli = [int(_ct41_moduli)] + ct41 = Polynomial( + { + "batch": _ct41_data.shape[0], + "num_elements": _ct41_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct41_m, + "precision": 32, + "degree_layout": (_ct41_r, _ct41_c), + }, + {"moduli": list(_ct41_moduli)[:_ct41_m]}, + ) + ct41.polynomial = _ct41_data.reshape( + _ct41_data.shape[0], _ct41_data.shape[1], _ct41_r, _ct41_c, _ct41_m_in + )[..., :_ct41_m].copy() + ct41.batch = ct41.polynomial.shape[0] + ct41.num_elements = ct41.polynomial.shape[1] + ct41.num_moduli = _ct41_m + ct41.degree_layout = (_ct41_r, _ct41_c) + ct41.r = _ct41_r + ct41.c = _ct41_c + ct41.moduli = list(_ct41_moduli)[:_ct41_m] + ct41.moduli_array = jnp.array( + ct41.moduli, dtype=getattr(ct41, "modulus_dtype", jnp.uint32) + ) + _ct41_rhs_data = ct40.polynomial if hasattr(ct40, "polynomial") else ct40 + _ct41_rhs_m_in = _ct41_rhs_data.shape[-1] + _ct41_rhs_m = _ct41_rhs_m_in + _ct41_rhs_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct41_rhs_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct41_rhs_r) + ) + _ct41_rhs_moduli = getattr(ct40, "moduli", v0.q_towers) + if isinstance(_ct41_rhs_moduli, (int, np.integer)): + _ct41_rhs_moduli = [int(_ct41_rhs_moduli)] + ct41_rhs = Polynomial( + { + "batch": _ct41_rhs_data.shape[0], + "num_elements": _ct41_rhs_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct41_rhs_m, + "precision": 32, + "degree_layout": (_ct41_rhs_r, _ct41_rhs_c), + }, + {"moduli": list(_ct41_rhs_moduli)[:_ct41_rhs_m]}, + ) + ct41_rhs.polynomial = _ct41_rhs_data.reshape( + _ct41_rhs_data.shape[0], + _ct41_rhs_data.shape[1], + _ct41_rhs_r, + _ct41_rhs_c, + _ct41_rhs_m_in, + )[..., :_ct41_rhs_m].copy() + ct41_rhs.batch = ct41_rhs.polynomial.shape[0] + ct41_rhs.num_elements = ct41_rhs.polynomial.shape[1] + ct41_rhs.num_moduli = _ct41_rhs_m + ct41_rhs.degree_layout = (_ct41_rhs_r, _ct41_rhs_c) + ct41_rhs.r = _ct41_rhs_r + ct41_rhs.c = _ct41_rhs_c + ct41_rhs.moduli = list(_ct41_rhs_moduli)[:_ct41_rhs_m] + ct41_rhs.moduli_array = jnp.array( + ct41_rhs.moduli, dtype=getattr(ct41_rhs, "modulus_dtype", jnp.uint32) + ) + ct41.add(ct41_rhs) + _moduli = jnp.array(ct41.moduli, dtype=jnp.uint32) + ct41.polynomial = jnp.where( + ct41.polynomial >= _moduli, ct41.polynomial - _moduli, ct41.polynomial + ) + v20 = [None] * 1 + _ct42_arg_data = ct41.polynomial if hasattr(ct41, "polynomial") else ct41 + _ct42_arg_m_in = _ct42_arg_data.shape[-1] + _ct42_arg_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 1) + if hasattr(v0, "_param_cache") + else _ct42_arg_m_in + ) + _ct42_arg_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct42_arg_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct42_arg_r) + ) + _ct42_arg_moduli = getattr(ct41, "moduli", v0.q_towers) + if isinstance(_ct42_arg_moduli, (int, np.integer)): + _ct42_arg_moduli = [int(_ct42_arg_moduli)] + ct42_arg = Polynomial( + { + "batch": _ct42_arg_data.shape[0], + "num_elements": _ct42_arg_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct42_arg_m, + "precision": 32, + "degree_layout": (_ct42_arg_r, _ct42_arg_c), + }, + {"moduli": list(_ct42_arg_moduli)[:_ct42_arg_m]}, + ) + ct42_arg.polynomial = _ct42_arg_data.reshape( + _ct42_arg_data.shape[0], + _ct42_arg_data.shape[1], + _ct42_arg_r, + _ct42_arg_c, + _ct42_arg_m_in, + )[..., :_ct42_arg_m].copy() + ct42_arg.batch = ct42_arg.polynomial.shape[0] + ct42_arg.num_elements = ct42_arg.polynomial.shape[1] + ct42_arg.num_moduli = _ct42_arg_m + ct42_arg.degree_layout = (_ct42_arg_r, _ct42_arg_c) + ct42_arg.r = _ct42_arg_r + ct42_arg.c = _ct42_arg_c + ct42_arg.moduli = list(_ct42_arg_moduli)[:_ct42_arg_m] + ct42_arg.moduli_array = jnp.array( + ct42_arg.moduli, dtype=getattr(ct42_arg, "modulus_dtype", jnp.uint32) + ) + ct42_raw = v0.he_rescale[v0.max_level - 1, v0.max_level - 2](ct42_arg) + _ct42_data = ( + ct42_raw.polynomial if hasattr(ct42_raw, "polynomial") else ct42_raw + ) + _ct42_m_in = _ct42_data.shape[-1] + _ct42_m = ( + v0._param_cache.num_q_at_level(v0.max_level - 2) + if hasattr(v0, "_param_cache") + else _ct42_m_in + ) + _ct42_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct42_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct42_r) + ) + _ct42_moduli = getattr(ct42_raw, "moduli", v0.q_towers) + if isinstance(_ct42_moduli, (int, np.integer)): + _ct42_moduli = [int(_ct42_moduli)] + ct42 = Polynomial( + { + "batch": _ct42_data.shape[0], + "num_elements": _ct42_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct42_m, + "precision": 32, + "degree_layout": (_ct42_r, _ct42_c), + }, + {"moduli": list(_ct42_moduli)[:_ct42_m]}, + ) + ct42.polynomial = _ct42_data.reshape( + _ct42_data.shape[0], _ct42_data.shape[1], _ct42_r, _ct42_c, _ct42_m_in + )[..., :_ct42_m].copy() + ct42.batch = ct42.polynomial.shape[0] + ct42.num_elements = ct42.polynomial.shape[1] + ct42.num_moduli = _ct42_m + ct42.degree_layout = (_ct42_r, _ct42_c) + ct42.r = _ct42_r + ct42.c = _ct42_c + ct42.moduli = list(_ct42_moduli)[:_ct42_m] + ct42.moduli_array = jnp.array( + ct42.moduli, dtype=getattr(ct42, "modulus_dtype", jnp.uint32) + ) + v20[0] = ct42 + v21 = v20 + return v21 + + +def matvec_chain( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, +) -> np.ndarray: + (v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) = ( + matvec_chain__preprocessing(v0, v1) + ) + v15 = matvec_chain__preprocessed( + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 + ) + return v15 + + +def matvec_chain__encrypt__arg0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = np.full( + ( + 1, + 8, + ), + 0.000000e00, + dtype=np.float32, + ) + v6 = 0 + v7 = 1 + v8 = 8 + v9 = v5.copy() + for v10 in range(0, 8): + v12 = int(v10) + v13 = v2[v12] + v9[0, v12] = v13 + v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) + pt = v0.encode(v15) + v0.public_key = v3 + ct_raw = v0.encrypt(pt) + _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw + _ct_m_in = _ct_data.shape[-1] + _ct_m = _ct_m_in + _ct_r = ( + v0._param_cache.r + if hasattr(v0, "_param_cache") + else v0.parameters.get("r", int(np.sqrt(v0.degree))) + ) + _ct_c = ( + v0._param_cache.c + if hasattr(v0, "_param_cache") + else v0.parameters.get("c", v0.degree // _ct_r) + ) + _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) + if isinstance(_ct_moduli, (int, np.integer)): + _ct_moduli = [int(_ct_moduli)] + ct = Polynomial( + { + "batch": _ct_data.shape[0], + "num_elements": _ct_data.shape[1], + "degree": v0.degree, + "num_moduli": _ct_m, + "precision": 32, + "degree_layout": (_ct_r, _ct_c), + }, + {"moduli": list(_ct_moduli)[:_ct_m]}, + ) + ct.polynomial = _ct_data.reshape( + _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in + )[..., :_ct_m].copy() + ct.batch = ct.polynomial.shape[0] + ct.num_elements = ct.polynomial.shape[1] + ct.num_moduli = _ct_m + ct.degree_layout = (_ct_r, _ct_c) + ct.r = _ct_r + ct.c = _ct_c + ct.moduli = list(_ct_moduli)[:_ct_m] + ct.moduli_array = jnp.array( + ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) + ) + v16 = [ct] + return v16 + + +def matvec_chain__decrypt__result0( + v0: ckks.CKKSContext, + v1: dict, + v2: np.ndarray, + v3: np.ndarray, +) -> np.ndarray: + v4 = 0 + v5 = 8 + v6 = 1 + v7 = 0 + v8 = np.full((8,), 0.000000e00, dtype=np.float32) + ct = v2[0] + v0.secret_key = v3 + _num_moduli = ct.polynomial.shape[-1] + _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] + _ct_for_dec = Polynomial( + { + "batch": ct.polynomial.shape[0], + "num_elements": ct.polynomial.shape[1], + "degree": v0.degree, + "precision": 32, + "num_moduli": _num_moduli, + "degree_layout": (v0.degree,), + }, + {"moduli": _q_sub}, + ) + _ct_for_dec.set_batch_polynomial( + ct.polynomial.reshape( + ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli + ) + ) + pt = v0.decrypt(_ct_for_dec) + v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) + v10 = v8.copy() + for v11 in range(0, 8): + v13 = int(v11) + v14 = v9[0, v13] + v10[v13] = v14 + return v10 + + +def matvec_identity__generate_crypto_context( + public_key, + secret_key, + evaluation_key, +) -> ckks.CKKSContext: + params = { + "degree": 16, + "num_slots": 8, + "batch": 1, + "r": 4, + "c": 4, + "dnum": 3, + "numEvalMult": 1, + "scaling_factor": 563019763943521, + "q_towers": [1073742881, 1073742721, 1073741441, 1073741857, 524353], + "p_towers": [1073740609, 1073739937, 1073739649], + "composite_degree": 1, + "p": 30, + "max_bits_in_word": 61, + "max_bits_value": 9223372036854775295, + "noise_scale_degree": 1, + "CKKS_M_FACTOR": 1, + "public_key": public_key, + "secret_key": secret_key, + "evaluation_key": evaluation_key, + } + v0 = ckks.CKKSContext(params) + return v0 + + +def matvec_identity__configure_crypto_context( + v0: ckks.CKKSContext, +): + v0.program_initialization( + total_hemul_levels=v0.max_level, + total_rotation_indices=[1, 2, 3, 6], + dnum=3, + r=4, + c=4, + batch=1, + ) diff --git a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir index 6ba1dcd271..2258a2f372 100644 --- a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context.mlir @@ -33,10 +33,15 @@ module attributes {ckks.schemeParam = #ckks.scheme_param diff --git a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir index e0d1df3ceb..431f13425f 100644 --- a/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir +++ b/tests/Dialect/JaxiteWord/Transforms/configure_crypto_context_defaults.mlir @@ -32,6 +32,9 @@ module { } // CHECK: @simple_mul__generate_crypto_context +// CHECK-SAME: !jaxiteword.public_key +// CHECK-SAME: !jaxiteword.private_key +// CHECK-SAME: !jaxiteword.eval_key // CHECK: jaxiteword.gen_params // CHECK-SAME: batch = 1 : i32 // CHECK-SAME: c = 4 : i32 diff --git a/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir b/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir index cf1a999229..911993a2c8 100644 --- a/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir +++ b/tests/Emitter/JaxiteWord/emit_jaxiteword.mlir @@ -15,7 +15,7 @@ func.func @test_add(%ctx: !jaxiteword.crypto_context<>, %ct1 : !ct_L1, %ct2 : !c } // CHECK: def test_mul( -// CHECK: hemul( +// CHECK: .he_mul[ func.func @test_mul(%ctx: !jaxiteword.crypto_context<>, %ct1 : !ct_L1, %ct2 : !ct_L1) -> !ct_L1 { %pk, %sk = jaxiteword.gen_keypair %ctx : (!jaxiteword.crypto_context<>) -> (!jaxiteword.public_key<>, !jaxiteword.private_key<>) %ek = jaxiteword.gen_mulkey %ctx, %sk : (!jaxiteword.crypto_context<>, !jaxiteword.private_key<>) -> !jaxiteword.eval_key<> @@ -32,8 +32,14 @@ func.func @test_mul_no_relin(%ctx: !jaxiteword.crypto_context<>, %ct1 : !ct_L1, // CHECK: def test_gen_params( // CHECK: "scaling_factor": 563019763943521 -func.func @test_gen_params() -> !jaxiteword.crypto_context<> { - %ctx = jaxiteword.gen_params { +// CHECK: "public_key": +// CHECK: "secret_key": +// CHECK: "evaluation_key": +func.func @test_gen_params( + %pk: !jaxiteword.public_key<>, + %sk: !jaxiteword.private_key<>, + %ek: !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> { + %ctx = jaxiteword.gen_params %pk, %sk, %ek { degree = 8192 : i64, numSlots = 4096 : i64, scalingFactor = 563019763943521.0 : f64, @@ -45,6 +51,6 @@ func.func @test_gen_params() -> !jaxiteword.crypto_context<> { dnum = 3 : i32, numEvalMult = 2 : i32, compositeDegree = 1 : i32 - } : () -> !jaxiteword.crypto_context<> + } : (!jaxiteword.public_key<>, !jaxiteword.private_key<>, !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> return %ctx : !jaxiteword.crypto_context<> } From cfb3220f3c661b154654f8c3444a1971784dc984 Mon Sep 17 00:00:00 2001 From: Zohaib58 Date: Wed, 17 Jun 2026 06:04:47 +0000 Subject: [PATCH 2/3] Remove Bazel output symlink --- bazel-heir-private | 1 - 1 file changed, 1 deletion(-) delete mode 120000 bazel-heir-private diff --git a/bazel-heir-private b/bazel-heir-private deleted file mode 120000 index 1786659bb8..0000000000 --- a/bazel-heir-private +++ /dev/null @@ -1 +0,0 @@ -/home/zohaib/.cache/bazel/_bazel_zohaib/f8ad823d70143bc66e0160e8a7bf9f07/execroot/_main \ No newline at end of file From 720bb04320b2ba1e35f5ddec633ab8d1ab771d47 Mon Sep 17 00:00:00 2001 From: Zohaib58 Date: Tue, 23 Jun 2026 03:35:48 +0000 Subject: [PATCH 3/3] Remove matvec_8x8 validation artifacts from PR. --- matvec_8x8.mlir | 105 - matvec_8x8_cross.py | 2503 -------- matvec_8x8_jaxite.mlir | 465 -- matvec_8x8_jaxiteword.mlir | 0 matvec_8x8_jaxiteword.py | 11122 ----------------------------------- 5 files changed, 14195 deletions(-) delete mode 100644 matvec_8x8.mlir delete mode 100644 matvec_8x8_cross.py delete mode 100644 matvec_8x8_jaxite.mlir delete mode 100644 matvec_8x8_jaxiteword.mlir delete mode 100644 matvec_8x8_jaxiteword.py diff --git a/matvec_8x8.mlir b/matvec_8x8.mlir deleted file mode 100644 index fb41004848..0000000000 --- a/matvec_8x8.mlir +++ /dev/null @@ -1,105 +0,0 @@ -// MLIR mimicking CROSS_dev/jaxite_word/matvec_test.py (degree=16, 8 slots). -// Test vector (for later validation): [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0] -// -// CROSS crypto params (matvec_test.py setUp): -// degree=16, num_slots=8, dnum=3, r=4, c=4 -// scaling_factor=563019763943521 (pass flag at lowering time, see below) -// q_towers / p_towers preset via ckks.schemeParam on this module -// -// Lowering flags to match remaining CROSS params: -// --torch-linalg-to-ckks=ciphertext-degree=8 -// --jaxiteword-configure-crypto-context=entry-function=matvec_identity,dnum=3,r=4,c=4,scaling-factor=563019763943521 - -module attributes { - scheme.ckks, - ckks.schemeParam = #ckks.scheme_param< - logN = 4, - Q = [1073742881, 1073742721, 1073741441, 1073741857, 524353], - P = [1073740609, 1073739937, 1073739649], - logDefaultScale = 45 - > - -} { - func.func @matvec_identity(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { - %matrix = arith.constant dense<[ - [1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00] - ]> : tensor<8x8xf32> - %out = arith.constant dense<0.000000e+00> : tensor<8xf32> - %0 = linalg.matvec ins(%matrix, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out : tensor<8xf32>) -> tensor<8xf32> - return %0 : tensor<8xf32> - } - - // Shift matrix from CROSS matvec_test.py: result[i] = vector[(i + 1) % n]. - func.func @matvec_shift(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { - %matrix = arith.constant dense<[ - [0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00, 0.000000e+00], - [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 1.000000e+00], - [1.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00] - ]> : tensor<8x8xf32> - %out = arith.constant dense<0.000000e+00> : tensor<8xf32> - %0 = linalg.matvec ins(%matrix, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out : tensor<8xf32>) -> tensor<8xf32> - return %0 : tensor<8xf32> - } - - - // Matrix from np.random.seed(42); np.random.uniform(0.1, 2.0, (8, 8)) - func.func @matvec_random(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { - %matrix = arith.constant dense<[ - [8.11626226e-01, 1.90635718e+00, 1.49078849e+00, 1.23745112e+00, 3.96435417e-01, 3.96389589e-01, 2.10358863e-01, 1.74573468e+00], - [1.24211852e+00, 1.44533790e+00, 1.39110539e-01, 1.94282872e+00, 1.68164102e+00, 5.03444310e-01, 4.45467438e-01, 4.48468569e-01], - [6.78060262e-01, 1.09703722e+00, 9.20695535e-01, 6.53335366e-01, 1.26252050e+00, 3.65038335e-01, 6.55074832e-01, 7.96087502e-01], - [9.66532970e-01, 1.59183433e+00, 4.79380186e-01, 1.07704543e+00, 1.22558768e+00, 1.88255784e-01, 1.25433522e+00, 4.23995835e-01], - [2.23598027e-01, 1.90288252e+00, 1.93470086e+00, 1.63595496e+00, 6.78766161e-01, 2.85577017e-01, 1.40004275e+00, 9.36289738e-01], - [3.31872646e-01, 1.04083613e+00, 1.65338190e-01, 1.82770876e+00, 5.91681965e-01, 1.35879234e+00, 6.92251045e-01, 1.08812924e+00], - [1.13874953e+00, 4.51223465e-01, 1.94221079e+00, 1.57275236e+00, 1.88504799e+00, 1.80017197e+00, 1.23600996e+00, 1.85156105e+00], - [2.68135754e-01, 4.72367439e-01, 1.85931849e-01, 7.18127628e-01, 8.38486850e-01, 6.15563160e-01, 1.67460127e+00, 7.77831321e-01] - ]> : tensor<8x8xf32> - %out = arith.constant dense<0.000000e+00> : tensor<8xf32> - %0 = linalg.matvec ins(%matrix, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out : tensor<8xf32>) -> tensor<8xf32> - return %0 : tensor<8xf32> - } - - // Matmat-vector chain from CROSS matvec_test.py: - // np.random.seed(123); A = round(uniform(0.5, 1.5), 2); - // B = round(uniform(0.5, 1.5), 2); computes A @ (B @ v). - func.func @matvec_chain(%arg0: tensor<8xf32> {secret.secret}) -> tensor<8xf32> { - %matrix_b = arith.constant dense<[ - [1.340000e+00, 5.800000e-01, 1.260000e+00, 7.400000e-01, 6.900000e-01, 1.070000e+00, 6.000000e-01, 1.390000e+00], - [1.130000e+00, 1.220000e+00, 5.200000e-01, 1.090000e+00, 1.060000e+00, 6.600000e-01, 6.500000e-01, 1.200000e+00], - [8.200000e-01, 1.190000e+00, 1.050000e+00, 8.900000e-01, 1.430000e+00, 1.340000e+00, 8.600000e-01, 5.400000e-01], - [8.000000e-01, 9.000000e-01, 1.200000e+00, 1.500000e+00, 8.600000e-01, 1.260000e+00, 1.090000e+00, 1.190000e+00], - [6.500000e-01, 9.000000e-01, 7.400000e-01, 8.400000e-01, 1.010000e+00, 1.170000e+00, 6.100000e-01, 6.300000e-01], - [8.200000e-01, 1.160000e+00, 1.350000e+00, 1.050000e+00, 1.350000e+00, 8.800000e-01, 8.200000e-01, 8.500000e-01], - [6.700000e-01, 1.330000e+00, 8.400000e-01, 1.050000e+00, 1.080000e+00, 1.020000e+00, 5.000000e-01, 1.490000e+00], - [1.410000e+00, 7.100000e-01, 7.900000e-01, 1.020000e+00, 1.400000e+00, 1.480000e+00, 7.600000e-01, 1.060000e+00] - ]> : tensor<8x8xf32> - %matrix_a = arith.constant dense<[ - [1.200000e+00, 7.900000e-01, 7.300000e-01, 1.050000e+00, 1.220000e+00, 9.200000e-01, 1.480000e+00, 1.180000e+00], - [9.800000e-01, 8.900000e-01, 8.400000e-01, 1.230000e+00, 9.400000e-01, 5.600000e-01, 9.000000e-01, 1.240000e+00], - [6.800000e-01, 6.800000e-01, 1.030000e+00, 1.030000e+00, 1.130000e+00, 1.350000e+00, 1.220000e+00, 1.110000e+00], - [1.220000e+00, 8.200000e-01, 8.600000e-01, 7.300000e-01, 7.900000e-01, 1.130000e+00, 5.900000e-01, 9.300000e-01], - [9.300000e-01, 9.900000e-01, 9.300000e-01, 8.100000e-01, 9.300000e-01, 1.390000e+00, 1.440000e+00, 1.000000e+00], - [1.120000e+00, 6.200000e-01, 8.200000e-01, 9.100000e-01, 1.370000e+00, 7.500000e-01, 9.800000e-01, 1.490000e+00], - [1.020000e+00, 1.110000e+00, 6.200000e-01, 1.330000e+00, 1.100000e+00, 1.050000e+00, 8.400000e-01, 8.000000e-01], - [9.200000e-01, 1.180000e+00, 1.380000e+00, 1.010000e+00, 1.170000e+00, 1.090000e+00, 1.120000e+00, 1.170000e+00] - ]> : tensor<8x8xf32> - %out_b = arith.constant dense<0.000000e+00> : tensor<8xf32> - %out_a = arith.constant dense<0.000000e+00> : tensor<8xf32> - %0 = linalg.matvec ins(%matrix_b, %arg0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out_b : tensor<8xf32>) -> tensor<8xf32> - %1 = linalg.matvec ins(%matrix_a, %0 : tensor<8x8xf32>, tensor<8xf32>) outs(%out_a : tensor<8xf32>) -> tensor<8xf32> - return %1 : tensor<8xf32> - } - -} diff --git a/matvec_8x8_cross.py b/matvec_8x8_cross.py deleted file mode 100644 index d3d34ae7d5..0000000000 --- a/matvec_8x8_cross.py +++ /dev/null @@ -1,2503 +0,0 @@ -import jax -import jax.numpy as jnp -import key_gen -import numpy as np -from ciphertext import Ciphertext -from polynomial import Polynomial -import ckks_ctx as ckks - - -def _ensure_poly(ctx, x, level=None): - _cache = ctx._param_cache - _r = _cache.r - _c = _cache.c - _m = _cache.num_q_at_level(level) if level is not None else None - - _data = x.polynomial if isinstance(x, Polynomial) else x - _m_in = _data.shape[-1] - if _m is None: - _m = _m_in - if _m > _m_in: - raise ValueError( - f"_ensure_poly: requested {_m} moduli but data only has {_m_in}" - ) - - if level is not None: - _moduli = _cache.q_moduli_at_level(level) - else: - _moduli_src = getattr(x, "moduli", ctx.q_towers) - if isinstance(_moduli_src, (int, np.integer)): - _moduli_src = [int(_moduli_src)] - _moduli = list(_moduli_src)[:_m] - - # Return a fresh wrapper even when x is already tiled: emitted add/sub paths - # mutate the result object, so aliasing the source would violate SSA semantics. - _out = Polynomial( - { - "batch": _data.shape[0], - "num_elements": _data.shape[1], - "degree": ctx.degree, - "num_moduli": _m, - "precision": 32, - "degree_layout": (_r, _c), - }, - {"moduli": _moduli}, - ) - _out.polynomial = _data.reshape( - _data.shape[0], _data.shape[1], _r, _c, _m_in - )[..., :_m] - return _out - - -def _assign_poly(dst, src): - for _attr in ( - "batch", - "num_elements", - "num_moduli", - "degree", - "precision", - "degree_layout", - "r", - "c", - "moduli", - "moduli_array", - "ntt_ctx", - "shape_in_ntt_all_limbs", - ): - if hasattr(src, _attr): - setattr(dst, _attr, getattr(src, _attr)) - dst.polynomial = src.polynomial - if hasattr(src, "extend_polynomial"): - dst.extend_polynomial = src.extend_polynomial - - -def matvec_identity__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> (np.ndarray, np.ndarray): - v2 = np.full((8,), 0.000000e00, dtype=np.float32) - v3 = np.full((8,), 1.000000e00, dtype=np.float32) - pt = v0.encode(v2) - pt1 = v0.encode(v3) - v4 = [pt] - v5 = [pt1] - return (v4, v5) - - -def matvec_identity__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, -) -> np.ndarray: - v5 = 1 - v6 = 2 - v7 = 3 - v8 = 6 - v9 = 0 - pt = v3[0] - pt1 = v4[0] - ct = v2[0] - ct1_arg = _ensure_poly(v0, ct, v0.max_level) - ct1 = v0.he_rot[v0.max_level, 1].rotate(ct1_arg) - ct2_arg = _ensure_poly(v0, ct1, v0.max_level) - ct2_pt_ntt = ( - pt.polynomial[0, 0, :, : ct2_arg.polynomial.shape[-1]] - .reshape(ct2_arg.r, ct2_arg.c, ct2_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct2_ptct = v0.ptct_mul[v0.max_level] - ct2_ptct.set_plaintext(ct2_pt_ntt) - ct2 = ct2_ptct.mul(ct2_arg, use_bat=False) - ct3_arg = _ensure_poly(v0, ct, v0.max_level) - ct3 = v0.he_rot[v0.max_level, 2].rotate(ct3_arg) - ct4_arg = _ensure_poly(v0, ct3, v0.max_level) - ct4_pt_ntt = ( - pt.polynomial[0, 0, :, : ct4_arg.polynomial.shape[-1]] - .reshape(ct4_arg.r, ct4_arg.c, ct4_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct4_ptct = v0.ptct_mul[v0.max_level] - ct4_ptct.set_plaintext(ct4_pt_ntt) - ct4 = ct4_ptct.mul(ct4_arg, use_bat=False) - ct5_arg = _ensure_poly(v0, ct, v0.max_level) - ct5_pt_ntt = ( - pt.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) - ct6_lhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - ct6_rhs = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - ct6_lhs = ct6_lhs.reshape( - ct6_lhs.shape[0], - ct6_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct6_lhs.shape[-1], - ) - ct6_rhs = ct6_rhs.reshape( - ct6_rhs.shape[0], - ct6_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct6_rhs.shape[-1], - ) - if ct6_lhs.shape != ct6_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct6_num_moduli = ct6_lhs.shape[-1] - if hasattr(ct5, "moduli") and hasattr(ct2, "moduli"): - if list(ct5.moduli)[:ct6_num_moduli] != list(ct2.moduli)[:ct6_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct6_moduli_src = getattr(ct5, "moduli", getattr(ct2, "moduli", v0.q_towers)) - if isinstance(ct6_moduli_src, (int, np.integer)): - ct6_moduli_src = [ct6_moduli_src] - ct6_moduli = jnp.array( - list(ct6_moduli_src)[:ct6_num_moduli], dtype=jnp.uint64 - ) - ct6_sum = ct6_lhs.astype(jnp.uint64) + ct6_rhs.astype(jnp.uint64) - ct6 = jnp.where(ct6_sum >= ct6_moduli, ct6_sum - ct6_moduli, ct6_sum).astype( - jnp.uint32 - ) - ct7_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - ct7_rhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - ct7_lhs = ct7_lhs.reshape( - ct7_lhs.shape[0], - ct7_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct7_lhs.shape[-1], - ) - ct7_rhs = ct7_rhs.reshape( - ct7_rhs.shape[0], - ct7_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct7_rhs.shape[-1], - ) - if ct7_lhs.shape != ct7_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct7_num_moduli = ct7_lhs.shape[-1] - if hasattr(ct6, "moduli") and hasattr(ct4, "moduli"): - if list(ct6.moduli)[:ct7_num_moduli] != list(ct4.moduli)[:ct7_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct7_moduli_src = getattr(ct6, "moduli", getattr(ct4, "moduli", v0.q_towers)) - if isinstance(ct7_moduli_src, (int, np.integer)): - ct7_moduli_src = [ct7_moduli_src] - ct7_moduli = jnp.array( - list(ct7_moduli_src)[:ct7_num_moduli], dtype=jnp.uint64 - ) - ct7_sum = ct7_lhs.astype(jnp.uint64) + ct7_rhs.astype(jnp.uint64) - ct7 = jnp.where(ct7_sum >= ct7_moduli, ct7_sum - ct7_moduli, ct7_sum).astype( - jnp.uint32 - ) - ct8_arg = _ensure_poly(v0, ct7, v0.max_level) - ct8 = v0.he_rot[v0.max_level, 3].rotate(ct8_arg) - ct9_arg = _ensure_poly(v0, ct6, v0.max_level) - ct9 = v0.he_rot[v0.max_level, 6].rotate(ct9_arg) - ct10_arg = _ensure_poly(v0, ct, v0.max_level) - ct10_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct10_arg.polynomial.shape[-1]] - .reshape(ct10_arg.r, ct10_arg.c, ct10_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct10_ptct = v0.ptct_mul[v0.max_level] - ct10_ptct.set_plaintext(ct10_pt_ntt) - ct10 = ct10_ptct.mul(ct10_arg, use_bat=False) - ct11_lhs = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 - ct11_rhs = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - ct11_lhs = ct11_lhs.reshape( - ct11_lhs.shape[0], - ct11_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct11_lhs.shape[-1], - ) - ct11_rhs = ct11_rhs.reshape( - ct11_rhs.shape[0], - ct11_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct11_rhs.shape[-1], - ) - if ct11_lhs.shape != ct11_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct11_num_moduli = ct11_lhs.shape[-1] - if hasattr(ct10, "moduli") and hasattr(ct2, "moduli"): - if ( - list(ct10.moduli)[:ct11_num_moduli] - != list(ct2.moduli)[:ct11_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct11_moduli_src = getattr(ct10, "moduli", getattr(ct2, "moduli", v0.q_towers)) - if isinstance(ct11_moduli_src, (int, np.integer)): - ct11_moduli_src = [ct11_moduli_src] - ct11_moduli = jnp.array( - list(ct11_moduli_src)[:ct11_num_moduli], dtype=jnp.uint64 - ) - ct11_sum = ct11_lhs.astype(jnp.uint64) + ct11_rhs.astype(jnp.uint64) - ct11 = jnp.where( - ct11_sum >= ct11_moduli, ct11_sum - ct11_moduli, ct11_sum - ).astype(jnp.uint32) - ct12_lhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - ct12_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - ct12_lhs = ct12_lhs.reshape( - ct12_lhs.shape[0], - ct12_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct12_lhs.shape[-1], - ) - ct12_rhs = ct12_rhs.reshape( - ct12_rhs.shape[0], - ct12_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct12_rhs.shape[-1], - ) - if ct12_lhs.shape != ct12_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct12_num_moduli = ct12_lhs.shape[-1] - if hasattr(ct4, "moduli") and hasattr(ct8, "moduli"): - if list(ct4.moduli)[:ct12_num_moduli] != list(ct8.moduli)[:ct12_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct12_moduli_src = getattr(ct4, "moduli", getattr(ct8, "moduli", v0.q_towers)) - if isinstance(ct12_moduli_src, (int, np.integer)): - ct12_moduli_src = [ct12_moduli_src] - ct12_moduli = jnp.array( - list(ct12_moduli_src)[:ct12_num_moduli], dtype=jnp.uint64 - ) - ct12_sum = ct12_lhs.astype(jnp.uint64) + ct12_rhs.astype(jnp.uint64) - ct12 = jnp.where( - ct12_sum >= ct12_moduli, ct12_sum - ct12_moduli, ct12_sum - ).astype(jnp.uint32) - ct13_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - ct13_rhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - ct13_lhs = ct13_lhs.reshape( - ct13_lhs.shape[0], - ct13_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct13_lhs.shape[-1], - ) - ct13_rhs = ct13_rhs.reshape( - ct13_rhs.shape[0], - ct13_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct13_rhs.shape[-1], - ) - if ct13_lhs.shape != ct13_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct13_num_moduli = ct13_lhs.shape[-1] - if hasattr(ct12, "moduli") and hasattr(ct9, "moduli"): - if ( - list(ct12.moduli)[:ct13_num_moduli] - != list(ct9.moduli)[:ct13_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct13_moduli_src = getattr(ct12, "moduli", getattr(ct9, "moduli", v0.q_towers)) - if isinstance(ct13_moduli_src, (int, np.integer)): - ct13_moduli_src = [ct13_moduli_src] - ct13_moduli = jnp.array( - list(ct13_moduli_src)[:ct13_num_moduli], dtype=jnp.uint64 - ) - ct13_sum = ct13_lhs.astype(jnp.uint64) + ct13_rhs.astype(jnp.uint64) - ct13 = jnp.where( - ct13_sum >= ct13_moduli, ct13_sum - ct13_moduli, ct13_sum - ).astype(jnp.uint32) - ct14_lhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - ct14_lhs = ct14_lhs.reshape( - ct14_lhs.shape[0], - ct14_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_lhs.shape[-1], - ) - ct14_rhs = ct14_rhs.reshape( - ct14_rhs.shape[0], - ct14_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_rhs.shape[-1], - ) - if ct14_lhs.shape != ct14_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct14_num_moduli = ct14_lhs.shape[-1] - if hasattr(ct11, "moduli") and hasattr(ct13, "moduli"): - if ( - list(ct11.moduli)[:ct14_num_moduli] - != list(ct13.moduli)[:ct14_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct14_moduli_src = getattr( - ct11, "moduli", getattr(ct13, "moduli", v0.q_towers) - ) - if isinstance(ct14_moduli_src, (int, np.integer)): - ct14_moduli_src = [ct14_moduli_src] - ct14_moduli = jnp.array( - list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 - ) - ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) - ct14 = jnp.where( - ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum - ).astype(jnp.uint32) - v10 = [None] * 1 - ct15_arg = _ensure_poly(v0, ct14, v0.max_level) - ct15 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct15_arg) - v10[0] = ct15 - v11 = v10 - return v11 - - -def matvec_identity( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4) = matvec_identity__preprocessing(v0, v1) - v5 = matvec_identity__preprocessed(v0, v1, v2, v3, v4) - return v5 - - -def matvec_identity__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - ct = _ensure_poly(v0, ct_raw) - v16 = [ct] - return v16 - - -def matvec_identity__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 7 - v8 = 0 - v9 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - pt_ct = _ensure_poly(v0, ct) - _num_moduli = pt_ct.polynomial.shape[-1] - _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": pt_ct.polynomial.shape[0], - "num_elements": pt_ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - pt_ct.polynomial.reshape( - pt_ct.polynomial.shape[0], - pt_ct.polynomial.shape[1], - v0.degree, - _num_moduli, - ) - ) - pt = v0.decrypt(_ct_for_dec) - v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v11 = v9.copy() - for v12 in range(0, 8): - v14 = v7 - v12 - v15 = int(v14) - v16 = v10[0, v15] - v11[v15] = v16 - return v11 - - -def matvec_shift__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> (np.ndarray, np.ndarray): - v2 = np.full((8,), 0.000000e00, dtype=np.float32) - v3 = np.full((8,), 1.000000e00, dtype=np.float32) - pt = v0.encode(v2) - pt1 = v0.encode(v3) - v4 = [pt] - v5 = [pt1] - return (v4, v5) - - -def matvec_shift__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, -) -> np.ndarray: - v5 = 1 - v6 = 2 - v7 = 3 - v8 = 6 - v9 = 0 - pt = v3[0] - pt1 = v4[0] - ct = v2[0] - ct1_arg = _ensure_poly(v0, ct, v0.max_level) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1 = ct1_ptct.mul(ct1_arg, use_bat=False) - ct2_arg = _ensure_poly(v0, ct, v0.max_level) - ct2 = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - ct3_arg = _ensure_poly(v0, ct, v0.max_level) - ct3 = v0.he_rot[v0.max_level, 2].rotate(ct3_arg) - ct4_arg = _ensure_poly(v0, ct3, v0.max_level) - ct4_pt_ntt = ( - pt.polynomial[0, 0, :, : ct4_arg.polynomial.shape[-1]] - .reshape(ct4_arg.r, ct4_arg.c, ct4_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct4_ptct = v0.ptct_mul[v0.max_level] - ct4_ptct.set_plaintext(ct4_pt_ntt) - ct4 = ct4_ptct.mul(ct4_arg, use_bat=False) - ct5_arg = _ensure_poly(v0, ct2, v0.max_level) - ct5_pt_ntt = ( - pt.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) - ct6_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - ct6_rhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - ct6_lhs = ct6_lhs.reshape( - ct6_lhs.shape[0], - ct6_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct6_lhs.shape[-1], - ) - ct6_rhs = ct6_rhs.reshape( - ct6_rhs.shape[0], - ct6_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct6_rhs.shape[-1], - ) - if ct6_lhs.shape != ct6_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct6_num_moduli = ct6_lhs.shape[-1] - if hasattr(ct1, "moduli") and hasattr(ct5, "moduli"): - if list(ct1.moduli)[:ct6_num_moduli] != list(ct5.moduli)[:ct6_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct6_moduli_src = getattr(ct1, "moduli", getattr(ct5, "moduli", v0.q_towers)) - if isinstance(ct6_moduli_src, (int, np.integer)): - ct6_moduli_src = [ct6_moduli_src] - ct6_moduli = jnp.array( - list(ct6_moduli_src)[:ct6_num_moduli], dtype=jnp.uint64 - ) - ct6_sum = ct6_lhs.astype(jnp.uint64) + ct6_rhs.astype(jnp.uint64) - ct6 = jnp.where(ct6_sum >= ct6_moduli, ct6_sum - ct6_moduli, ct6_sum).astype( - jnp.uint32 - ) - ct7_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - ct7_rhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - ct7_lhs = ct7_lhs.reshape( - ct7_lhs.shape[0], - ct7_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct7_lhs.shape[-1], - ) - ct7_rhs = ct7_rhs.reshape( - ct7_rhs.shape[0], - ct7_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct7_rhs.shape[-1], - ) - if ct7_lhs.shape != ct7_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct7_num_moduli = ct7_lhs.shape[-1] - if hasattr(ct6, "moduli") and hasattr(ct4, "moduli"): - if list(ct6.moduli)[:ct7_num_moduli] != list(ct4.moduli)[:ct7_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct7_moduli_src = getattr(ct6, "moduli", getattr(ct4, "moduli", v0.q_towers)) - if isinstance(ct7_moduli_src, (int, np.integer)): - ct7_moduli_src = [ct7_moduli_src] - ct7_moduli = jnp.array( - list(ct7_moduli_src)[:ct7_num_moduli], dtype=jnp.uint64 - ) - ct7_sum = ct7_lhs.astype(jnp.uint64) + ct7_rhs.astype(jnp.uint64) - ct7 = jnp.where(ct7_sum >= ct7_moduli, ct7_sum - ct7_moduli, ct7_sum).astype( - jnp.uint32 - ) - ct8_arg = _ensure_poly(v0, ct7, v0.max_level) - ct8 = v0.he_rot[v0.max_level, 3].rotate(ct8_arg) - ct9_arg = _ensure_poly(v0, ct6, v0.max_level) - ct9 = v0.he_rot[v0.max_level, 6].rotate(ct9_arg) - ct10_arg = _ensure_poly(v0, ct2, v0.max_level) - ct10_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct10_arg.polynomial.shape[-1]] - .reshape(ct10_arg.r, ct10_arg.c, ct10_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct10_ptct = v0.ptct_mul[v0.max_level] - ct10_ptct.set_plaintext(ct10_pt_ntt) - ct10 = ct10_ptct.mul(ct10_arg, use_bat=False) - ct11_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - ct11_rhs = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 - ct11_lhs = ct11_lhs.reshape( - ct11_lhs.shape[0], - ct11_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct11_lhs.shape[-1], - ) - ct11_rhs = ct11_rhs.reshape( - ct11_rhs.shape[0], - ct11_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct11_rhs.shape[-1], - ) - if ct11_lhs.shape != ct11_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct11_num_moduli = ct11_lhs.shape[-1] - if hasattr(ct1, "moduli") and hasattr(ct10, "moduli"): - if ( - list(ct1.moduli)[:ct11_num_moduli] - != list(ct10.moduli)[:ct11_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct11_moduli_src = getattr(ct1, "moduli", getattr(ct10, "moduli", v0.q_towers)) - if isinstance(ct11_moduli_src, (int, np.integer)): - ct11_moduli_src = [ct11_moduli_src] - ct11_moduli = jnp.array( - list(ct11_moduli_src)[:ct11_num_moduli], dtype=jnp.uint64 - ) - ct11_sum = ct11_lhs.astype(jnp.uint64) + ct11_rhs.astype(jnp.uint64) - ct11 = jnp.where( - ct11_sum >= ct11_moduli, ct11_sum - ct11_moduli, ct11_sum - ).astype(jnp.uint32) - ct12_lhs = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - ct12_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - ct12_lhs = ct12_lhs.reshape( - ct12_lhs.shape[0], - ct12_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct12_lhs.shape[-1], - ) - ct12_rhs = ct12_rhs.reshape( - ct12_rhs.shape[0], - ct12_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct12_rhs.shape[-1], - ) - if ct12_lhs.shape != ct12_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct12_num_moduli = ct12_lhs.shape[-1] - if hasattr(ct4, "moduli") and hasattr(ct8, "moduli"): - if list(ct4.moduli)[:ct12_num_moduli] != list(ct8.moduli)[:ct12_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct12_moduli_src = getattr(ct4, "moduli", getattr(ct8, "moduli", v0.q_towers)) - if isinstance(ct12_moduli_src, (int, np.integer)): - ct12_moduli_src = [ct12_moduli_src] - ct12_moduli = jnp.array( - list(ct12_moduli_src)[:ct12_num_moduli], dtype=jnp.uint64 - ) - ct12_sum = ct12_lhs.astype(jnp.uint64) + ct12_rhs.astype(jnp.uint64) - ct12 = jnp.where( - ct12_sum >= ct12_moduli, ct12_sum - ct12_moduli, ct12_sum - ).astype(jnp.uint32) - ct13_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - ct13_rhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - ct13_lhs = ct13_lhs.reshape( - ct13_lhs.shape[0], - ct13_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct13_lhs.shape[-1], - ) - ct13_rhs = ct13_rhs.reshape( - ct13_rhs.shape[0], - ct13_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct13_rhs.shape[-1], - ) - if ct13_lhs.shape != ct13_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct13_num_moduli = ct13_lhs.shape[-1] - if hasattr(ct12, "moduli") and hasattr(ct9, "moduli"): - if ( - list(ct12.moduli)[:ct13_num_moduli] - != list(ct9.moduli)[:ct13_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct13_moduli_src = getattr(ct12, "moduli", getattr(ct9, "moduli", v0.q_towers)) - if isinstance(ct13_moduli_src, (int, np.integer)): - ct13_moduli_src = [ct13_moduli_src] - ct13_moduli = jnp.array( - list(ct13_moduli_src)[:ct13_num_moduli], dtype=jnp.uint64 - ) - ct13_sum = ct13_lhs.astype(jnp.uint64) + ct13_rhs.astype(jnp.uint64) - ct13 = jnp.where( - ct13_sum >= ct13_moduli, ct13_sum - ct13_moduli, ct13_sum - ).astype(jnp.uint32) - ct14_lhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - ct14_lhs = ct14_lhs.reshape( - ct14_lhs.shape[0], - ct14_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_lhs.shape[-1], - ) - ct14_rhs = ct14_rhs.reshape( - ct14_rhs.shape[0], - ct14_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_rhs.shape[-1], - ) - if ct14_lhs.shape != ct14_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct14_num_moduli = ct14_lhs.shape[-1] - if hasattr(ct11, "moduli") and hasattr(ct13, "moduli"): - if ( - list(ct11.moduli)[:ct14_num_moduli] - != list(ct13.moduli)[:ct14_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct14_moduli_src = getattr( - ct11, "moduli", getattr(ct13, "moduli", v0.q_towers) - ) - if isinstance(ct14_moduli_src, (int, np.integer)): - ct14_moduli_src = [ct14_moduli_src] - ct14_moduli = jnp.array( - list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 - ) - ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) - ct14 = jnp.where( - ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum - ).astype(jnp.uint32) - v10 = [None] * 1 - ct15_arg = _ensure_poly(v0, ct14, v0.max_level) - ct15 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct15_arg) - v10[0] = ct15 - v11 = v10 - return v11 - - -def matvec_shift( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4) = matvec_shift__preprocessing(v0, v1) - v5 = matvec_shift__preprocessed(v0, v1, v2, v3, v4) - return v5 - - -def matvec_shift__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - ct = _ensure_poly(v0, ct_raw) - v16 = [ct] - return v16 - - -def matvec_shift__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 7 - v8 = 0 - v9 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - pt_ct = _ensure_poly(v0, ct) - _num_moduli = pt_ct.polynomial.shape[-1] - _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": pt_ct.polynomial.shape[0], - "num_elements": pt_ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - pt_ct.polynomial.reshape( - pt_ct.polynomial.shape[0], - pt_ct.polynomial.shape[1], - v0.degree, - _num_moduli, - ) - ) - pt = v0.decrypt(_ct_for_dec) - v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v11 = v9.copy() - for v12 in range(0, 8): - v14 = v7 - v12 - v15 = int(v14) - v16 = v10[0, v15] - v11[v15] = v16 - return v11 - - -def matvec_random__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> ( - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -): - v2 = np.array( - [ - 8.116263e-01, - 1.445338e00, - 9.206955e-01, - 1.077045e00, - 6.787661e-01, - 1.358792e00, - 1.236010e00, - 7.778313e-01, - ], - dtype=np.float32, - ) - v3 = np.array( - [ - 1.906357e00, - 1.391105e-01, - 6.533354e-01, - 1.225588e00, - 2.855770e-01, - 6.922510e-01, - 1.851561e00, - 2.681358e-01, - ], - dtype=np.float32, - ) - v4 = np.array( - [ - 1.490788e00, - 1.942829e00, - 1.262521e00, - 1.882558e-01, - 1.400043e00, - 1.088129e00, - 1.138749e00, - 4.723674e-01, - ], - dtype=np.float32, - ) - v5 = np.array( - [ - 3.318726e-01, - 4.512235e-01, - 1.859318e-01, - 1.237451e00, - 1.681641e00, - 3.650383e-01, - 1.254335e00, - 9.362897e-01, - ], - dtype=np.float32, - ) - v6 = np.array( - [ - 1.040836e00, - 1.942211e00, - 7.181276e-01, - 3.964354e-01, - 5.034443e-01, - 6.550748e-01, - 4.239958e-01, - 2.235980e-01, - ], - dtype=np.float32, - ) - v7 = np.array( - [ - 1.653382e-01, - 1.572752e00, - 8.384869e-01, - 3.963896e-01, - 4.454674e-01, - 7.960875e-01, - 9.665329e-01, - 1.902883e00, - ], - dtype=np.float32, - ) - v8 = np.array( - [ - 6.780602e-01, - 1.591834e00, - 1.934701e00, - 1.827709e00, - 1.885048e00, - 6.155632e-01, - 2.103589e-01, - 4.484686e-01, - ], - dtype=np.float32, - ) - v9 = np.array( - [ - 1.097037e00, - 4.793802e-01, - 1.635955e00, - 5.916820e-01, - 1.800172e00, - 1.674601e00, - 1.745735e00, - 1.242118e00, - ], - dtype=np.float32, - ) - pt = v0.encode(v2) - pt1 = v0.encode(v3) - pt2 = v0.encode(v4) - pt3 = v0.encode(v5) - pt4 = v0.encode(v6) - pt5 = v0.encode(v7) - pt6 = v0.encode(v8) - pt7 = v0.encode(v9) - v10 = [pt] - v11 = [pt1] - v12 = [pt2] - v13 = [pt3] - v14 = [pt4] - v15 = [pt5] - v16 = [pt6] - v17 = [pt7] - return (v10, v11, v12, v13, v14, v15, v16, v17) - - -def matvec_random__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, - v5: np.ndarray, - v6: np.ndarray, - v7: np.ndarray, - v8: np.ndarray, - v9: np.ndarray, - v10: np.ndarray, -) -> np.ndarray: - v11 = 1 - v12 = 2 - v13 = 3 - v14 = 6 - v15 = 0 - pt = v3[0] - pt1 = v4[0] - pt2 = v5[0] - pt3 = v6[0] - pt4 = v7[0] - pt5 = v8[0] - pt6 = v9[0] - pt7 = v10[0] - ct = v2[0] - ct1_arg = _ensure_poly(v0, ct, v0.max_level) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1 = ct1_ptct.mul(ct1_arg, use_bat=False) - ct2_arg = _ensure_poly(v0, ct, v0.max_level) - ct2 = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - ct3_arg = _ensure_poly(v0, ct2, v0.max_level) - ct3_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] - .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct3_ptct = v0.ptct_mul[v0.max_level] - ct3_ptct.set_plaintext(ct3_pt_ntt) - ct3 = ct3_ptct.mul(ct3_arg, use_bat=False) - ct4_arg = _ensure_poly(v0, ct, v0.max_level) - ct4 = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) - ct5_arg = _ensure_poly(v0, ct4, v0.max_level) - ct5_pt_ntt = ( - pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) - ct6_arg = _ensure_poly(v0, ct, v0.max_level) - ct6_pt_ntt = ( - pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] - .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct6_ptct = v0.ptct_mul[v0.max_level] - ct6_ptct.set_plaintext(ct6_pt_ntt) - ct6 = ct6_ptct.mul(ct6_arg, use_bat=False) - ct7_arg = _ensure_poly(v0, ct2, v0.max_level) - ct7_pt_ntt = ( - pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] - .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct7_ptct = v0.ptct_mul[v0.max_level] - ct7_ptct.set_plaintext(ct7_pt_ntt) - ct7 = ct7_ptct.mul(ct7_arg, use_bat=False) - ct8_arg = _ensure_poly(v0, ct4, v0.max_level) - ct8_pt_ntt = ( - pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] - .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct8_ptct = v0.ptct_mul[v0.max_level] - ct8_ptct.set_plaintext(ct8_pt_ntt) - ct8 = ct8_ptct.mul(ct8_arg, use_bat=False) - ct9_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - ct9_rhs = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 - ct9_lhs = ct9_lhs.reshape( - ct9_lhs.shape[0], - ct9_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct9_lhs.shape[-1], - ) - ct9_rhs = ct9_rhs.reshape( - ct9_rhs.shape[0], - ct9_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct9_rhs.shape[-1], - ) - if ct9_lhs.shape != ct9_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct9_num_moduli = ct9_lhs.shape[-1] - if hasattr(ct6, "moduli") and hasattr(ct7, "moduli"): - if list(ct6.moduli)[:ct9_num_moduli] != list(ct7.moduli)[:ct9_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct9_moduli_src = getattr(ct6, "moduli", getattr(ct7, "moduli", v0.q_towers)) - if isinstance(ct9_moduli_src, (int, np.integer)): - ct9_moduli_src = [ct9_moduli_src] - ct9_moduli = jnp.array( - list(ct9_moduli_src)[:ct9_num_moduli], dtype=jnp.uint64 - ) - ct9_sum = ct9_lhs.astype(jnp.uint64) + ct9_rhs.astype(jnp.uint64) - ct9 = jnp.where(ct9_sum >= ct9_moduli, ct9_sum - ct9_moduli, ct9_sum).astype( - jnp.uint32 - ) - ct10_lhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - ct10_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - ct10_lhs = ct10_lhs.reshape( - ct10_lhs.shape[0], - ct10_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct10_lhs.shape[-1], - ) - ct10_rhs = ct10_rhs.reshape( - ct10_rhs.shape[0], - ct10_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct10_rhs.shape[-1], - ) - if ct10_lhs.shape != ct10_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct10_num_moduli = ct10_lhs.shape[-1] - if hasattr(ct9, "moduli") and hasattr(ct8, "moduli"): - if list(ct9.moduli)[:ct10_num_moduli] != list(ct8.moduli)[:ct10_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct10_moduli_src = getattr(ct9, "moduli", getattr(ct8, "moduli", v0.q_towers)) - if isinstance(ct10_moduli_src, (int, np.integer)): - ct10_moduli_src = [ct10_moduli_src] - ct10_moduli = jnp.array( - list(ct10_moduli_src)[:ct10_num_moduli], dtype=jnp.uint64 - ) - ct10_sum = ct10_lhs.astype(jnp.uint64) + ct10_rhs.astype(jnp.uint64) - ct10 = jnp.where( - ct10_sum >= ct10_moduli, ct10_sum - ct10_moduli, ct10_sum - ).astype(jnp.uint32) - ct11_arg = _ensure_poly(v0, ct10, v0.max_level) - ct11 = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) - ct12_arg = _ensure_poly(v0, ct, v0.max_level) - ct12_pt_ntt = ( - pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] - .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct12_ptct = v0.ptct_mul[v0.max_level] - ct12_ptct.set_plaintext(ct12_pt_ntt) - ct12 = ct12_ptct.mul(ct12_arg, use_bat=False) - ct13_arg = _ensure_poly(v0, ct2, v0.max_level) - ct13_pt_ntt = ( - pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] - .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct13_ptct = v0.ptct_mul[v0.max_level] - ct13_ptct.set_plaintext(ct13_pt_ntt) - ct13 = ct13_ptct.mul(ct13_arg, use_bat=False) - ct14_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - ct14_lhs = ct14_lhs.reshape( - ct14_lhs.shape[0], - ct14_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_lhs.shape[-1], - ) - ct14_rhs = ct14_rhs.reshape( - ct14_rhs.shape[0], - ct14_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_rhs.shape[-1], - ) - if ct14_lhs.shape != ct14_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct14_num_moduli = ct14_lhs.shape[-1] - if hasattr(ct12, "moduli") and hasattr(ct13, "moduli"): - if ( - list(ct12.moduli)[:ct14_num_moduli] - != list(ct13.moduli)[:ct14_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct14_moduli_src = getattr( - ct12, "moduli", getattr(ct13, "moduli", v0.q_towers) - ) - if isinstance(ct14_moduli_src, (int, np.integer)): - ct14_moduli_src = [ct14_moduli_src] - ct14_moduli = jnp.array( - list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 - ) - ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) - ct14 = jnp.where( - ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum - ).astype(jnp.uint32) - ct15_arg = _ensure_poly(v0, ct14, v0.max_level) - ct15 = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) - ct16_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - ct16_rhs = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 - ct16_lhs = ct16_lhs.reshape( - ct16_lhs.shape[0], - ct16_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct16_lhs.shape[-1], - ) - ct16_rhs = ct16_rhs.reshape( - ct16_rhs.shape[0], - ct16_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct16_rhs.shape[-1], - ) - if ct16_lhs.shape != ct16_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct16_num_moduli = ct16_lhs.shape[-1] - if hasattr(ct1, "moduli") and hasattr(ct3, "moduli"): - if list(ct1.moduli)[:ct16_num_moduli] != list(ct3.moduli)[:ct16_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct16_moduli_src = getattr(ct1, "moduli", getattr(ct3, "moduli", v0.q_towers)) - if isinstance(ct16_moduli_src, (int, np.integer)): - ct16_moduli_src = [ct16_moduli_src] - ct16_moduli = jnp.array( - list(ct16_moduli_src)[:ct16_num_moduli], dtype=jnp.uint64 - ) - ct16_sum = ct16_lhs.astype(jnp.uint64) + ct16_rhs.astype(jnp.uint64) - ct16 = jnp.where( - ct16_sum >= ct16_moduli, ct16_sum - ct16_moduli, ct16_sum - ).astype(jnp.uint32) - ct17_lhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - ct17_rhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - ct17_lhs = ct17_lhs.reshape( - ct17_lhs.shape[0], - ct17_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct17_lhs.shape[-1], - ) - ct17_rhs = ct17_rhs.reshape( - ct17_rhs.shape[0], - ct17_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct17_rhs.shape[-1], - ) - if ct17_lhs.shape != ct17_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct17_num_moduli = ct17_lhs.shape[-1] - if hasattr(ct5, "moduli") and hasattr(ct11, "moduli"): - if ( - list(ct5.moduli)[:ct17_num_moduli] - != list(ct11.moduli)[:ct17_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct17_moduli_src = getattr(ct5, "moduli", getattr(ct11, "moduli", v0.q_towers)) - if isinstance(ct17_moduli_src, (int, np.integer)): - ct17_moduli_src = [ct17_moduli_src] - ct17_moduli = jnp.array( - list(ct17_moduli_src)[:ct17_num_moduli], dtype=jnp.uint64 - ) - ct17_sum = ct17_lhs.astype(jnp.uint64) + ct17_rhs.astype(jnp.uint64) - ct17 = jnp.where( - ct17_sum >= ct17_moduli, ct17_sum - ct17_moduli, ct17_sum - ).astype(jnp.uint32) - ct18_lhs = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 - ct18_rhs = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 - ct18_lhs = ct18_lhs.reshape( - ct18_lhs.shape[0], - ct18_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct18_lhs.shape[-1], - ) - ct18_rhs = ct18_rhs.reshape( - ct18_rhs.shape[0], - ct18_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct18_rhs.shape[-1], - ) - if ct18_lhs.shape != ct18_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct18_num_moduli = ct18_lhs.shape[-1] - if hasattr(ct17, "moduli") and hasattr(ct15, "moduli"): - if ( - list(ct17.moduli)[:ct18_num_moduli] - != list(ct15.moduli)[:ct18_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct18_moduli_src = getattr( - ct17, "moduli", getattr(ct15, "moduli", v0.q_towers) - ) - if isinstance(ct18_moduli_src, (int, np.integer)): - ct18_moduli_src = [ct18_moduli_src] - ct18_moduli = jnp.array( - list(ct18_moduli_src)[:ct18_num_moduli], dtype=jnp.uint64 - ) - ct18_sum = ct18_lhs.astype(jnp.uint64) + ct18_rhs.astype(jnp.uint64) - ct18 = jnp.where( - ct18_sum >= ct18_moduli, ct18_sum - ct18_moduli, ct18_sum - ).astype(jnp.uint32) - ct19_lhs = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 - ct19_rhs = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 - ct19_lhs = ct19_lhs.reshape( - ct19_lhs.shape[0], - ct19_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct19_lhs.shape[-1], - ) - ct19_rhs = ct19_rhs.reshape( - ct19_rhs.shape[0], - ct19_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct19_rhs.shape[-1], - ) - if ct19_lhs.shape != ct19_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct19_num_moduli = ct19_lhs.shape[-1] - if hasattr(ct16, "moduli") and hasattr(ct18, "moduli"): - if ( - list(ct16.moduli)[:ct19_num_moduli] - != list(ct18.moduli)[:ct19_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct19_moduli_src = getattr( - ct16, "moduli", getattr(ct18, "moduli", v0.q_towers) - ) - if isinstance(ct19_moduli_src, (int, np.integer)): - ct19_moduli_src = [ct19_moduli_src] - ct19_moduli = jnp.array( - list(ct19_moduli_src)[:ct19_num_moduli], dtype=jnp.uint64 - ) - ct19_sum = ct19_lhs.astype(jnp.uint64) + ct19_rhs.astype(jnp.uint64) - ct19 = jnp.where( - ct19_sum >= ct19_moduli, ct19_sum - ct19_moduli, ct19_sum - ).astype(jnp.uint32) - v16 = [None] * 1 - ct20_arg = _ensure_poly(v0, ct19, v0.max_level) - ct20 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) - v16[0] = ct20 - v17 = v16 - return v17 - - -def matvec_random( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_random__preprocessing(v0, v1) - v11 = matvec_random__preprocessed(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) - return v11 - - -def matvec_random__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - ct = _ensure_poly(v0, ct_raw) - v16 = [ct] - return v16 - - -def matvec_random__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 7 - v8 = 0 - v9 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - pt_ct = _ensure_poly(v0, ct) - _num_moduli = pt_ct.polynomial.shape[-1] - _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": pt_ct.polynomial.shape[0], - "num_elements": pt_ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - pt_ct.polynomial.reshape( - pt_ct.polynomial.shape[0], - pt_ct.polynomial.shape[1], - v0.degree, - _num_moduli, - ) - ) - pt = v0.decrypt(_ct_for_dec) - v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v11 = v9.copy() - for v12 in range(0, 8): - v14 = v7 - v12 - v15 = int(v14) - v16 = v10[0, v15] - v11[v15] = v16 - return v11 - - -def matvec_chain__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> ( - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -): - v2 = np.array( - [ - 1.340000e00, - 1.220000e00, - 1.050000e00, - 1.500000e00, - 1.010000e00, - 8.800000e-01, - 5.000000e-01, - 1.060000e00, - ], - dtype=np.float32, - ) - v3 = np.array( - [ - 5.800000e-01, - 5.200000e-01, - 8.900000e-01, - 8.600000e-01, - 1.170000e00, - 8.200000e-01, - 1.490000e00, - 1.410000e00, - ], - dtype=np.float32, - ) - v4 = np.array( - [ - 1.260000e00, - 1.090000e00, - 1.430000e00, - 1.260000e00, - 6.100000e-01, - 8.500000e-01, - 6.700000e-01, - 7.100000e-01, - ], - dtype=np.float32, - ) - v5 = np.array( - [ - 8.200000e-01, - 1.330000e00, - 7.900000e-01, - 7.400000e-01, - 1.060000e00, - 1.340000e00, - 1.090000e00, - 6.300000e-01, - ], - dtype=np.float32, - ) - v6 = np.array( - [ - 1.160000e00, - 8.400000e-01, - 1.020000e00, - 6.900000e-01, - 6.600000e-01, - 8.600000e-01, - 1.190000e00, - 6.500000e-01, - ], - dtype=np.float32, - ) - v7 = np.array( - [ - 1.350000e00, - 1.050000e00, - 1.400000e00, - 1.070000e00, - 6.500000e-01, - 5.400000e-01, - 8.000000e-01, - 9.000000e-01, - ], - dtype=np.float32, - ) - v8 = np.array( - [ - 8.200000e-01, - 9.000000e-01, - 7.400000e-01, - 1.050000e00, - 1.080000e00, - 1.480000e00, - 6.000000e-01, - 1.200000e00, - ], - dtype=np.float32, - ) - v9 = np.array( - [ - 1.190000e00, - 1.200000e00, - 8.400000e-01, - 1.350000e00, - 1.020000e00, - 7.600000e-01, - 1.390000e00, - 1.130000e00, - ], - dtype=np.float32, - ) - v10 = np.array( - [ - 1.200000e00, - 8.900000e-01, - 1.030000e00, - 7.300000e-01, - 9.300000e-01, - 7.500000e-01, - 8.400000e-01, - 1.170000e00, - ], - dtype=np.float32, - ) - v11 = np.array( - [ - 7.900000e-01, - 8.400000e-01, - 1.030000e00, - 7.900000e-01, - 1.390000e00, - 9.800000e-01, - 8.000000e-01, - 9.200000e-01, - ], - dtype=np.float32, - ) - v12 = np.array( - [ - 7.300000e-01, - 1.230000e00, - 1.130000e00, - 1.130000e00, - 1.440000e00, - 1.490000e00, - 1.020000e00, - 1.180000e00, - ], - dtype=np.float32, - ) - v13 = np.array( - [ - 1.120000e00, - 1.110000e00, - 1.380000e00, - 1.050000e00, - 9.400000e-01, - 1.350000e00, - 5.900000e-01, - 1.000000e00, - ], - dtype=np.float32, - ) - v14 = np.array( - [ - 6.200000e-01, - 6.200000e-01, - 1.010000e00, - 1.220000e00, - 5.600000e-01, - 1.220000e00, - 9.300000e-01, - 9.300000e-01, - ], - dtype=np.float32, - ) - v15 = np.array( - [ - 8.200000e-01, - 1.330000e00, - 1.170000e00, - 9.200000e-01, - 9.000000e-01, - 1.110000e00, - 1.220000e00, - 9.900000e-01, - ], - dtype=np.float32, - ) - v16 = np.array( - [ - 6.800000e-01, - 8.200000e-01, - 9.300000e-01, - 9.100000e-01, - 1.100000e00, - 1.090000e00, - 1.480000e00, - 1.240000e00, - ], - dtype=np.float32, - ) - v17 = np.array( - [ - 6.800000e-01, - 8.600000e-01, - 8.100000e-01, - 1.370000e00, - 1.050000e00, - 1.120000e00, - 1.180000e00, - 9.800000e-01, - ], - dtype=np.float32, - ) - pt = v0.encode(v2) - pt1 = v0.encode(v3) - pt2 = v0.encode(v4) - pt3 = v0.encode(v5) - pt4 = v0.encode(v6) - pt5 = v0.encode(v7) - pt6 = v0.encode(v8) - pt7 = v0.encode(v9) - pt8 = v0.encode(v10) - pt9 = v0.encode(v11) - pt10 = v0.encode(v12) - pt11 = v0.encode(v13) - pt12 = v0.encode(v14) - pt13 = v0.encode(v15) - pt14 = v0.encode(v16) - pt15 = v0.encode(v17) - v18 = [pt] - v19 = [pt1] - v20 = [pt2] - v21 = [pt3] - v22 = [pt4] - v23 = [pt5] - v24 = [pt6] - v25 = [pt7] - v26 = [pt8, pt9] - v27 = [pt10, pt11] - v28 = [pt12, pt13] - v29 = [pt14, pt15] - return (v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) - - -def matvec_chain__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, - v5: np.ndarray, - v6: np.ndarray, - v7: np.ndarray, - v8: np.ndarray, - v9: np.ndarray, - v10: np.ndarray, - v11: np.ndarray, - v12: np.ndarray, - v13: np.ndarray, - v14: np.ndarray, -) -> np.ndarray: - v15 = 1 - v16 = 2 - v17 = 3 - v18 = 6 - v19 = 0 - pt = v3[0] - pt1 = v4[0] - pt2 = v5[0] - pt3 = v6[0] - pt4 = v7[0] - pt5 = v8[0] - pt6 = v9[0] - pt7 = v10[0] - pt8 = v11[0] - pt9 = v11[1] - pt10 = v12[0] - pt11 = v12[1] - pt12 = v13[0] - pt13 = v13[1] - pt14 = v14[0] - pt15 = v14[1] - ct = v2[0] - ct1_arg = _ensure_poly(v0, ct, v0.max_level) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1 = ct1_ptct.mul(ct1_arg, use_bat=False) - ct2_arg = _ensure_poly(v0, ct, v0.max_level) - ct2 = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - ct3_arg = _ensure_poly(v0, ct2, v0.max_level) - ct3_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] - .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct3_ptct = v0.ptct_mul[v0.max_level] - ct3_ptct.set_plaintext(ct3_pt_ntt) - ct3 = ct3_ptct.mul(ct3_arg, use_bat=False) - ct4_arg = _ensure_poly(v0, ct, v0.max_level) - ct4 = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) - ct5_arg = _ensure_poly(v0, ct4, v0.max_level) - ct5_pt_ntt = ( - pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5 = ct5_ptct.mul(ct5_arg, use_bat=False) - ct6_arg = _ensure_poly(v0, ct, v0.max_level) - ct6_pt_ntt = ( - pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] - .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct6_ptct = v0.ptct_mul[v0.max_level] - ct6_ptct.set_plaintext(ct6_pt_ntt) - ct6 = ct6_ptct.mul(ct6_arg, use_bat=False) - ct7_arg = _ensure_poly(v0, ct2, v0.max_level) - ct7_pt_ntt = ( - pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] - .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct7_ptct = v0.ptct_mul[v0.max_level] - ct7_ptct.set_plaintext(ct7_pt_ntt) - ct7 = ct7_ptct.mul(ct7_arg, use_bat=False) - ct8_arg = _ensure_poly(v0, ct4, v0.max_level) - ct8_pt_ntt = ( - pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] - .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct8_ptct = v0.ptct_mul[v0.max_level] - ct8_ptct.set_plaintext(ct8_pt_ntt) - ct8 = ct8_ptct.mul(ct8_arg, use_bat=False) - ct9_lhs = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - ct9_rhs = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 - ct9_lhs = ct9_lhs.reshape( - ct9_lhs.shape[0], - ct9_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct9_lhs.shape[-1], - ) - ct9_rhs = ct9_rhs.reshape( - ct9_rhs.shape[0], - ct9_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct9_rhs.shape[-1], - ) - if ct9_lhs.shape != ct9_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct9_num_moduli = ct9_lhs.shape[-1] - if hasattr(ct6, "moduli") and hasattr(ct7, "moduli"): - if list(ct6.moduli)[:ct9_num_moduli] != list(ct7.moduli)[:ct9_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct9_moduli_src = getattr(ct6, "moduli", getattr(ct7, "moduli", v0.q_towers)) - if isinstance(ct9_moduli_src, (int, np.integer)): - ct9_moduli_src = [ct9_moduli_src] - ct9_moduli = jnp.array( - list(ct9_moduli_src)[:ct9_num_moduli], dtype=jnp.uint64 - ) - ct9_sum = ct9_lhs.astype(jnp.uint64) + ct9_rhs.astype(jnp.uint64) - ct9 = jnp.where(ct9_sum >= ct9_moduli, ct9_sum - ct9_moduli, ct9_sum).astype( - jnp.uint32 - ) - ct10_lhs = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - ct10_rhs = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - ct10_lhs = ct10_lhs.reshape( - ct10_lhs.shape[0], - ct10_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct10_lhs.shape[-1], - ) - ct10_rhs = ct10_rhs.reshape( - ct10_rhs.shape[0], - ct10_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct10_rhs.shape[-1], - ) - if ct10_lhs.shape != ct10_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct10_num_moduli = ct10_lhs.shape[-1] - if hasattr(ct9, "moduli") and hasattr(ct8, "moduli"): - if list(ct9.moduli)[:ct10_num_moduli] != list(ct8.moduli)[:ct10_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct10_moduli_src = getattr(ct9, "moduli", getattr(ct8, "moduli", v0.q_towers)) - if isinstance(ct10_moduli_src, (int, np.integer)): - ct10_moduli_src = [ct10_moduli_src] - ct10_moduli = jnp.array( - list(ct10_moduli_src)[:ct10_num_moduli], dtype=jnp.uint64 - ) - ct10_sum = ct10_lhs.astype(jnp.uint64) + ct10_rhs.astype(jnp.uint64) - ct10 = jnp.where( - ct10_sum >= ct10_moduli, ct10_sum - ct10_moduli, ct10_sum - ).astype(jnp.uint32) - ct11_arg = _ensure_poly(v0, ct10, v0.max_level) - ct11 = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) - ct12_arg = _ensure_poly(v0, ct, v0.max_level) - ct12_pt_ntt = ( - pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] - .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct12_ptct = v0.ptct_mul[v0.max_level] - ct12_ptct.set_plaintext(ct12_pt_ntt) - ct12 = ct12_ptct.mul(ct12_arg, use_bat=False) - ct13_arg = _ensure_poly(v0, ct2, v0.max_level) - ct13_pt_ntt = ( - pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] - .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct13_ptct = v0.ptct_mul[v0.max_level] - ct13_ptct.set_plaintext(ct13_pt_ntt) - ct13 = ct13_ptct.mul(ct13_arg, use_bat=False) - ct14_lhs = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - ct14_rhs = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - ct14_lhs = ct14_lhs.reshape( - ct14_lhs.shape[0], - ct14_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_lhs.shape[-1], - ) - ct14_rhs = ct14_rhs.reshape( - ct14_rhs.shape[0], - ct14_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct14_rhs.shape[-1], - ) - if ct14_lhs.shape != ct14_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct14_num_moduli = ct14_lhs.shape[-1] - if hasattr(ct12, "moduli") and hasattr(ct13, "moduli"): - if ( - list(ct12.moduli)[:ct14_num_moduli] - != list(ct13.moduli)[:ct14_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct14_moduli_src = getattr( - ct12, "moduli", getattr(ct13, "moduli", v0.q_towers) - ) - if isinstance(ct14_moduli_src, (int, np.integer)): - ct14_moduli_src = [ct14_moduli_src] - ct14_moduli = jnp.array( - list(ct14_moduli_src)[:ct14_num_moduli], dtype=jnp.uint64 - ) - ct14_sum = ct14_lhs.astype(jnp.uint64) + ct14_rhs.astype(jnp.uint64) - ct14 = jnp.where( - ct14_sum >= ct14_moduli, ct14_sum - ct14_moduli, ct14_sum - ).astype(jnp.uint32) - ct15_arg = _ensure_poly(v0, ct14, v0.max_level) - ct15 = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) - ct16_lhs = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - ct16_rhs = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 - ct16_lhs = ct16_lhs.reshape( - ct16_lhs.shape[0], - ct16_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct16_lhs.shape[-1], - ) - ct16_rhs = ct16_rhs.reshape( - ct16_rhs.shape[0], - ct16_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct16_rhs.shape[-1], - ) - if ct16_lhs.shape != ct16_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct16_num_moduli = ct16_lhs.shape[-1] - if hasattr(ct1, "moduli") and hasattr(ct3, "moduli"): - if list(ct1.moduli)[:ct16_num_moduli] != list(ct3.moduli)[:ct16_num_moduli]: - raise ValueError("ciphertext add modulus mismatch") - ct16_moduli_src = getattr(ct1, "moduli", getattr(ct3, "moduli", v0.q_towers)) - if isinstance(ct16_moduli_src, (int, np.integer)): - ct16_moduli_src = [ct16_moduli_src] - ct16_moduli = jnp.array( - list(ct16_moduli_src)[:ct16_num_moduli], dtype=jnp.uint64 - ) - ct16_sum = ct16_lhs.astype(jnp.uint64) + ct16_rhs.astype(jnp.uint64) - ct16 = jnp.where( - ct16_sum >= ct16_moduli, ct16_sum - ct16_moduli, ct16_sum - ).astype(jnp.uint32) - ct17_lhs = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - ct17_rhs = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - ct17_lhs = ct17_lhs.reshape( - ct17_lhs.shape[0], - ct17_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct17_lhs.shape[-1], - ) - ct17_rhs = ct17_rhs.reshape( - ct17_rhs.shape[0], - ct17_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct17_rhs.shape[-1], - ) - if ct17_lhs.shape != ct17_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct17_num_moduli = ct17_lhs.shape[-1] - if hasattr(ct5, "moduli") and hasattr(ct11, "moduli"): - if ( - list(ct5.moduli)[:ct17_num_moduli] - != list(ct11.moduli)[:ct17_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct17_moduli_src = getattr(ct5, "moduli", getattr(ct11, "moduli", v0.q_towers)) - if isinstance(ct17_moduli_src, (int, np.integer)): - ct17_moduli_src = [ct17_moduli_src] - ct17_moduli = jnp.array( - list(ct17_moduli_src)[:ct17_num_moduli], dtype=jnp.uint64 - ) - ct17_sum = ct17_lhs.astype(jnp.uint64) + ct17_rhs.astype(jnp.uint64) - ct17 = jnp.where( - ct17_sum >= ct17_moduli, ct17_sum - ct17_moduli, ct17_sum - ).astype(jnp.uint32) - ct18_lhs = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 - ct18_rhs = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 - ct18_lhs = ct18_lhs.reshape( - ct18_lhs.shape[0], - ct18_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct18_lhs.shape[-1], - ) - ct18_rhs = ct18_rhs.reshape( - ct18_rhs.shape[0], - ct18_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct18_rhs.shape[-1], - ) - if ct18_lhs.shape != ct18_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct18_num_moduli = ct18_lhs.shape[-1] - if hasattr(ct17, "moduli") and hasattr(ct15, "moduli"): - if ( - list(ct17.moduli)[:ct18_num_moduli] - != list(ct15.moduli)[:ct18_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct18_moduli_src = getattr( - ct17, "moduli", getattr(ct15, "moduli", v0.q_towers) - ) - if isinstance(ct18_moduli_src, (int, np.integer)): - ct18_moduli_src = [ct18_moduli_src] - ct18_moduli = jnp.array( - list(ct18_moduli_src)[:ct18_num_moduli], dtype=jnp.uint64 - ) - ct18_sum = ct18_lhs.astype(jnp.uint64) + ct18_rhs.astype(jnp.uint64) - ct18 = jnp.where( - ct18_sum >= ct18_moduli, ct18_sum - ct18_moduli, ct18_sum - ).astype(jnp.uint32) - ct19_lhs = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 - ct19_rhs = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 - ct19_lhs = ct19_lhs.reshape( - ct19_lhs.shape[0], - ct19_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct19_lhs.shape[-1], - ) - ct19_rhs = ct19_rhs.reshape( - ct19_rhs.shape[0], - ct19_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct19_rhs.shape[-1], - ) - if ct19_lhs.shape != ct19_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct19_num_moduli = ct19_lhs.shape[-1] - if hasattr(ct16, "moduli") and hasattr(ct18, "moduli"): - if ( - list(ct16.moduli)[:ct19_num_moduli] - != list(ct18.moduli)[:ct19_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct19_moduli_src = getattr( - ct16, "moduli", getattr(ct18, "moduli", v0.q_towers) - ) - if isinstance(ct19_moduli_src, (int, np.integer)): - ct19_moduli_src = [ct19_moduli_src] - ct19_moduli = jnp.array( - list(ct19_moduli_src)[:ct19_num_moduli], dtype=jnp.uint64 - ) - ct19_sum = ct19_lhs.astype(jnp.uint64) + ct19_rhs.astype(jnp.uint64) - ct19 = jnp.where( - ct19_sum >= ct19_moduli, ct19_sum - ct19_moduli, ct19_sum - ).astype(jnp.uint32) - ct20_arg = _ensure_poly(v0, ct19, v0.max_level) - ct20 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) - ct21_arg = _ensure_poly(v0, ct20, v0.max_level - 1) - ct21_pt_ntt = ( - pt8.polynomial[0, 0, :, : ct21_arg.polynomial.shape[-1]] - .reshape(ct21_arg.r, ct21_arg.c, ct21_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct21_ptct = v0.ptct_mul[v0.max_level - 1] - ct21_ptct.set_plaintext(ct21_pt_ntt) - ct21 = ct21_ptct.mul(ct21_arg, use_bat=False) - ct22_arg = _ensure_poly(v0, ct19, v0.max_level) - ct22 = v0.he_rot[v0.max_level, 1].rotate(ct22_arg) - ct23_arg = _ensure_poly(v0, ct22, v0.max_level) - ct23 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct23_arg) - ct24_arg = _ensure_poly(v0, ct23, v0.max_level - 1) - ct24_pt_ntt = ( - pt9.polynomial[0, 0, :, : ct24_arg.polynomial.shape[-1]] - .reshape(ct24_arg.r, ct24_arg.c, ct24_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct24_ptct = v0.ptct_mul[v0.max_level - 1] - ct24_ptct.set_plaintext(ct24_pt_ntt) - ct24 = ct24_ptct.mul(ct24_arg, use_bat=False) - ct25_arg = _ensure_poly(v0, ct19, v0.max_level) - ct25 = v0.he_rot[v0.max_level, 2].rotate(ct25_arg) - ct26_arg = _ensure_poly(v0, ct25, v0.max_level) - ct26 = v0.he_rescale[v0.max_level, v0.max_level - 1](ct26_arg) - ct27_arg = _ensure_poly(v0, ct26, v0.max_level - 1) - ct27_pt_ntt = ( - pt10.polynomial[0, 0, :, : ct27_arg.polynomial.shape[-1]] - .reshape(ct27_arg.r, ct27_arg.c, ct27_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct27_ptct = v0.ptct_mul[v0.max_level - 1] - ct27_ptct.set_plaintext(ct27_pt_ntt) - ct27 = ct27_ptct.mul(ct27_arg, use_bat=False) - ct28_arg = _ensure_poly(v0, ct20, v0.max_level - 1) - ct28_pt_ntt = ( - pt11.polynomial[0, 0, :, : ct28_arg.polynomial.shape[-1]] - .reshape(ct28_arg.r, ct28_arg.c, ct28_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct28_ptct = v0.ptct_mul[v0.max_level - 1] - ct28_ptct.set_plaintext(ct28_pt_ntt) - ct28 = ct28_ptct.mul(ct28_arg, use_bat=False) - ct29_arg = _ensure_poly(v0, ct23, v0.max_level - 1) - ct29_pt_ntt = ( - pt12.polynomial[0, 0, :, : ct29_arg.polynomial.shape[-1]] - .reshape(ct29_arg.r, ct29_arg.c, ct29_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct29_ptct = v0.ptct_mul[v0.max_level - 1] - ct29_ptct.set_plaintext(ct29_pt_ntt) - ct29 = ct29_ptct.mul(ct29_arg, use_bat=False) - ct30_arg = _ensure_poly(v0, ct26, v0.max_level - 1) - ct30_pt_ntt = ( - pt13.polynomial[0, 0, :, : ct30_arg.polynomial.shape[-1]] - .reshape(ct30_arg.r, ct30_arg.c, ct30_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct30_ptct = v0.ptct_mul[v0.max_level - 1] - ct30_ptct.set_plaintext(ct30_pt_ntt) - ct30 = ct30_ptct.mul(ct30_arg, use_bat=False) - ct31_lhs = ct28.polynomial if hasattr(ct28, "polynomial") else ct28 - ct31_rhs = ct29.polynomial if hasattr(ct29, "polynomial") else ct29 - ct31_lhs = ct31_lhs.reshape( - ct31_lhs.shape[0], - ct31_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct31_lhs.shape[-1], - ) - ct31_rhs = ct31_rhs.reshape( - ct31_rhs.shape[0], - ct31_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct31_rhs.shape[-1], - ) - if ct31_lhs.shape != ct31_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct31_num_moduli = ct31_lhs.shape[-1] - if hasattr(ct28, "moduli") and hasattr(ct29, "moduli"): - if ( - list(ct28.moduli)[:ct31_num_moduli] - != list(ct29.moduli)[:ct31_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct31_moduli_src = getattr( - ct28, "moduli", getattr(ct29, "moduli", v0.q_towers) - ) - if isinstance(ct31_moduli_src, (int, np.integer)): - ct31_moduli_src = [ct31_moduli_src] - ct31_moduli = jnp.array( - list(ct31_moduli_src)[:ct31_num_moduli], dtype=jnp.uint64 - ) - ct31_sum = ct31_lhs.astype(jnp.uint64) + ct31_rhs.astype(jnp.uint64) - ct31 = jnp.where( - ct31_sum >= ct31_moduli, ct31_sum - ct31_moduli, ct31_sum - ).astype(jnp.uint32) - ct32_lhs = ct31.polynomial if hasattr(ct31, "polynomial") else ct31 - ct32_rhs = ct30.polynomial if hasattr(ct30, "polynomial") else ct30 - ct32_lhs = ct32_lhs.reshape( - ct32_lhs.shape[0], - ct32_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct32_lhs.shape[-1], - ) - ct32_rhs = ct32_rhs.reshape( - ct32_rhs.shape[0], - ct32_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct32_rhs.shape[-1], - ) - if ct32_lhs.shape != ct32_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct32_num_moduli = ct32_lhs.shape[-1] - if hasattr(ct31, "moduli") and hasattr(ct30, "moduli"): - if ( - list(ct31.moduli)[:ct32_num_moduli] - != list(ct30.moduli)[:ct32_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct32_moduli_src = getattr( - ct31, "moduli", getattr(ct30, "moduli", v0.q_towers) - ) - if isinstance(ct32_moduli_src, (int, np.integer)): - ct32_moduli_src = [ct32_moduli_src] - ct32_moduli = jnp.array( - list(ct32_moduli_src)[:ct32_num_moduli], dtype=jnp.uint64 - ) - ct32_sum = ct32_lhs.astype(jnp.uint64) + ct32_rhs.astype(jnp.uint64) - ct32 = jnp.where( - ct32_sum >= ct32_moduli, ct32_sum - ct32_moduli, ct32_sum - ).astype(jnp.uint32) - ct33_arg = _ensure_poly(v0, ct32, v0.max_level - 1) - ct33 = v0.he_rot[v0.max_level - 1, 3].rotate(ct33_arg) - ct34_arg = _ensure_poly(v0, ct20, v0.max_level - 1) - ct34_pt_ntt = ( - pt14.polynomial[0, 0, :, : ct34_arg.polynomial.shape[-1]] - .reshape(ct34_arg.r, ct34_arg.c, ct34_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct34_ptct = v0.ptct_mul[v0.max_level - 1] - ct34_ptct.set_plaintext(ct34_pt_ntt) - ct34 = ct34_ptct.mul(ct34_arg, use_bat=False) - ct35_arg = _ensure_poly(v0, ct23, v0.max_level - 1) - ct35_pt_ntt = ( - pt15.polynomial[0, 0, :, : ct35_arg.polynomial.shape[-1]] - .reshape(ct35_arg.r, ct35_arg.c, ct35_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct35_ptct = v0.ptct_mul[v0.max_level - 1] - ct35_ptct.set_plaintext(ct35_pt_ntt) - ct35 = ct35_ptct.mul(ct35_arg, use_bat=False) - ct36_lhs = ct34.polynomial if hasattr(ct34, "polynomial") else ct34 - ct36_rhs = ct35.polynomial if hasattr(ct35, "polynomial") else ct35 - ct36_lhs = ct36_lhs.reshape( - ct36_lhs.shape[0], - ct36_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct36_lhs.shape[-1], - ) - ct36_rhs = ct36_rhs.reshape( - ct36_rhs.shape[0], - ct36_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct36_rhs.shape[-1], - ) - if ct36_lhs.shape != ct36_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct36_num_moduli = ct36_lhs.shape[-1] - if hasattr(ct34, "moduli") and hasattr(ct35, "moduli"): - if ( - list(ct34.moduli)[:ct36_num_moduli] - != list(ct35.moduli)[:ct36_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct36_moduli_src = getattr( - ct34, "moduli", getattr(ct35, "moduli", v0.q_towers) - ) - if isinstance(ct36_moduli_src, (int, np.integer)): - ct36_moduli_src = [ct36_moduli_src] - ct36_moduli = jnp.array( - list(ct36_moduli_src)[:ct36_num_moduli], dtype=jnp.uint64 - ) - ct36_sum = ct36_lhs.astype(jnp.uint64) + ct36_rhs.astype(jnp.uint64) - ct36 = jnp.where( - ct36_sum >= ct36_moduli, ct36_sum - ct36_moduli, ct36_sum - ).astype(jnp.uint32) - ct37_arg = _ensure_poly(v0, ct36, v0.max_level - 1) - ct37 = v0.he_rot[v0.max_level - 1, 6].rotate(ct37_arg) - ct38_lhs = ct21.polynomial if hasattr(ct21, "polynomial") else ct21 - ct38_rhs = ct24.polynomial if hasattr(ct24, "polynomial") else ct24 - ct38_lhs = ct38_lhs.reshape( - ct38_lhs.shape[0], - ct38_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct38_lhs.shape[-1], - ) - ct38_rhs = ct38_rhs.reshape( - ct38_rhs.shape[0], - ct38_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct38_rhs.shape[-1], - ) - if ct38_lhs.shape != ct38_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct38_num_moduli = ct38_lhs.shape[-1] - if hasattr(ct21, "moduli") and hasattr(ct24, "moduli"): - if ( - list(ct21.moduli)[:ct38_num_moduli] - != list(ct24.moduli)[:ct38_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct38_moduli_src = getattr( - ct21, "moduli", getattr(ct24, "moduli", v0.q_towers) - ) - if isinstance(ct38_moduli_src, (int, np.integer)): - ct38_moduli_src = [ct38_moduli_src] - ct38_moduli = jnp.array( - list(ct38_moduli_src)[:ct38_num_moduli], dtype=jnp.uint64 - ) - ct38_sum = ct38_lhs.astype(jnp.uint64) + ct38_rhs.astype(jnp.uint64) - ct38 = jnp.where( - ct38_sum >= ct38_moduli, ct38_sum - ct38_moduli, ct38_sum - ).astype(jnp.uint32) - ct39_lhs = ct27.polynomial if hasattr(ct27, "polynomial") else ct27 - ct39_rhs = ct33.polynomial if hasattr(ct33, "polynomial") else ct33 - ct39_lhs = ct39_lhs.reshape( - ct39_lhs.shape[0], - ct39_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct39_lhs.shape[-1], - ) - ct39_rhs = ct39_rhs.reshape( - ct39_rhs.shape[0], - ct39_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct39_rhs.shape[-1], - ) - if ct39_lhs.shape != ct39_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct39_num_moduli = ct39_lhs.shape[-1] - if hasattr(ct27, "moduli") and hasattr(ct33, "moduli"): - if ( - list(ct27.moduli)[:ct39_num_moduli] - != list(ct33.moduli)[:ct39_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct39_moduli_src = getattr( - ct27, "moduli", getattr(ct33, "moduli", v0.q_towers) - ) - if isinstance(ct39_moduli_src, (int, np.integer)): - ct39_moduli_src = [ct39_moduli_src] - ct39_moduli = jnp.array( - list(ct39_moduli_src)[:ct39_num_moduli], dtype=jnp.uint64 - ) - ct39_sum = ct39_lhs.astype(jnp.uint64) + ct39_rhs.astype(jnp.uint64) - ct39 = jnp.where( - ct39_sum >= ct39_moduli, ct39_sum - ct39_moduli, ct39_sum - ).astype(jnp.uint32) - ct40_lhs = ct39.polynomial if hasattr(ct39, "polynomial") else ct39 - ct40_rhs = ct37.polynomial if hasattr(ct37, "polynomial") else ct37 - ct40_lhs = ct40_lhs.reshape( - ct40_lhs.shape[0], - ct40_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct40_lhs.shape[-1], - ) - ct40_rhs = ct40_rhs.reshape( - ct40_rhs.shape[0], - ct40_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct40_rhs.shape[-1], - ) - if ct40_lhs.shape != ct40_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct40_num_moduli = ct40_lhs.shape[-1] - if hasattr(ct39, "moduli") and hasattr(ct37, "moduli"): - if ( - list(ct39.moduli)[:ct40_num_moduli] - != list(ct37.moduli)[:ct40_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct40_moduli_src = getattr( - ct39, "moduli", getattr(ct37, "moduli", v0.q_towers) - ) - if isinstance(ct40_moduli_src, (int, np.integer)): - ct40_moduli_src = [ct40_moduli_src] - ct40_moduli = jnp.array( - list(ct40_moduli_src)[:ct40_num_moduli], dtype=jnp.uint64 - ) - ct40_sum = ct40_lhs.astype(jnp.uint64) + ct40_rhs.astype(jnp.uint64) - ct40 = jnp.where( - ct40_sum >= ct40_moduli, ct40_sum - ct40_moduli, ct40_sum - ).astype(jnp.uint32) - ct41_lhs = ct38.polynomial if hasattr(ct38, "polynomial") else ct38 - ct41_rhs = ct40.polynomial if hasattr(ct40, "polynomial") else ct40 - ct41_lhs = ct41_lhs.reshape( - ct41_lhs.shape[0], - ct41_lhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct41_lhs.shape[-1], - ) - ct41_rhs = ct41_rhs.reshape( - ct41_rhs.shape[0], - ct41_rhs.shape[1], - v0._param_cache.r, - v0._param_cache.c, - ct41_rhs.shape[-1], - ) - if ct41_lhs.shape != ct41_rhs.shape: - raise ValueError("ciphertext add shape mismatch") - ct41_num_moduli = ct41_lhs.shape[-1] - if hasattr(ct38, "moduli") and hasattr(ct40, "moduli"): - if ( - list(ct38.moduli)[:ct41_num_moduli] - != list(ct40.moduli)[:ct41_num_moduli] - ): - raise ValueError("ciphertext add modulus mismatch") - ct41_moduli_src = getattr( - ct38, "moduli", getattr(ct40, "moduli", v0.q_towers) - ) - if isinstance(ct41_moduli_src, (int, np.integer)): - ct41_moduli_src = [ct41_moduli_src] - ct41_moduli = jnp.array( - list(ct41_moduli_src)[:ct41_num_moduli], dtype=jnp.uint64 - ) - ct41_sum = ct41_lhs.astype(jnp.uint64) + ct41_rhs.astype(jnp.uint64) - ct41 = jnp.where( - ct41_sum >= ct41_moduli, ct41_sum - ct41_moduli, ct41_sum - ).astype(jnp.uint32) - v20 = [None] * 1 - ct42_arg = _ensure_poly(v0, ct41, v0.max_level - 1) - ct42 = v0.he_rescale[v0.max_level - 1, v0.max_level - 2](ct42_arg) - v20[0] = ct42 - v21 = v20 - return v21 - - -def matvec_chain( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) = ( - matvec_chain__preprocessing(v0, v1) - ) - v15 = matvec_chain__preprocessed( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 - ) - return v15 - - -def matvec_chain__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - ct = _ensure_poly(v0, ct_raw) - v16 = [ct] - return v16 - - -def matvec_chain__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 7 - v8 = 0 - v9 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - pt_ct = _ensure_poly(v0, ct) - _num_moduli = pt_ct.polynomial.shape[-1] - _q_sub = list(getattr(pt_ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": pt_ct.polynomial.shape[0], - "num_elements": pt_ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - pt_ct.polynomial.reshape( - pt_ct.polynomial.shape[0], - pt_ct.polynomial.shape[1], - v0.degree, - _num_moduli, - ) - ) - pt = v0.decrypt(_ct_for_dec) - v10 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v11 = v9.copy() - for v12 in range(0, 8): - v14 = v7 - v12 - v15 = int(v14) - v16 = v10[0, v15] - v11[v15] = v16 - return v11 - - -def matvec_identity__generate_crypto_context( - v0: np.ndarray, - v1: np.ndarray, - v2: dict, -) -> ckks.CKKSContext: - params = { - "degree": 16, - "num_slots": 8, - "batch": 1, - "r": 4, - "c": 4, - "dnum": 3, - "numEvalMult": 1, - "scaling_factor": 35184372088832, - "q_towers": [1073742881, 1073742721, 1073741441, 1073741857, 524353], - "p_towers": [1073740609, 1073739937, 1073739649], - "composite_degree": 1, - "p": 30, - "max_bits_in_word": 61, - "max_bits_value": 9223372036854775295, - "noise_scale_degree": 1, - "CKKS_M_FACTOR": 1, - "public_key": v0, - "secret_key": v1, - "evaluation_key": v2, - } - v3 = ckks.CKKSContext(params) - return v3 - - -def matvec_identity__configure_crypto_context( - v0: ckks.CKKSContext, -): - v0.program_initialization( - total_hemul_levels=1, - total_rotation_indices=[1, 2, 3, 6], - dnum=3, - r=4, - c=4, - batch=1, - ) diff --git a/matvec_8x8_jaxite.mlir b/matvec_8x8_jaxite.mlir deleted file mode 100644 index 2dc3ac06d1..0000000000 --- a/matvec_8x8_jaxite.mlir +++ /dev/null @@ -1,465 +0,0 @@ -!Z1073741441_i64 = !mod_arith.int<1073741441 : i64> -!Z1073742721_i64 = !mod_arith.int<1073742721 : i64> -!Z1073742881_i64 = !mod_arith.int<1073742881 : i64> -#inverse_canonical_encoding = #lwe.inverse_canonical_encoding -#inverse_canonical_encoding1 = #lwe.inverse_canonical_encoding -#inverse_canonical_encoding2 = #lwe.inverse_canonical_encoding -#inverse_canonical_encoding3 = #lwe.inverse_canonical_encoding -#key = #lwe.key<> -#layout = #tensor_ext.layout<"{ [i0] -> [ct, slot] : ct = 0 and (-i0 + slot) mod 8 = 0 and 0 <= i0 <= 7 and 0 <= slot <= 7 }"> -#modulus_chain_L4_C0 = #lwe.modulus_chain, current = 0> -#modulus_chain_L4_C1 = #lwe.modulus_chain, current = 1> -#modulus_chain_L4_C2 = #lwe.modulus_chain, current = 2> -#ring_f64_1_x8 = #polynomial.ring> -!rns_L0 = !rns.rns -!rns_L1 = !rns.rns -!rns_L2 = !rns.rns -#original_type = #tensor_ext.original_type, layout = #layout> -!pt = !lwe.lwe_plaintext> -!pt1 = !lwe.lwe_plaintext> -!pt2 = !lwe.lwe_plaintext> -#ring_rns_L0_1_x8 = #polynomial.ring> -#ring_rns_L1_1_x8 = #polynomial.ring> -#ring_rns_L2_1_x8 = #polynomial.ring> -#ciphertext_space_L0 = #lwe.ciphertext_space -#ciphertext_space_L1 = #lwe.ciphertext_space -#ciphertext_space_L2 = #lwe.ciphertext_space -!ct_L0 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L0, key = #key, modulus_chain = #modulus_chain_L4_C0> -!ct_L1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L4_C1> -!ct_L1_1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L4_C1> -!ct_L2 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L4_C2> -!ct_L2_1 = !lwe.lwe_ciphertext, ciphertext_space = #ciphertext_space_L2, key = #key, modulus_chain = #modulus_chain_L4_C2> -module attributes {scheme.ckks} { - func.func @matvec_identity__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) attributes {client.pack_func = {func_name = "matvec_identity"}} { - %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<8xf32> - %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_1 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %from_elements = tensor.from_elements %pt : tensor<1x!pt> - %from_elements_2 = tensor.from_elements %pt_1 : tensor<1x!pt> - return %from_elements, %from_elements_2 : tensor<1x!pt>, tensor<1x!pt> - } - func.func @matvec_identity__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>) -> tensor<1x!ct_L1> attributes {client.preprocessed_func = {func_name = "matvec_identity"}} { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c6 = arith.constant 6 : index - %c0 = arith.constant 0 : index - %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> - %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> - %extracted_1 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> - %ct = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_2 = jaxiteword.mul_plain %arg0, %ct, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_3 = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_4 = jaxiteword.mul_plain %arg0, %ct_3, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_5 = jaxiteword.mul_plain %arg0, %extracted_1, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_6 = jaxiteword.add %arg0, %ct_5, %ct_2 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_7 = jaxiteword.add %arg0, %ct_6, %ct_4 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_8 = jaxiteword.rot %arg0, %ct_7, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_9 = jaxiteword.rot %arg0, %ct_6, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_10 = jaxiteword.mul_plain %arg0, %extracted_1, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_11 = jaxiteword.add %arg0, %ct_10, %ct_2 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_12 = jaxiteword.add %arg0, %ct_4, %ct_8 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_13 = jaxiteword.add %arg0, %ct_12, %ct_9 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_14 = jaxiteword.add %arg0, %ct_11, %ct_13 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %0 = tensor.empty() : tensor<1x!ct_L1> - %ct_15 = jaxiteword.mod_reduce %arg0, %ct_14 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 - %inserted = tensor.insert %ct_15 into %0[%c0] : tensor<1x!ct_L1> - return %inserted : tensor<1x!ct_L1> - } - func.func @matvec_identity(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) { - %0:2 = call @matvec_identity__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) - %1 = call @matvec_identity__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>) -> tensor<1x!ct_L1> - return %1 : tensor<1x!ct_L1> - } - func.func @matvec_identity__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_identity", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { - %1 = arith.index_cast %arg4 : i32 to index - %extracted = tensor.extract %arg2[%1] : tensor<8xf32> - %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> - scf.yield %inserted : tensor<1x8xf32> - } - %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> - %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 - %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> - return %from_elements : tensor<1x!ct_L2> - } - func.func @matvec_identity__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L1>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_identity", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %c8_i32 = arith.constant 8 : i32 - %c1_i32 = arith.constant 1 : i32 - %c7_i32 = arith.constant 7 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> - %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L1> - %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L1, !jaxiteword.private_key<>) -> !pt1 - %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt1) -> tensor<1x8xf32> - %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { - %2 = arith.subi %c7_i32, %arg4 : i32 - %3 = arith.index_cast %2 : i32 to index - %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> - %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> - scf.yield %inserted : tensor<8xf32> - } - return %1 : tensor<8xf32> - } - func.func @matvec_shift__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) attributes {client.pack_func = {func_name = "matvec_shift"}} { - %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> - %cst_0 = arith.constant dense<1.000000e+00> : tensor<8xf32> - %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_1 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %from_elements = tensor.from_elements %pt : tensor<1x!pt> - %from_elements_2 = tensor.from_elements %pt_1 : tensor<1x!pt> - return %from_elements, %from_elements_2 : tensor<1x!pt>, tensor<1x!pt> - } - func.func @matvec_shift__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>) -> tensor<1x!ct_L1> attributes {client.preprocessed_func = {func_name = "matvec_shift"}} { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c6 = arith.constant 6 : index - %c0 = arith.constant 0 : index - %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> - %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> - %extracted_1 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> - %ct = jaxiteword.mul_plain %arg0, %extracted_1, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_2 = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_3 = jaxiteword.rot %arg0, %extracted_1, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_4 = jaxiteword.mul_plain %arg0, %ct_3, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_5 = jaxiteword.mul_plain %arg0, %ct_2, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_6 = jaxiteword.add %arg0, %ct, %ct_5 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_7 = jaxiteword.add %arg0, %ct_6, %ct_4 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_8 = jaxiteword.rot %arg0, %ct_7, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_9 = jaxiteword.rot %arg0, %ct_6, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_10 = jaxiteword.mul_plain %arg0, %ct_2, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_11 = jaxiteword.add %arg0, %ct, %ct_10 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_12 = jaxiteword.add %arg0, %ct_4, %ct_8 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_13 = jaxiteword.add %arg0, %ct_12, %ct_9 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_14 = jaxiteword.add %arg0, %ct_11, %ct_13 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %0 = tensor.empty() : tensor<1x!ct_L1> - %ct_15 = jaxiteword.mod_reduce %arg0, %ct_14 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 - %inserted = tensor.insert %ct_15 into %0[%c0] : tensor<1x!ct_L1> - return %inserted : tensor<1x!ct_L1> - } - func.func @matvec_shift(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) { - %0:2 = call @matvec_shift__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>) - %1 = call @matvec_shift__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>) -> tensor<1x!ct_L1> - return %1 : tensor<1x!ct_L1> - } - func.func @matvec_shift__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_shift", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { - %1 = arith.index_cast %arg4 : i32 to index - %extracted = tensor.extract %arg2[%1] : tensor<8xf32> - %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> - scf.yield %inserted : tensor<1x8xf32> - } - %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> - %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 - %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> - return %from_elements : tensor<1x!ct_L2> - } - func.func @matvec_shift__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L1>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_shift", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %c8_i32 = arith.constant 8 : i32 - %c1_i32 = arith.constant 1 : i32 - %c7_i32 = arith.constant 7 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> - %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L1> - %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L1, !jaxiteword.private_key<>) -> !pt1 - %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt1) -> tensor<1x8xf32> - %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { - %2 = arith.subi %c7_i32, %arg4 : i32 - %3 = arith.index_cast %2 : i32 to index - %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> - %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> - scf.yield %inserted : tensor<8xf32> - } - return %1 : tensor<8xf32> - } - func.func @matvec_random__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>) attributes {client.pack_func = {func_name = "matvec_random"}} { - %cst = arith.constant dense<[0.811626255, 1.44533789, 0.920695543, 1.07704544, 0.678766131, 1.3587923, 1.236010e+00, 0.777831316]> : tensor<8xf32> - %cst_0 = arith.constant dense<[1.90635717, 0.139110535, 0.653335392, 1.22558773, 0.285577029, 6.922510e-01, 1.85156107, 0.268135756]> : tensor<8xf32> - %cst_1 = arith.constant dense<[1.49078846, 1.94282877, 1.26252055, 0.188255787, 1.40004277, 1.08812928, 1.13874948, 0.472367436]> : tensor<8xf32> - %cst_2 = arith.constant dense<[0.331872642, 0.451223463, 0.185931846, 1.23745108, 1.68164098, 0.365038335, 1.25433517, 0.936289727]> : tensor<8xf32> - %cst_3 = arith.constant dense<[1.0408361, 1.94221079, 0.718127608, 0.39643541, 0.503444314, 0.655074835, 0.423995823, 0.223598033]> : tensor<8xf32> - %cst_4 = arith.constant dense<[0.165338188, 1.57275236, 0.83848685, 0.396389604, 0.445467442, 0.796087503, 0.966532945, 1.90288258]> : tensor<8xf32> - %cst_5 = arith.constant dense<[0.678060233, 1.59183431, 1.93470085, 1.82770872, 1.88504803, 0.615563154, 0.210358858, 0.448468566]> : tensor<8xf32> - %cst_6 = arith.constant dense<[1.0970372, 0.47938019, 1.63595498, 0.591681957, 1.80017197, 1.67460132, 1.74573469, 1.24211848]> : tensor<8xf32> - %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_7 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_8 = jaxiteword.encode %arg0, %cst_1 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_9 = jaxiteword.encode %arg0, %cst_2 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_10 = jaxiteword.encode %arg0, %cst_3 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_11 = jaxiteword.encode %arg0, %cst_4 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_12 = jaxiteword.encode %arg0, %cst_5 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_13 = jaxiteword.encode %arg0, %cst_6 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %from_elements = tensor.from_elements %pt : tensor<1x!pt> - %from_elements_14 = tensor.from_elements %pt_7 : tensor<1x!pt> - %from_elements_15 = tensor.from_elements %pt_8 : tensor<1x!pt> - %from_elements_16 = tensor.from_elements %pt_9 : tensor<1x!pt> - %from_elements_17 = tensor.from_elements %pt_10 : tensor<1x!pt> - %from_elements_18 = tensor.from_elements %pt_11 : tensor<1x!pt> - %from_elements_19 = tensor.from_elements %pt_12 : tensor<1x!pt> - %from_elements_20 = tensor.from_elements %pt_13 : tensor<1x!pt> - return %from_elements, %from_elements_14, %from_elements_15, %from_elements_16, %from_elements_17, %from_elements_18, %from_elements_19, %from_elements_20 : tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt> - } - func.func @matvec_random__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>, %arg5: tensor<1x!pt>, %arg6: tensor<1x!pt>, %arg7: tensor<1x!pt>, %arg8: tensor<1x!pt>, %arg9: tensor<1x!pt>, %arg10: tensor<1x!pt>) -> tensor<1x!ct_L1> attributes {client.preprocessed_func = {func_name = "matvec_random"}} { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c6 = arith.constant 6 : index - %c0 = arith.constant 0 : index - %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> - %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> - %extracted_1 = tensor.extract %arg5[%c0] : tensor<1x!pt> - %extracted_2 = tensor.extract %arg6[%c0] : tensor<1x!pt> - %extracted_3 = tensor.extract %arg7[%c0] : tensor<1x!pt> - %extracted_4 = tensor.extract %arg8[%c0] : tensor<1x!pt> - %extracted_5 = tensor.extract %arg9[%c0] : tensor<1x!pt> - %extracted_6 = tensor.extract %arg10[%c0] : tensor<1x!pt> - %extracted_7 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> - %ct = jaxiteword.mul_plain %arg0, %extracted_7, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_8 = jaxiteword.rot %arg0, %extracted_7, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_9 = jaxiteword.mul_plain %arg0, %ct_8, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_10 = jaxiteword.rot %arg0, %extracted_7, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_11 = jaxiteword.mul_plain %arg0, %ct_10, %extracted_1 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_12 = jaxiteword.mul_plain %arg0, %extracted_7, %extracted_2 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_13 = jaxiteword.mul_plain %arg0, %ct_8, %extracted_3 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_14 = jaxiteword.mul_plain %arg0, %ct_10, %extracted_4 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_15 = jaxiteword.add %arg0, %ct_12, %ct_13 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_16 = jaxiteword.add %arg0, %ct_15, %ct_14 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_17 = jaxiteword.rot %arg0, %ct_16, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_18 = jaxiteword.mul_plain %arg0, %extracted_7, %extracted_5 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_19 = jaxiteword.mul_plain %arg0, %ct_8, %extracted_6 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_20 = jaxiteword.add %arg0, %ct_18, %ct_19 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_21 = jaxiteword.rot %arg0, %ct_20, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_22 = jaxiteword.add %arg0, %ct, %ct_9 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_23 = jaxiteword.add %arg0, %ct_11, %ct_17 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_24 = jaxiteword.add %arg0, %ct_23, %ct_21 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_25 = jaxiteword.add %arg0, %ct_22, %ct_24 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %0 = tensor.empty() : tensor<1x!ct_L1> - %ct_26 = jaxiteword.mod_reduce %arg0, %ct_25 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 - %inserted = tensor.insert %ct_26 into %0[%c0] : tensor<1x!ct_L1> - return %inserted : tensor<1x!ct_L1> - } - func.func @matvec_random(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L1> {tensor_ext.original_type = #original_type}) { - %0:8 = call @matvec_random__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>) - %1 = call @matvec_random__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>) -> tensor<1x!ct_L1> - return %1 : tensor<1x!ct_L1> - } - func.func @matvec_random__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_random", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { - %1 = arith.index_cast %arg4 : i32 to index - %extracted = tensor.extract %arg2[%1] : tensor<8xf32> - %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> - scf.yield %inserted : tensor<1x8xf32> - } - %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> - %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 - %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> - return %from_elements : tensor<1x!ct_L2> - } - func.func @matvec_random__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L1>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_random", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %c8_i32 = arith.constant 8 : i32 - %c1_i32 = arith.constant 1 : i32 - %c7_i32 = arith.constant 7 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> - %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L1> - %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L1, !jaxiteword.private_key<>) -> !pt1 - %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt1) -> tensor<1x8xf32> - %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { - %2 = arith.subi %c7_i32, %arg4 : i32 - %3 = arith.index_cast %2 : i32 to index - %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> - %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> - scf.yield %inserted : tensor<8xf32> - } - return %1 : tensor<8xf32> - } - func.func @matvec_chain__preprocessing(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>) attributes {client.pack_func = {func_name = "matvec_chain"}} { - %cst = arith.constant dense<[1.340000e+00, 1.220000e+00, 1.050000e+00, 1.500000e+00, 1.010000e+00, 0.879999995, 5.000000e-01, 1.060000e+00]> : tensor<8xf32> - %cst_0 = arith.constant dense<[5.800000e-01, 5.200000e-01, 0.889999985, 8.600000e-01, 1.170000e+00, 0.819999992, 1.490000e+00, 1.410000e+00]> : tensor<8xf32> - %cst_1 = arith.constant dense<[1.260000e+00, 1.090000e+00, 1.430000e+00, 1.260000e+00, 6.100000e-01, 8.500000e-01, 6.700000e-01, 0.709999978]> : tensor<8xf32> - %cst_2 = arith.constant dense<[0.819999992, 1.330000e+00, 7.900000e-01, 7.400000e-01, 1.060000e+00, 1.340000e+00, 1.090000e+00, 6.300000e-01]> : tensor<8xf32> - %cst_3 = arith.constant dense<[1.160000e+00, 0.839999973, 1.020000e+00, 0.689999997, 6.600000e-01, 8.600000e-01, 1.190000e+00, 6.500000e-01]> : tensor<8xf32> - %cst_4 = arith.constant dense<[1.350000e+00, 1.050000e+00, 1.400000e+00, 1.070000e+00, 6.500000e-01, 5.400000e-01, 8.000000e-01, 0.899999976]> : tensor<8xf32> - %cst_5 = arith.constant dense<[0.819999992, 0.899999976, 7.400000e-01, 1.050000e+00, 1.080000e+00, 1.480000e+00, 6.000000e-01, 1.200000e+00]> : tensor<8xf32> - %cst_6 = arith.constant dense<[1.190000e+00, 1.200000e+00, 0.839999973, 1.350000e+00, 1.020000e+00, 7.600000e-01, 1.390000e+00, 1.130000e+00]> : tensor<8xf32> - %cst_7 = arith.constant dense<[1.200000e+00, 0.889999985, 1.030000e+00, 7.300000e-01, 9.300000e-01, 7.500000e-01, 0.839999973, 1.170000e+00]> : tensor<8xf32> - %cst_8 = arith.constant dense<[7.900000e-01, 0.839999973, 1.030000e+00, 7.900000e-01, 1.390000e+00, 9.800000e-01, 8.000000e-01, 9.200000e-01]> : tensor<8xf32> - %cst_9 = arith.constant dense<[7.300000e-01, 1.230000e+00, 1.130000e+00, 1.130000e+00, 1.440000e+00, 1.490000e+00, 1.020000e+00, 1.180000e+00]> : tensor<8xf32> - %cst_10 = arith.constant dense<[1.120000e+00, 1.110000e+00, 1.380000e+00, 1.050000e+00, 0.939999997, 1.350000e+00, 5.900000e-01, 1.000000e+00]> : tensor<8xf32> - %cst_11 = arith.constant dense<[6.200000e-01, 6.200000e-01, 1.010000e+00, 1.220000e+00, 5.600000e-01, 1.220000e+00, 9.300000e-01, 9.300000e-01]> : tensor<8xf32> - %cst_12 = arith.constant dense<[0.819999992, 1.330000e+00, 1.170000e+00, 9.200000e-01, 0.899999976, 1.110000e+00, 1.220000e+00, 9.900000e-01]> : tensor<8xf32> - %cst_13 = arith.constant dense<[6.800000e-01, 0.819999992, 9.300000e-01, 9.100000e-01, 1.100000e+00, 1.090000e+00, 1.480000e+00, 1.240000e+00]> : tensor<8xf32> - %cst_14 = arith.constant dense<[6.800000e-01, 8.600000e-01, 8.100000e-01, 1.370000e+00, 1.050000e+00, 1.120000e+00, 1.180000e+00, 9.800000e-01]> : tensor<8xf32> - %pt = jaxiteword.encode %arg0, %cst : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_15 = jaxiteword.encode %arg0, %cst_0 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_16 = jaxiteword.encode %arg0, %cst_1 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_17 = jaxiteword.encode %arg0, %cst_2 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_18 = jaxiteword.encode %arg0, %cst_3 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_19 = jaxiteword.encode %arg0, %cst_4 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_20 = jaxiteword.encode %arg0, %cst_5 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_21 = jaxiteword.encode %arg0, %cst_6 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %pt_22 = jaxiteword.encode %arg0, %cst_7 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_23 = jaxiteword.encode %arg0, %cst_8 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_24 = jaxiteword.encode %arg0, %cst_9 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_25 = jaxiteword.encode %arg0, %cst_10 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_26 = jaxiteword.encode %arg0, %cst_11 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_27 = jaxiteword.encode %arg0, %cst_12 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_28 = jaxiteword.encode %arg0, %cst_13 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %pt_29 = jaxiteword.encode %arg0, %cst_14 : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt1 - %from_elements = tensor.from_elements %pt : tensor<1x!pt> - %from_elements_30 = tensor.from_elements %pt_15 : tensor<1x!pt> - %from_elements_31 = tensor.from_elements %pt_16 : tensor<1x!pt> - %from_elements_32 = tensor.from_elements %pt_17 : tensor<1x!pt> - %from_elements_33 = tensor.from_elements %pt_18 : tensor<1x!pt> - %from_elements_34 = tensor.from_elements %pt_19 : tensor<1x!pt> - %from_elements_35 = tensor.from_elements %pt_20 : tensor<1x!pt> - %from_elements_36 = tensor.from_elements %pt_21 : tensor<1x!pt> - %from_elements_37 = tensor.from_elements %pt_22, %pt_23 : tensor<2x!pt1> - %from_elements_38 = tensor.from_elements %pt_24, %pt_25 : tensor<2x!pt1> - %from_elements_39 = tensor.from_elements %pt_26, %pt_27 : tensor<2x!pt1> - %from_elements_40 = tensor.from_elements %pt_28, %pt_29 : tensor<2x!pt1> - return %from_elements, %from_elements_30, %from_elements_31, %from_elements_32, %from_elements_33, %from_elements_34, %from_elements_35, %from_elements_36, %from_elements_37, %from_elements_38, %from_elements_39, %from_elements_40 : tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1> - } - func.func @matvec_chain__preprocessed(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2>, %arg3: tensor<1x!pt>, %arg4: tensor<1x!pt>, %arg5: tensor<1x!pt>, %arg6: tensor<1x!pt>, %arg7: tensor<1x!pt>, %arg8: tensor<1x!pt>, %arg9: tensor<1x!pt>, %arg10: tensor<1x!pt>, %arg11: tensor<2x!pt1>, %arg12: tensor<2x!pt1>, %arg13: tensor<2x!pt1>, %arg14: tensor<2x!pt1>) -> tensor<1x!ct_L0> attributes {client.preprocessed_func = {func_name = "matvec_chain"}} { - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c6 = arith.constant 6 : index - %c0 = arith.constant 0 : index - %extracted = tensor.extract %arg3[%c0] : tensor<1x!pt> - %extracted_0 = tensor.extract %arg4[%c0] : tensor<1x!pt> - %extracted_1 = tensor.extract %arg5[%c0] : tensor<1x!pt> - %extracted_2 = tensor.extract %arg6[%c0] : tensor<1x!pt> - %extracted_3 = tensor.extract %arg7[%c0] : tensor<1x!pt> - %extracted_4 = tensor.extract %arg8[%c0] : tensor<1x!pt> - %extracted_5 = tensor.extract %arg9[%c0] : tensor<1x!pt> - %extracted_6 = tensor.extract %arg10[%c0] : tensor<1x!pt> - %extracted_7 = tensor.extract %arg11[%c0] : tensor<2x!pt1> - %extracted_8 = tensor.extract %arg11[%c1] : tensor<2x!pt1> - %extracted_9 = tensor.extract %arg12[%c0] : tensor<2x!pt1> - %extracted_10 = tensor.extract %arg12[%c1] : tensor<2x!pt1> - %extracted_11 = tensor.extract %arg13[%c0] : tensor<2x!pt1> - %extracted_12 = tensor.extract %arg13[%c1] : tensor<2x!pt1> - %extracted_13 = tensor.extract %arg14[%c0] : tensor<2x!pt1> - %extracted_14 = tensor.extract %arg14[%c1] : tensor<2x!pt1> - %extracted_15 = tensor.extract %arg2[%c0] : tensor<1x!ct_L2> - %ct = jaxiteword.mul_plain %arg0, %extracted_15, %extracted : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_16 = jaxiteword.rot %arg0, %extracted_15, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_17 = jaxiteword.mul_plain %arg0, %ct_16, %extracted_0 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_18 = jaxiteword.rot %arg0, %extracted_15, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2, !jaxiteword.eval_key<>) -> !ct_L2 - %ct_19 = jaxiteword.mul_plain %arg0, %ct_18, %extracted_1 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_20 = jaxiteword.mul_plain %arg0, %extracted_15, %extracted_2 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_21 = jaxiteword.mul_plain %arg0, %ct_16, %extracted_3 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_22 = jaxiteword.mul_plain %arg0, %ct_18, %extracted_4 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_23 = jaxiteword.add %arg0, %ct_20, %ct_21 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_24 = jaxiteword.add %arg0, %ct_23, %ct_22 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_25 = jaxiteword.rot %arg0, %ct_24, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_26 = jaxiteword.mul_plain %arg0, %extracted_15, %extracted_5 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_27 = jaxiteword.mul_plain %arg0, %ct_16, %extracted_6 : (!jaxiteword.crypto_context<>, !ct_L2, !pt) -> !ct_L2_1 - %ct_28 = jaxiteword.add %arg0, %ct_26, %ct_27 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_29 = jaxiteword.rot %arg0, %ct_28, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_30 = jaxiteword.add %arg0, %ct, %ct_17 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_31 = jaxiteword.add %arg0, %ct_19, %ct_25 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_32 = jaxiteword.add %arg0, %ct_31, %ct_29 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_33 = jaxiteword.add %arg0, %ct_30, %ct_32 : (!jaxiteword.crypto_context<>, !ct_L2_1, !ct_L2_1) -> !ct_L2_1 - %ct_34 = jaxiteword.mod_reduce %arg0, %ct_33 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 - %ct_35 = jaxiteword.mul_plain %arg0, %ct_34, %extracted_7 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_36 = jaxiteword.rot %arg0, %ct_33, %arg1 {index = 1 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_37 = jaxiteword.mod_reduce %arg0, %ct_36 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 - %ct_38 = jaxiteword.mul_plain %arg0, %ct_37, %extracted_8 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_39 = jaxiteword.rot %arg0, %ct_33, %arg1 {index = 2 : i64} : (!jaxiteword.crypto_context<>, !ct_L2_1, !jaxiteword.eval_key<>) -> !ct_L2_1 - %ct_40 = jaxiteword.mod_reduce %arg0, %ct_39 : (!jaxiteword.crypto_context<>, !ct_L2_1) -> !ct_L1 - %ct_41 = jaxiteword.mul_plain %arg0, %ct_40, %extracted_9 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_42 = jaxiteword.mul_plain %arg0, %ct_34, %extracted_10 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_43 = jaxiteword.mul_plain %arg0, %ct_37, %extracted_11 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_44 = jaxiteword.mul_plain %arg0, %ct_40, %extracted_12 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_45 = jaxiteword.add %arg0, %ct_42, %ct_43 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %ct_46 = jaxiteword.add %arg0, %ct_45, %ct_44 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %ct_47 = jaxiteword.rot %arg0, %ct_46, %arg1 {index = 3 : i64} : (!jaxiteword.crypto_context<>, !ct_L1_1, !jaxiteword.eval_key<>) -> !ct_L1_1 - %ct_48 = jaxiteword.mul_plain %arg0, %ct_34, %extracted_13 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_49 = jaxiteword.mul_plain %arg0, %ct_37, %extracted_14 : (!jaxiteword.crypto_context<>, !ct_L1, !pt1) -> !ct_L1_1 - %ct_50 = jaxiteword.add %arg0, %ct_48, %ct_49 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %ct_51 = jaxiteword.rot %arg0, %ct_50, %arg1 {index = 6 : i64} : (!jaxiteword.crypto_context<>, !ct_L1_1, !jaxiteword.eval_key<>) -> !ct_L1_1 - %ct_52 = jaxiteword.add %arg0, %ct_35, %ct_38 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %ct_53 = jaxiteword.add %arg0, %ct_41, %ct_47 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %ct_54 = jaxiteword.add %arg0, %ct_53, %ct_51 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %ct_55 = jaxiteword.add %arg0, %ct_52, %ct_54 : (!jaxiteword.crypto_context<>, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 - %0 = tensor.empty() : tensor<1x!ct_L0> - %ct_56 = jaxiteword.mod_reduce %arg0, %ct_55 : (!jaxiteword.crypto_context<>, !ct_L1_1) -> !ct_L0 - %inserted = tensor.insert %ct_56 into %0[%c0] : tensor<1x!ct_L0> - return %inserted : tensor<1x!ct_L0> - } - func.func @matvec_chain(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L2> {tensor_ext.original_type = #original_type}) -> (tensor<1x!ct_L0> {tensor_ext.original_type = #original_type}) { - %0:12 = call @matvec_chain__preprocessing(%arg0, %arg1) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>) -> (tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>) - %1 = call @matvec_chain__preprocessed(%arg0, %arg1, %arg2, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11) : (!jaxiteword.crypto_context<>, !jaxiteword.eval_key<>, tensor<1x!ct_L2>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<1x!pt>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>, tensor<2x!pt1>) -> tensor<1x!ct_L0> - return %1 : tensor<1x!ct_L0> - } - func.func @matvec_chain__encrypt__arg0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<8xf32>, %arg3: !jaxiteword.public_key<>) -> tensor<1x!ct_L2> attributes {client.enc_func = {func_name = "matvec_chain", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<1x8xf32> - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<1x8xf32>) : i32 { - %1 = arith.index_cast %arg4 : i32 to index - %extracted = tensor.extract %arg2[%1] : tensor<8xf32> - %inserted = tensor.insert %extracted into %arg5[%c0, %1] : tensor<1x8xf32> - scf.yield %inserted : tensor<1x8xf32> - } - %extracted_slice = tensor.extract_slice %0[0, 0] [1, 8] [1, 1] : tensor<1x8xf32> to tensor<8xf32> - %pt = jaxiteword.encode %arg0, %extracted_slice : (!jaxiteword.crypto_context<>, tensor<8xf32>) -> !pt - %ct = jaxiteword.encrypt %arg0, %pt, %arg3 : (!jaxiteword.crypto_context<>, !pt, !jaxiteword.public_key<>) -> !ct_L2 - %from_elements = tensor.from_elements %ct : tensor<1x!ct_L2> - return %from_elements : tensor<1x!ct_L2> - } - func.func @matvec_chain__decrypt__result0(%arg0: !jaxiteword.crypto_context<>, %arg1: !jaxiteword.eval_key<>, %arg2: tensor<1x!ct_L0>, %arg3: !jaxiteword.private_key<>) -> tensor<8xf32> attributes {client.dec_func = {func_name = "matvec_chain", index = 0 : i64}} { - %c0 = arith.constant 0 : index - %c8_i32 = arith.constant 8 : i32 - %c1_i32 = arith.constant 1 : i32 - %c7_i32 = arith.constant 7 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<8xf32> - %extracted = tensor.extract %arg2[%c0] : tensor<1x!ct_L0> - %pt = jaxiteword.decrypt %arg0, %extracted, %arg3 : (!jaxiteword.crypto_context<>, !ct_L0, !jaxiteword.private_key<>) -> !pt2 - %0 = jaxiteword.decode %arg0, %pt : (!jaxiteword.crypto_context<>, !pt2) -> tensor<1x8xf32> - %1 = scf.for %arg4 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<8xf32>) : i32 { - %2 = arith.subi %c7_i32, %arg4 : i32 - %3 = arith.index_cast %2 : i32 to index - %extracted_0 = tensor.extract %0[%c0, %3] : tensor<1x8xf32> - %inserted = tensor.insert %extracted_0 into %arg5[%3] : tensor<8xf32> - scf.yield %inserted : tensor<8xf32> - } - return %1 : tensor<8xf32> - } - func.func @matvec_identity__generate_crypto_context(%arg0: !jaxiteword.public_key<>, %arg1: !jaxiteword.private_key<>, %arg2: !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> { - %0 = jaxiteword.gen_params %arg0, %arg1, %arg2 {batch = 1 : i32, c = 4 : i32, compositeDegree = 1 : i32, degree = 16 : i64, dnum = 3 : i32, numEvalMult = 1 : i32, numSlots = 8 : i64, pTowers = array, qTowers = array, r = 4 : i32, scalingFactor = 0x42C0000000000000 : f64} : (!jaxiteword.public_key<>, !jaxiteword.private_key<>, !jaxiteword.eval_key<>) -> !jaxiteword.crypto_context<> - return %0 : !jaxiteword.crypto_context<> - } - func.func @matvec_identity__configure_crypto_context(%arg0: !jaxiteword.crypto_context<>) { - jaxiteword.program_initialization %arg0 {batch = 1 : i32, c = 4 : i32, dnum = 3 : i32, r = 4 : i32, totalHemulLevels = 1 : i64, totalRotationIndices = array} : (!jaxiteword.crypto_context<>) -> () - return - } -} diff --git a/matvec_8x8_jaxiteword.mlir b/matvec_8x8_jaxiteword.mlir deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/matvec_8x8_jaxiteword.py b/matvec_8x8_jaxiteword.py deleted file mode 100644 index 03cc11bf62..0000000000 --- a/matvec_8x8_jaxiteword.py +++ /dev/null @@ -1,11122 +0,0 @@ -import jax -import jax.numpy as jnp -import numpy as np -from ciphertext import Ciphertext -from polynomial import Polynomial -import ckks_ctx as ckks - - -def _assign_layout_15335824159471298539( - v0: np.ndarray, -) -> np.ndarray: - v1 = 8 - v2 = np.full( - ( - 8, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v3 = 0 - v4 = 1 - v5 = v2.copy() - for v6 in range(0, 8): - for v9 in range(0, 8): - v11 = v6 + v9 - v12 = v11 % v1 - v13 = int(v9) - v14 = int(v12) - v15 = v0[v13, v14] - v16 = int(v6) - v5[v16, v13] = v15 - return v5 - - -def matvec_identity__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> ( - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -): - v2 = np.array( - [ - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - ], - dtype=np.float32, - ).reshape(8, 8) - v3 = _assign_layout_15335824159471298539(v2) - v4 = v3[3 : 3 + 1, 0 : 0 + 5] - v5 = v3[3 : 3 + 1, 5 : 5 + 3] - v6 = np.zeros( - ( - 1, - 8, - ), - dtype=np.float32, - ) - v7 = v6.copy() - v7[0 : 0 + 1, 3 : 3 + 5] = v4 - v8 = v7.copy() - v8[0 : 0 + 1, 0 : 0 + 3] = v5 - v9 = v3[4 : 4 + 1, 0 : 0 + 5] - v10 = v3[4 : 4 + 1, 5 : 5 + 3] - v11 = v6.copy() - v11[0 : 0 + 1, 3 : 3 + 5] = v9 - v12 = v11.copy() - v12[0 : 0 + 1, 0 : 0 + 3] = v10 - v13 = v3[5 : 5 + 1, 0 : 0 + 5] - v14 = v3[5 : 5 + 1, 5 : 5 + 3] - v15 = v6.copy() - v15[0 : 0 + 1, 3 : 3 + 5] = v13 - v16 = v15.copy() - v16[0 : 0 + 1, 0 : 0 + 3] = v14 - v17 = v3[6 : 6 + 1, 0 : 0 + 2] - v18 = v3[6 : 6 + 1, 2 : 2 + 6] - v19 = v6.copy() - v19[0 : 0 + 1, 6 : 6 + 2] = v17 - v20 = v19.copy() - v20[0 : 0 + 1, 0 : 0 + 6] = v18 - v21 = v3[7 : 7 + 1, 0 : 0 + 2] - v22 = v3[7 : 7 + 1, 2 : 2 + 6] - v23 = v6.copy() - v23[0 : 0 + 1, 6 : 6 + 2] = v21 - v24 = v23.copy() - v24[0 : 0 + 1, 0 : 0 + 6] = v22 - v25 = v3[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v25) - v26 = v3[1 : 1 + 1, 0 : 0 + 8].reshape(8) - pt1 = v0.encode(v26) - v27 = v3[2 : 2 + 1, 0 : 0 + 8].reshape(8) - pt2 = v0.encode(v27) - v28 = v8[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt3 = v0.encode(v28) - v29 = v12[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt4 = v0.encode(v29) - v30 = v16[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt5 = v0.encode(v30) - v31 = v20[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt6 = v0.encode(v31) - v32 = v24[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt7 = v0.encode(v32) - v33 = [pt] - v34 = [pt1] - v35 = [pt2] - v36 = [pt3] - v37 = [pt4] - v38 = [pt5] - v39 = [pt6] - v40 = [pt7] - return (v33, v34, v35, v36, v37, v38, v39, v40) - - -def matvec_identity__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, - v5: np.ndarray, - v6: np.ndarray, - v7: np.ndarray, - v8: np.ndarray, - v9: np.ndarray, - v10: np.ndarray, -) -> np.ndarray: - v11 = 1 - v12 = 2 - v13 = 3 - v14 = 6 - v15 = 0 - pt = v3[0] - pt1 = v4[0] - pt2 = v5[0] - pt3 = v6[0] - pt4 = v7[0] - pt5 = v8[0] - pt6 = v9[0] - pt7 = v10[0] - ct = v2[0] - _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct1_arg_m_in = _ct1_arg_data.shape[-1] - _ct1_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_arg_m_in - ) - _ct1_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_arg_r) - ) - _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct1_arg_moduli, (int, np.integer)): - _ct1_arg_moduli = [int(_ct1_arg_moduli)] - ct1_arg = Polynomial( - { - "batch": _ct1_arg_data.shape[0], - "num_elements": _ct1_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_arg_m, - "precision": 32, - "degree_layout": (_ct1_arg_r, _ct1_arg_c), - }, - {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, - ) - ct1_arg.polynomial = _ct1_arg_data.reshape( - _ct1_arg_data.shape[0], - _ct1_arg_data.shape[1], - _ct1_arg_r, - _ct1_arg_c, - _ct1_arg_m_in, - )[..., :_ct1_arg_m].copy() - ct1_arg.batch = ct1_arg.polynomial.shape[0] - ct1_arg.num_elements = ct1_arg.polynomial.shape[1] - ct1_arg.num_moduli = _ct1_arg_m - ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) - ct1_arg.r = _ct1_arg_r - ct1_arg.c = _ct1_arg_c - ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] - ct1_arg.moduli_array = jnp.array( - ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) - ) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) - _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw - _ct1_m_in = _ct1_data.shape[-1] - _ct1_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_m_in - ) - _ct1_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_r) - ) - _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) - if isinstance(_ct1_moduli, (int, np.integer)): - _ct1_moduli = [int(_ct1_moduli)] - ct1 = Polynomial( - { - "batch": _ct1_data.shape[0], - "num_elements": _ct1_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_m, - "precision": 32, - "degree_layout": (_ct1_r, _ct1_c), - }, - {"moduli": list(_ct1_moduli)[:_ct1_m]}, - ) - ct1.polynomial = _ct1_data.reshape( - _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in - )[..., :_ct1_m].copy() - ct1.batch = ct1.polynomial.shape[0] - ct1.num_elements = ct1.polynomial.shape[1] - ct1.num_moduli = _ct1_m - ct1.degree_layout = (_ct1_r, _ct1_c) - ct1.r = _ct1_r - ct1.c = _ct1_c - ct1.moduli = list(_ct1_moduli)[:_ct1_m] - ct1.moduli_array = jnp.array( - ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) - ) - _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct2_arg_m_in = _ct2_arg_data.shape[-1] - _ct2_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_arg_m_in - ) - _ct2_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_arg_r) - ) - _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct2_arg_moduli, (int, np.integer)): - _ct2_arg_moduli = [int(_ct2_arg_moduli)] - ct2_arg = Polynomial( - { - "batch": _ct2_arg_data.shape[0], - "num_elements": _ct2_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_arg_m, - "precision": 32, - "degree_layout": (_ct2_arg_r, _ct2_arg_c), - }, - {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, - ) - ct2_arg.polynomial = _ct2_arg_data.reshape( - _ct2_arg_data.shape[0], - _ct2_arg_data.shape[1], - _ct2_arg_r, - _ct2_arg_c, - _ct2_arg_m_in, - )[..., :_ct2_arg_m].copy() - ct2_arg.batch = ct2_arg.polynomial.shape[0] - ct2_arg.num_elements = ct2_arg.polynomial.shape[1] - ct2_arg.num_moduli = _ct2_arg_m - ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) - ct2_arg.r = _ct2_arg_r - ct2_arg.c = _ct2_arg_c - ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] - ct2_arg.moduli_array = jnp.array( - ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) - ) - ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw - _ct2_m_in = _ct2_data.shape[-1] - _ct2_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_m_in - ) - _ct2_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_r) - ) - _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) - if isinstance(_ct2_moduli, (int, np.integer)): - _ct2_moduli = [int(_ct2_moduli)] - ct2 = Polynomial( - { - "batch": _ct2_data.shape[0], - "num_elements": _ct2_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_m, - "precision": 32, - "degree_layout": (_ct2_r, _ct2_c), - }, - {"moduli": list(_ct2_moduli)[:_ct2_m]}, - ) - ct2.polynomial = _ct2_data.reshape( - _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in - )[..., :_ct2_m].copy() - ct2.batch = ct2.polynomial.shape[0] - ct2.num_elements = ct2.polynomial.shape[1] - ct2.num_moduli = _ct2_m - ct2.degree_layout = (_ct2_r, _ct2_c) - ct2.r = _ct2_r - ct2.c = _ct2_c - ct2.moduli = list(_ct2_moduli)[:_ct2_m] - ct2.moduli_array = jnp.array( - ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) - ) - _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct3_arg_m_in = _ct3_arg_data.shape[-1] - _ct3_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_arg_m_in - ) - _ct3_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_arg_r) - ) - _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct3_arg_moduli, (int, np.integer)): - _ct3_arg_moduli = [int(_ct3_arg_moduli)] - ct3_arg = Polynomial( - { - "batch": _ct3_arg_data.shape[0], - "num_elements": _ct3_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_arg_m, - "precision": 32, - "degree_layout": (_ct3_arg_r, _ct3_arg_c), - }, - {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, - ) - ct3_arg.polynomial = _ct3_arg_data.reshape( - _ct3_arg_data.shape[0], - _ct3_arg_data.shape[1], - _ct3_arg_r, - _ct3_arg_c, - _ct3_arg_m_in, - )[..., :_ct3_arg_m].copy() - ct3_arg.batch = ct3_arg.polynomial.shape[0] - ct3_arg.num_elements = ct3_arg.polynomial.shape[1] - ct3_arg.num_moduli = _ct3_arg_m - ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) - ct3_arg.r = _ct3_arg_r - ct3_arg.c = _ct3_arg_c - ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] - ct3_arg.moduli_array = jnp.array( - ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) - ) - ct3_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] - .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct3_ptct = v0.ptct_mul[v0.max_level] - ct3_ptct.set_plaintext(ct3_pt_ntt) - ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) - _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw - _ct3_m_in = _ct3_data.shape[-1] - _ct3_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_m_in - ) - _ct3_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_r) - ) - _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) - if isinstance(_ct3_moduli, (int, np.integer)): - _ct3_moduli = [int(_ct3_moduli)] - ct3 = Polynomial( - { - "batch": _ct3_data.shape[0], - "num_elements": _ct3_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_m, - "precision": 32, - "degree_layout": (_ct3_r, _ct3_c), - }, - {"moduli": list(_ct3_moduli)[:_ct3_m]}, - ) - ct3.polynomial = _ct3_data.reshape( - _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in - )[..., :_ct3_m].copy() - ct3.batch = ct3.polynomial.shape[0] - ct3.num_elements = ct3.polynomial.shape[1] - ct3.num_moduli = _ct3_m - ct3.degree_layout = (_ct3_r, _ct3_c) - ct3.r = _ct3_r - ct3.c = _ct3_c - ct3.moduli = list(_ct3_moduli)[:_ct3_m] - ct3.moduli_array = jnp.array( - ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) - ) - _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct4_arg_m_in = _ct4_arg_data.shape[-1] - _ct4_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_arg_m_in - ) - _ct4_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_arg_r) - ) - _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct4_arg_moduli, (int, np.integer)): - _ct4_arg_moduli = [int(_ct4_arg_moduli)] - ct4_arg = Polynomial( - { - "batch": _ct4_arg_data.shape[0], - "num_elements": _ct4_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_arg_m, - "precision": 32, - "degree_layout": (_ct4_arg_r, _ct4_arg_c), - }, - {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, - ) - ct4_arg.polynomial = _ct4_arg_data.reshape( - _ct4_arg_data.shape[0], - _ct4_arg_data.shape[1], - _ct4_arg_r, - _ct4_arg_c, - _ct4_arg_m_in, - )[..., :_ct4_arg_m].copy() - ct4_arg.batch = ct4_arg.polynomial.shape[0] - ct4_arg.num_elements = ct4_arg.polynomial.shape[1] - ct4_arg.num_moduli = _ct4_arg_m - ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) - ct4_arg.r = _ct4_arg_r - ct4_arg.c = _ct4_arg_c - ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] - ct4_arg.moduli_array = jnp.array( - ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) - ) - ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) - _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw - _ct4_m_in = _ct4_data.shape[-1] - _ct4_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_m_in - ) - _ct4_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_r) - ) - _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) - if isinstance(_ct4_moduli, (int, np.integer)): - _ct4_moduli = [int(_ct4_moduli)] - ct4 = Polynomial( - { - "batch": _ct4_data.shape[0], - "num_elements": _ct4_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_m, - "precision": 32, - "degree_layout": (_ct4_r, _ct4_c), - }, - {"moduli": list(_ct4_moduli)[:_ct4_m]}, - ) - ct4.polynomial = _ct4_data.reshape( - _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in - )[..., :_ct4_m].copy() - ct4.batch = ct4.polynomial.shape[0] - ct4.num_elements = ct4.polynomial.shape[1] - ct4.num_moduli = _ct4_m - ct4.degree_layout = (_ct4_r, _ct4_c) - ct4.r = _ct4_r - ct4.c = _ct4_c - ct4.moduli = list(_ct4_moduli)[:_ct4_m] - ct4.moduli_array = jnp.array( - ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) - ) - _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct5_arg_m_in = _ct5_arg_data.shape[-1] - _ct5_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_arg_m_in - ) - _ct5_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_arg_r) - ) - _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct5_arg_moduli, (int, np.integer)): - _ct5_arg_moduli = [int(_ct5_arg_moduli)] - ct5_arg = Polynomial( - { - "batch": _ct5_arg_data.shape[0], - "num_elements": _ct5_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_arg_m, - "precision": 32, - "degree_layout": (_ct5_arg_r, _ct5_arg_c), - }, - {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, - ) - ct5_arg.polynomial = _ct5_arg_data.reshape( - _ct5_arg_data.shape[0], - _ct5_arg_data.shape[1], - _ct5_arg_r, - _ct5_arg_c, - _ct5_arg_m_in, - )[..., :_ct5_arg_m].copy() - ct5_arg.batch = ct5_arg.polynomial.shape[0] - ct5_arg.num_elements = ct5_arg.polynomial.shape[1] - ct5_arg.num_moduli = _ct5_arg_m - ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) - ct5_arg.r = _ct5_arg_r - ct5_arg.c = _ct5_arg_c - ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] - ct5_arg.moduli_array = jnp.array( - ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) - ) - ct5_pt_ntt = ( - pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) - _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw - _ct5_m_in = _ct5_data.shape[-1] - _ct5_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_m_in - ) - _ct5_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_r) - ) - _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) - if isinstance(_ct5_moduli, (int, np.integer)): - _ct5_moduli = [int(_ct5_moduli)] - ct5 = Polynomial( - { - "batch": _ct5_data.shape[0], - "num_elements": _ct5_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_m, - "precision": 32, - "degree_layout": (_ct5_r, _ct5_c), - }, - {"moduli": list(_ct5_moduli)[:_ct5_m]}, - ) - ct5.polynomial = _ct5_data.reshape( - _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in - )[..., :_ct5_m].copy() - ct5.batch = ct5.polynomial.shape[0] - ct5.num_elements = ct5.polynomial.shape[1] - ct5.num_moduli = _ct5_m - ct5.degree_layout = (_ct5_r, _ct5_c) - ct5.r = _ct5_r - ct5.c = _ct5_c - ct5.moduli = list(_ct5_moduli)[:_ct5_m] - ct5.moduli_array = jnp.array( - ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) - ) - _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct6_arg_m_in = _ct6_arg_data.shape[-1] - _ct6_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_arg_m_in - ) - _ct6_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_arg_r) - ) - _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct6_arg_moduli, (int, np.integer)): - _ct6_arg_moduli = [int(_ct6_arg_moduli)] - ct6_arg = Polynomial( - { - "batch": _ct6_arg_data.shape[0], - "num_elements": _ct6_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_arg_m, - "precision": 32, - "degree_layout": (_ct6_arg_r, _ct6_arg_c), - }, - {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, - ) - ct6_arg.polynomial = _ct6_arg_data.reshape( - _ct6_arg_data.shape[0], - _ct6_arg_data.shape[1], - _ct6_arg_r, - _ct6_arg_c, - _ct6_arg_m_in, - )[..., :_ct6_arg_m].copy() - ct6_arg.batch = ct6_arg.polynomial.shape[0] - ct6_arg.num_elements = ct6_arg.polynomial.shape[1] - ct6_arg.num_moduli = _ct6_arg_m - ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) - ct6_arg.r = _ct6_arg_r - ct6_arg.c = _ct6_arg_c - ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] - ct6_arg.moduli_array = jnp.array( - ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) - ) - ct6_pt_ntt = ( - pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] - .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct6_ptct = v0.ptct_mul[v0.max_level] - ct6_ptct.set_plaintext(ct6_pt_ntt) - ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) - _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw - _ct6_m_in = _ct6_data.shape[-1] - _ct6_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_m_in - ) - _ct6_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_r) - ) - _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) - if isinstance(_ct6_moduli, (int, np.integer)): - _ct6_moduli = [int(_ct6_moduli)] - ct6 = Polynomial( - { - "batch": _ct6_data.shape[0], - "num_elements": _ct6_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_m, - "precision": 32, - "degree_layout": (_ct6_r, _ct6_c), - }, - {"moduli": list(_ct6_moduli)[:_ct6_m]}, - ) - ct6.polynomial = _ct6_data.reshape( - _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in - )[..., :_ct6_m].copy() - ct6.batch = ct6.polynomial.shape[0] - ct6.num_elements = ct6.polynomial.shape[1] - ct6.num_moduli = _ct6_m - ct6.degree_layout = (_ct6_r, _ct6_c) - ct6.r = _ct6_r - ct6.c = _ct6_c - ct6.moduli = list(_ct6_moduli)[:_ct6_m] - ct6.moduli_array = jnp.array( - ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) - ) - _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct7_arg_m_in = _ct7_arg_data.shape[-1] - _ct7_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_arg_m_in - ) - _ct7_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_arg_r) - ) - _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct7_arg_moduli, (int, np.integer)): - _ct7_arg_moduli = [int(_ct7_arg_moduli)] - ct7_arg = Polynomial( - { - "batch": _ct7_arg_data.shape[0], - "num_elements": _ct7_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_arg_m, - "precision": 32, - "degree_layout": (_ct7_arg_r, _ct7_arg_c), - }, - {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, - ) - ct7_arg.polynomial = _ct7_arg_data.reshape( - _ct7_arg_data.shape[0], - _ct7_arg_data.shape[1], - _ct7_arg_r, - _ct7_arg_c, - _ct7_arg_m_in, - )[..., :_ct7_arg_m].copy() - ct7_arg.batch = ct7_arg.polynomial.shape[0] - ct7_arg.num_elements = ct7_arg.polynomial.shape[1] - ct7_arg.num_moduli = _ct7_arg_m - ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) - ct7_arg.r = _ct7_arg_r - ct7_arg.c = _ct7_arg_c - ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] - ct7_arg.moduli_array = jnp.array( - ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) - ) - ct7_pt_ntt = ( - pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] - .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct7_ptct = v0.ptct_mul[v0.max_level] - ct7_ptct.set_plaintext(ct7_pt_ntt) - ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) - _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw - _ct7_m_in = _ct7_data.shape[-1] - _ct7_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_m_in - ) - _ct7_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_r) - ) - _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) - if isinstance(_ct7_moduli, (int, np.integer)): - _ct7_moduli = [int(_ct7_moduli)] - ct7 = Polynomial( - { - "batch": _ct7_data.shape[0], - "num_elements": _ct7_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_m, - "precision": 32, - "degree_layout": (_ct7_r, _ct7_c), - }, - {"moduli": list(_ct7_moduli)[:_ct7_m]}, - ) - ct7.polynomial = _ct7_data.reshape( - _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in - )[..., :_ct7_m].copy() - ct7.batch = ct7.polynomial.shape[0] - ct7.num_elements = ct7.polynomial.shape[1] - ct7.num_moduli = _ct7_m - ct7.degree_layout = (_ct7_r, _ct7_c) - ct7.r = _ct7_r - ct7.c = _ct7_c - ct7.moduli = list(_ct7_moduli)[:_ct7_m] - ct7.moduli_array = jnp.array( - ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) - ) - _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct8_arg_m_in = _ct8_arg_data.shape[-1] - _ct8_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_arg_m_in - ) - _ct8_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_arg_r) - ) - _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct8_arg_moduli, (int, np.integer)): - _ct8_arg_moduli = [int(_ct8_arg_moduli)] - ct8_arg = Polynomial( - { - "batch": _ct8_arg_data.shape[0], - "num_elements": _ct8_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_arg_m, - "precision": 32, - "degree_layout": (_ct8_arg_r, _ct8_arg_c), - }, - {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, - ) - ct8_arg.polynomial = _ct8_arg_data.reshape( - _ct8_arg_data.shape[0], - _ct8_arg_data.shape[1], - _ct8_arg_r, - _ct8_arg_c, - _ct8_arg_m_in, - )[..., :_ct8_arg_m].copy() - ct8_arg.batch = ct8_arg.polynomial.shape[0] - ct8_arg.num_elements = ct8_arg.polynomial.shape[1] - ct8_arg.num_moduli = _ct8_arg_m - ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) - ct8_arg.r = _ct8_arg_r - ct8_arg.c = _ct8_arg_c - ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] - ct8_arg.moduli_array = jnp.array( - ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) - ) - ct8_pt_ntt = ( - pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] - .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct8_ptct = v0.ptct_mul[v0.max_level] - ct8_ptct.set_plaintext(ct8_pt_ntt) - ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) - _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw - _ct8_m_in = _ct8_data.shape[-1] - _ct8_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_m_in - ) - _ct8_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_r) - ) - _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) - if isinstance(_ct8_moduli, (int, np.integer)): - _ct8_moduli = [int(_ct8_moduli)] - ct8 = Polynomial( - { - "batch": _ct8_data.shape[0], - "num_elements": _ct8_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_m, - "precision": 32, - "degree_layout": (_ct8_r, _ct8_c), - }, - {"moduli": list(_ct8_moduli)[:_ct8_m]}, - ) - ct8.polynomial = _ct8_data.reshape( - _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in - )[..., :_ct8_m].copy() - ct8.batch = ct8.polynomial.shape[0] - ct8.num_elements = ct8.polynomial.shape[1] - ct8.num_moduli = _ct8_m - ct8.degree_layout = (_ct8_r, _ct8_c) - ct8.r = _ct8_r - ct8.c = _ct8_c - ct8.moduli = list(_ct8_moduli)[:_ct8_m] - ct8.moduli_array = jnp.array( - ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) - ) - _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - _ct9_m_in = _ct9_data.shape[-1] - _ct9_m = _ct9_m_in - _ct9_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_r) - ) - _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) - if isinstance(_ct9_moduli, (int, np.integer)): - _ct9_moduli = [int(_ct9_moduli)] - ct9 = Polynomial( - { - "batch": _ct9_data.shape[0], - "num_elements": _ct9_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_m, - "precision": 32, - "degree_layout": (_ct9_r, _ct9_c), - }, - {"moduli": list(_ct9_moduli)[:_ct9_m]}, - ) - ct9.polynomial = _ct9_data.reshape( - _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in - )[..., :_ct9_m].copy() - ct9.batch = ct9.polynomial.shape[0] - ct9.num_elements = ct9.polynomial.shape[1] - ct9.num_moduli = _ct9_m - ct9.degree_layout = (_ct9_r, _ct9_c) - ct9.r = _ct9_r - ct9.c = _ct9_c - ct9.moduli = list(_ct9_moduli)[:_ct9_m] - ct9.moduli_array = jnp.array( - ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) - ) - _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 - _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] - _ct9_rhs_m = _ct9_rhs_m_in - _ct9_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_rhs_r) - ) - _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) - if isinstance(_ct9_rhs_moduli, (int, np.integer)): - _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] - ct9_rhs = Polynomial( - { - "batch": _ct9_rhs_data.shape[0], - "num_elements": _ct9_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_rhs_m, - "precision": 32, - "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), - }, - {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, - ) - ct9_rhs.polynomial = _ct9_rhs_data.reshape( - _ct9_rhs_data.shape[0], - _ct9_rhs_data.shape[1], - _ct9_rhs_r, - _ct9_rhs_c, - _ct9_rhs_m_in, - )[..., :_ct9_rhs_m].copy() - ct9_rhs.batch = ct9_rhs.polynomial.shape[0] - ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] - ct9_rhs.num_moduli = _ct9_rhs_m - ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) - ct9_rhs.r = _ct9_rhs_r - ct9_rhs.c = _ct9_rhs_c - ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] - ct9_rhs.moduli_array = jnp.array( - ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) - ) - ct9.add(ct9_rhs) - _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) - ct9.polynomial = jnp.where( - ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial - ) - _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - _ct10_m_in = _ct10_data.shape[-1] - _ct10_m = _ct10_m_in - _ct10_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_r) - ) - _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) - if isinstance(_ct10_moduli, (int, np.integer)): - _ct10_moduli = [int(_ct10_moduli)] - ct10 = Polynomial( - { - "batch": _ct10_data.shape[0], - "num_elements": _ct10_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_m, - "precision": 32, - "degree_layout": (_ct10_r, _ct10_c), - }, - {"moduli": list(_ct10_moduli)[:_ct10_m]}, - ) - ct10.polynomial = _ct10_data.reshape( - _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in - )[..., :_ct10_m].copy() - ct10.batch = ct10.polynomial.shape[0] - ct10.num_elements = ct10.polynomial.shape[1] - ct10.num_moduli = _ct10_m - ct10.degree_layout = (_ct10_r, _ct10_c) - ct10.r = _ct10_r - ct10.c = _ct10_c - ct10.moduli = list(_ct10_moduli)[:_ct10_m] - ct10.moduli_array = jnp.array( - ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) - ) - _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] - _ct10_rhs_m = _ct10_rhs_m_in - _ct10_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_rhs_r) - ) - _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) - if isinstance(_ct10_rhs_moduli, (int, np.integer)): - _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] - ct10_rhs = Polynomial( - { - "batch": _ct10_rhs_data.shape[0], - "num_elements": _ct10_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_rhs_m, - "precision": 32, - "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), - }, - {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, - ) - ct10_rhs.polynomial = _ct10_rhs_data.reshape( - _ct10_rhs_data.shape[0], - _ct10_rhs_data.shape[1], - _ct10_rhs_r, - _ct10_rhs_c, - _ct10_rhs_m_in, - )[..., :_ct10_rhs_m].copy() - ct10_rhs.batch = ct10_rhs.polynomial.shape[0] - ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] - ct10_rhs.num_moduli = _ct10_rhs_m - ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) - ct10_rhs.r = _ct10_rhs_r - ct10_rhs.c = _ct10_rhs_c - ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] - ct10_rhs.moduli_array = jnp.array( - ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) - ) - ct10.add(ct10_rhs) - _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) - ct10.polynomial = jnp.where( - ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial - ) - _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 - _ct11_arg_m_in = _ct11_arg_data.shape[-1] - _ct11_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_arg_m_in - ) - _ct11_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_arg_r) - ) - _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) - if isinstance(_ct11_arg_moduli, (int, np.integer)): - _ct11_arg_moduli = [int(_ct11_arg_moduli)] - ct11_arg = Polynomial( - { - "batch": _ct11_arg_data.shape[0], - "num_elements": _ct11_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_arg_m, - "precision": 32, - "degree_layout": (_ct11_arg_r, _ct11_arg_c), - }, - {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, - ) - ct11_arg.polynomial = _ct11_arg_data.reshape( - _ct11_arg_data.shape[0], - _ct11_arg_data.shape[1], - _ct11_arg_r, - _ct11_arg_c, - _ct11_arg_m_in, - )[..., :_ct11_arg_m].copy() - ct11_arg.batch = ct11_arg.polynomial.shape[0] - ct11_arg.num_elements = ct11_arg.polynomial.shape[1] - ct11_arg.num_moduli = _ct11_arg_m - ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) - ct11_arg.r = _ct11_arg_r - ct11_arg.c = _ct11_arg_c - ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] - ct11_arg.moduli_array = jnp.array( - ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) - ) - ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) - _ct11_data = ( - ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw - ) - _ct11_m_in = _ct11_data.shape[-1] - _ct11_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_m_in - ) - _ct11_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_r) - ) - _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) - if isinstance(_ct11_moduli, (int, np.integer)): - _ct11_moduli = [int(_ct11_moduli)] - ct11 = Polynomial( - { - "batch": _ct11_data.shape[0], - "num_elements": _ct11_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_m, - "precision": 32, - "degree_layout": (_ct11_r, _ct11_c), - }, - {"moduli": list(_ct11_moduli)[:_ct11_m]}, - ) - ct11.polynomial = _ct11_data.reshape( - _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in - )[..., :_ct11_m].copy() - ct11.batch = ct11.polynomial.shape[0] - ct11.num_elements = ct11.polynomial.shape[1] - ct11.num_moduli = _ct11_m - ct11.degree_layout = (_ct11_r, _ct11_c) - ct11.r = _ct11_r - ct11.c = _ct11_c - ct11.moduli = list(_ct11_moduli)[:_ct11_m] - ct11.moduli_array = jnp.array( - ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) - ) - _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct12_arg_m_in = _ct12_arg_data.shape[-1] - _ct12_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_arg_m_in - ) - _ct12_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_arg_r) - ) - _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct12_arg_moduli, (int, np.integer)): - _ct12_arg_moduli = [int(_ct12_arg_moduli)] - ct12_arg = Polynomial( - { - "batch": _ct12_arg_data.shape[0], - "num_elements": _ct12_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_arg_m, - "precision": 32, - "degree_layout": (_ct12_arg_r, _ct12_arg_c), - }, - {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, - ) - ct12_arg.polynomial = _ct12_arg_data.reshape( - _ct12_arg_data.shape[0], - _ct12_arg_data.shape[1], - _ct12_arg_r, - _ct12_arg_c, - _ct12_arg_m_in, - )[..., :_ct12_arg_m].copy() - ct12_arg.batch = ct12_arg.polynomial.shape[0] - ct12_arg.num_elements = ct12_arg.polynomial.shape[1] - ct12_arg.num_moduli = _ct12_arg_m - ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) - ct12_arg.r = _ct12_arg_r - ct12_arg.c = _ct12_arg_c - ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] - ct12_arg.moduli_array = jnp.array( - ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) - ) - ct12_pt_ntt = ( - pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] - .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct12_ptct = v0.ptct_mul[v0.max_level] - ct12_ptct.set_plaintext(ct12_pt_ntt) - ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) - _ct12_data = ( - ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw - ) - _ct12_m_in = _ct12_data.shape[-1] - _ct12_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_m_in - ) - _ct12_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_r) - ) - _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) - if isinstance(_ct12_moduli, (int, np.integer)): - _ct12_moduli = [int(_ct12_moduli)] - ct12 = Polynomial( - { - "batch": _ct12_data.shape[0], - "num_elements": _ct12_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_m, - "precision": 32, - "degree_layout": (_ct12_r, _ct12_c), - }, - {"moduli": list(_ct12_moduli)[:_ct12_m]}, - ) - ct12.polynomial = _ct12_data.reshape( - _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in - )[..., :_ct12_m].copy() - ct12.batch = ct12.polynomial.shape[0] - ct12.num_elements = ct12.polynomial.shape[1] - ct12.num_moduli = _ct12_m - ct12.degree_layout = (_ct12_r, _ct12_c) - ct12.r = _ct12_r - ct12.c = _ct12_c - ct12.moduli = list(_ct12_moduli)[:_ct12_m] - ct12.moduli_array = jnp.array( - ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) - ) - _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct13_arg_m_in = _ct13_arg_data.shape[-1] - _ct13_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_arg_m_in - ) - _ct13_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_arg_r) - ) - _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct13_arg_moduli, (int, np.integer)): - _ct13_arg_moduli = [int(_ct13_arg_moduli)] - ct13_arg = Polynomial( - { - "batch": _ct13_arg_data.shape[0], - "num_elements": _ct13_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_arg_m, - "precision": 32, - "degree_layout": (_ct13_arg_r, _ct13_arg_c), - }, - {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, - ) - ct13_arg.polynomial = _ct13_arg_data.reshape( - _ct13_arg_data.shape[0], - _ct13_arg_data.shape[1], - _ct13_arg_r, - _ct13_arg_c, - _ct13_arg_m_in, - )[..., :_ct13_arg_m].copy() - ct13_arg.batch = ct13_arg.polynomial.shape[0] - ct13_arg.num_elements = ct13_arg.polynomial.shape[1] - ct13_arg.num_moduli = _ct13_arg_m - ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) - ct13_arg.r = _ct13_arg_r - ct13_arg.c = _ct13_arg_c - ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] - ct13_arg.moduli_array = jnp.array( - ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) - ) - ct13_pt_ntt = ( - pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] - .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct13_ptct = v0.ptct_mul[v0.max_level] - ct13_ptct.set_plaintext(ct13_pt_ntt) - ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) - _ct13_data = ( - ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw - ) - _ct13_m_in = _ct13_data.shape[-1] - _ct13_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_m_in - ) - _ct13_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_r) - ) - _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) - if isinstance(_ct13_moduli, (int, np.integer)): - _ct13_moduli = [int(_ct13_moduli)] - ct13 = Polynomial( - { - "batch": _ct13_data.shape[0], - "num_elements": _ct13_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_m, - "precision": 32, - "degree_layout": (_ct13_r, _ct13_c), - }, - {"moduli": list(_ct13_moduli)[:_ct13_m]}, - ) - ct13.polynomial = _ct13_data.reshape( - _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in - )[..., :_ct13_m].copy() - ct13.batch = ct13.polynomial.shape[0] - ct13.num_elements = ct13.polynomial.shape[1] - ct13.num_moduli = _ct13_m - ct13.degree_layout = (_ct13_r, _ct13_c) - ct13.r = _ct13_r - ct13.c = _ct13_c - ct13.moduli = list(_ct13_moduli)[:_ct13_m] - ct13.moduli_array = jnp.array( - ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) - ) - _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - _ct14_m_in = _ct14_data.shape[-1] - _ct14_m = _ct14_m_in - _ct14_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_r) - ) - _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) - if isinstance(_ct14_moduli, (int, np.integer)): - _ct14_moduli = [int(_ct14_moduli)] - ct14 = Polynomial( - { - "batch": _ct14_data.shape[0], - "num_elements": _ct14_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_m, - "precision": 32, - "degree_layout": (_ct14_r, _ct14_c), - }, - {"moduli": list(_ct14_moduli)[:_ct14_m]}, - ) - ct14.polynomial = _ct14_data.reshape( - _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in - )[..., :_ct14_m].copy() - ct14.batch = ct14.polynomial.shape[0] - ct14.num_elements = ct14.polynomial.shape[1] - ct14.num_moduli = _ct14_m - ct14.degree_layout = (_ct14_r, _ct14_c) - ct14.r = _ct14_r - ct14.c = _ct14_c - ct14.moduli = list(_ct14_moduli)[:_ct14_m] - ct14.moduli_array = jnp.array( - ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) - ) - _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] - _ct14_rhs_m = _ct14_rhs_m_in - _ct14_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_rhs_r) - ) - _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) - if isinstance(_ct14_rhs_moduli, (int, np.integer)): - _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] - ct14_rhs = Polynomial( - { - "batch": _ct14_rhs_data.shape[0], - "num_elements": _ct14_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_rhs_m, - "precision": 32, - "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), - }, - {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, - ) - ct14_rhs.polynomial = _ct14_rhs_data.reshape( - _ct14_rhs_data.shape[0], - _ct14_rhs_data.shape[1], - _ct14_rhs_r, - _ct14_rhs_c, - _ct14_rhs_m_in, - )[..., :_ct14_rhs_m].copy() - ct14_rhs.batch = ct14_rhs.polynomial.shape[0] - ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] - ct14_rhs.num_moduli = _ct14_rhs_m - ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) - ct14_rhs.r = _ct14_rhs_r - ct14_rhs.c = _ct14_rhs_c - ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] - ct14_rhs.moduli_array = jnp.array( - ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) - ) - ct14.add(ct14_rhs) - _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) - ct14.polynomial = jnp.where( - ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial - ) - _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 - _ct15_arg_m_in = _ct15_arg_data.shape[-1] - _ct15_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_arg_m_in - ) - _ct15_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_arg_r) - ) - _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) - if isinstance(_ct15_arg_moduli, (int, np.integer)): - _ct15_arg_moduli = [int(_ct15_arg_moduli)] - ct15_arg = Polynomial( - { - "batch": _ct15_arg_data.shape[0], - "num_elements": _ct15_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_arg_m, - "precision": 32, - "degree_layout": (_ct15_arg_r, _ct15_arg_c), - }, - {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, - ) - ct15_arg.polynomial = _ct15_arg_data.reshape( - _ct15_arg_data.shape[0], - _ct15_arg_data.shape[1], - _ct15_arg_r, - _ct15_arg_c, - _ct15_arg_m_in, - )[..., :_ct15_arg_m].copy() - ct15_arg.batch = ct15_arg.polynomial.shape[0] - ct15_arg.num_elements = ct15_arg.polynomial.shape[1] - ct15_arg.num_moduli = _ct15_arg_m - ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) - ct15_arg.r = _ct15_arg_r - ct15_arg.c = _ct15_arg_c - ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] - ct15_arg.moduli_array = jnp.array( - ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) - ) - ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) - _ct15_data = ( - ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw - ) - _ct15_m_in = _ct15_data.shape[-1] - _ct15_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_m_in - ) - _ct15_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_r) - ) - _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) - if isinstance(_ct15_moduli, (int, np.integer)): - _ct15_moduli = [int(_ct15_moduli)] - ct15 = Polynomial( - { - "batch": _ct15_data.shape[0], - "num_elements": _ct15_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_m, - "precision": 32, - "degree_layout": (_ct15_r, _ct15_c), - }, - {"moduli": list(_ct15_moduli)[:_ct15_m]}, - ) - ct15.polynomial = _ct15_data.reshape( - _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in - )[..., :_ct15_m].copy() - ct15.batch = ct15.polynomial.shape[0] - ct15.num_elements = ct15.polynomial.shape[1] - ct15.num_moduli = _ct15_m - ct15.degree_layout = (_ct15_r, _ct15_c) - ct15.r = _ct15_r - ct15.c = _ct15_c - ct15.moduli = list(_ct15_moduli)[:_ct15_m] - ct15.moduli_array = jnp.array( - ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) - ) - _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - _ct16_m_in = _ct16_data.shape[-1] - _ct16_m = _ct16_m_in - _ct16_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_r) - ) - _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) - if isinstance(_ct16_moduli, (int, np.integer)): - _ct16_moduli = [int(_ct16_moduli)] - ct16 = Polynomial( - { - "batch": _ct16_data.shape[0], - "num_elements": _ct16_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_m, - "precision": 32, - "degree_layout": (_ct16_r, _ct16_c), - }, - {"moduli": list(_ct16_moduli)[:_ct16_m]}, - ) - ct16.polynomial = _ct16_data.reshape( - _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in - )[..., :_ct16_m].copy() - ct16.batch = ct16.polynomial.shape[0] - ct16.num_elements = ct16.polynomial.shape[1] - ct16.num_moduli = _ct16_m - ct16.degree_layout = (_ct16_r, _ct16_c) - ct16.r = _ct16_r - ct16.c = _ct16_c - ct16.moduli = list(_ct16_moduli)[:_ct16_m] - ct16.moduli_array = jnp.array( - ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) - ) - _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 - _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] - _ct16_rhs_m = _ct16_rhs_m_in - _ct16_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_rhs_r) - ) - _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) - if isinstance(_ct16_rhs_moduli, (int, np.integer)): - _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] - ct16_rhs = Polynomial( - { - "batch": _ct16_rhs_data.shape[0], - "num_elements": _ct16_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_rhs_m, - "precision": 32, - "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), - }, - {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, - ) - ct16_rhs.polynomial = _ct16_rhs_data.reshape( - _ct16_rhs_data.shape[0], - _ct16_rhs_data.shape[1], - _ct16_rhs_r, - _ct16_rhs_c, - _ct16_rhs_m_in, - )[..., :_ct16_rhs_m].copy() - ct16_rhs.batch = ct16_rhs.polynomial.shape[0] - ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] - ct16_rhs.num_moduli = _ct16_rhs_m - ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) - ct16_rhs.r = _ct16_rhs_r - ct16_rhs.c = _ct16_rhs_c - ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] - ct16_rhs.moduli_array = jnp.array( - ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) - ) - ct16.add(ct16_rhs) - _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) - ct16.polynomial = jnp.where( - ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial - ) - _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - _ct17_m_in = _ct17_data.shape[-1] - _ct17_m = _ct17_m_in - _ct17_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_r) - ) - _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) - if isinstance(_ct17_moduli, (int, np.integer)): - _ct17_moduli = [int(_ct17_moduli)] - ct17 = Polynomial( - { - "batch": _ct17_data.shape[0], - "num_elements": _ct17_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_m, - "precision": 32, - "degree_layout": (_ct17_r, _ct17_c), - }, - {"moduli": list(_ct17_moduli)[:_ct17_m]}, - ) - ct17.polynomial = _ct17_data.reshape( - _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in - )[..., :_ct17_m].copy() - ct17.batch = ct17.polynomial.shape[0] - ct17.num_elements = ct17.polynomial.shape[1] - ct17.num_moduli = _ct17_m - ct17.degree_layout = (_ct17_r, _ct17_c) - ct17.r = _ct17_r - ct17.c = _ct17_c - ct17.moduli = list(_ct17_moduli)[:_ct17_m] - ct17.moduli_array = jnp.array( - ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) - ) - _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] - _ct17_rhs_m = _ct17_rhs_m_in - _ct17_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_rhs_r) - ) - _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) - if isinstance(_ct17_rhs_moduli, (int, np.integer)): - _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] - ct17_rhs = Polynomial( - { - "batch": _ct17_rhs_data.shape[0], - "num_elements": _ct17_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_rhs_m, - "precision": 32, - "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), - }, - {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, - ) - ct17_rhs.polynomial = _ct17_rhs_data.reshape( - _ct17_rhs_data.shape[0], - _ct17_rhs_data.shape[1], - _ct17_rhs_r, - _ct17_rhs_c, - _ct17_rhs_m_in, - )[..., :_ct17_rhs_m].copy() - ct17_rhs.batch = ct17_rhs.polynomial.shape[0] - ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] - ct17_rhs.num_moduli = _ct17_rhs_m - ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) - ct17_rhs.r = _ct17_rhs_r - ct17_rhs.c = _ct17_rhs_c - ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] - ct17_rhs.moduli_array = jnp.array( - ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) - ) - ct17.add(ct17_rhs) - _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) - ct17.polynomial = jnp.where( - ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial - ) - _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 - _ct18_m_in = _ct18_data.shape[-1] - _ct18_m = _ct18_m_in - _ct18_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_r) - ) - _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) - if isinstance(_ct18_moduli, (int, np.integer)): - _ct18_moduli = [int(_ct18_moduli)] - ct18 = Polynomial( - { - "batch": _ct18_data.shape[0], - "num_elements": _ct18_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_m, - "precision": 32, - "degree_layout": (_ct18_r, _ct18_c), - }, - {"moduli": list(_ct18_moduli)[:_ct18_m]}, - ) - ct18.polynomial = _ct18_data.reshape( - _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in - )[..., :_ct18_m].copy() - ct18.batch = ct18.polynomial.shape[0] - ct18.num_elements = ct18.polynomial.shape[1] - ct18.num_moduli = _ct18_m - ct18.degree_layout = (_ct18_r, _ct18_c) - ct18.r = _ct18_r - ct18.c = _ct18_c - ct18.moduli = list(_ct18_moduli)[:_ct18_m] - ct18.moduli_array = jnp.array( - ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) - ) - _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 - _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] - _ct18_rhs_m = _ct18_rhs_m_in - _ct18_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_rhs_r) - ) - _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) - if isinstance(_ct18_rhs_moduli, (int, np.integer)): - _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] - ct18_rhs = Polynomial( - { - "batch": _ct18_rhs_data.shape[0], - "num_elements": _ct18_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_rhs_m, - "precision": 32, - "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), - }, - {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, - ) - ct18_rhs.polynomial = _ct18_rhs_data.reshape( - _ct18_rhs_data.shape[0], - _ct18_rhs_data.shape[1], - _ct18_rhs_r, - _ct18_rhs_c, - _ct18_rhs_m_in, - )[..., :_ct18_rhs_m].copy() - ct18_rhs.batch = ct18_rhs.polynomial.shape[0] - ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] - ct18_rhs.num_moduli = _ct18_rhs_m - ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) - ct18_rhs.r = _ct18_rhs_r - ct18_rhs.c = _ct18_rhs_c - ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] - ct18_rhs.moduli_array = jnp.array( - ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) - ) - ct18.add(ct18_rhs) - _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) - ct18.polynomial = jnp.where( - ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial - ) - _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 - _ct19_m_in = _ct19_data.shape[-1] - _ct19_m = _ct19_m_in - _ct19_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_r) - ) - _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) - if isinstance(_ct19_moduli, (int, np.integer)): - _ct19_moduli = [int(_ct19_moduli)] - ct19 = Polynomial( - { - "batch": _ct19_data.shape[0], - "num_elements": _ct19_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_m, - "precision": 32, - "degree_layout": (_ct19_r, _ct19_c), - }, - {"moduli": list(_ct19_moduli)[:_ct19_m]}, - ) - ct19.polynomial = _ct19_data.reshape( - _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in - )[..., :_ct19_m].copy() - ct19.batch = ct19.polynomial.shape[0] - ct19.num_elements = ct19.polynomial.shape[1] - ct19.num_moduli = _ct19_m - ct19.degree_layout = (_ct19_r, _ct19_c) - ct19.r = _ct19_r - ct19.c = _ct19_c - ct19.moduli = list(_ct19_moduli)[:_ct19_m] - ct19.moduli_array = jnp.array( - ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) - ) - _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 - _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] - _ct19_rhs_m = _ct19_rhs_m_in - _ct19_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_rhs_r) - ) - _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) - if isinstance(_ct19_rhs_moduli, (int, np.integer)): - _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] - ct19_rhs = Polynomial( - { - "batch": _ct19_rhs_data.shape[0], - "num_elements": _ct19_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_rhs_m, - "precision": 32, - "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), - }, - {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, - ) - ct19_rhs.polynomial = _ct19_rhs_data.reshape( - _ct19_rhs_data.shape[0], - _ct19_rhs_data.shape[1], - _ct19_rhs_r, - _ct19_rhs_c, - _ct19_rhs_m_in, - )[..., :_ct19_rhs_m].copy() - ct19_rhs.batch = ct19_rhs.polynomial.shape[0] - ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] - ct19_rhs.num_moduli = _ct19_rhs_m - ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) - ct19_rhs.r = _ct19_rhs_r - ct19_rhs.c = _ct19_rhs_c - ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] - ct19_rhs.moduli_array = jnp.array( - ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) - ) - ct19.add(ct19_rhs) - _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) - ct19.polynomial = jnp.where( - ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial - ) - v16 = [None] * 1 - _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 - _ct20_arg_m_in = _ct20_arg_data.shape[-1] - _ct20_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct20_arg_m_in - ) - _ct20_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_arg_r) - ) - _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) - if isinstance(_ct20_arg_moduli, (int, np.integer)): - _ct20_arg_moduli = [int(_ct20_arg_moduli)] - ct20_arg = Polynomial( - { - "batch": _ct20_arg_data.shape[0], - "num_elements": _ct20_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_arg_m, - "precision": 32, - "degree_layout": (_ct20_arg_r, _ct20_arg_c), - }, - {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, - ) - ct20_arg.polynomial = _ct20_arg_data.reshape( - _ct20_arg_data.shape[0], - _ct20_arg_data.shape[1], - _ct20_arg_r, - _ct20_arg_c, - _ct20_arg_m_in, - )[..., :_ct20_arg_m].copy() - ct20_arg.batch = ct20_arg.polynomial.shape[0] - ct20_arg.num_elements = ct20_arg.polynomial.shape[1] - ct20_arg.num_moduli = _ct20_arg_m - ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) - ct20_arg.r = _ct20_arg_r - ct20_arg.c = _ct20_arg_c - ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] - ct20_arg.moduli_array = jnp.array( - ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) - ) - ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) - _ct20_data = ( - ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw - ) - _ct20_m_in = _ct20_data.shape[-1] - _ct20_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct20_m_in - ) - _ct20_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_r) - ) - _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) - if isinstance(_ct20_moduli, (int, np.integer)): - _ct20_moduli = [int(_ct20_moduli)] - ct20 = Polynomial( - { - "batch": _ct20_data.shape[0], - "num_elements": _ct20_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_m, - "precision": 32, - "degree_layout": (_ct20_r, _ct20_c), - }, - {"moduli": list(_ct20_moduli)[:_ct20_m]}, - ) - ct20.polynomial = _ct20_data.reshape( - _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in - )[..., :_ct20_m].copy() - ct20.batch = ct20.polynomial.shape[0] - ct20.num_elements = ct20.polynomial.shape[1] - ct20.num_moduli = _ct20_m - ct20.degree_layout = (_ct20_r, _ct20_c) - ct20.r = _ct20_r - ct20.c = _ct20_c - ct20.moduli = list(_ct20_moduli)[:_ct20_m] - ct20.moduli_array = jnp.array( - ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) - ) - v16[0] = ct20 - v17 = v16 - return v17 - - -def matvec_identity( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_identity__preprocessing(v0, v1) - v11 = matvec_identity__preprocessed( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10 - ) - return v11 - - -def matvec_identity__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw - _ct_m_in = _ct_data.shape[-1] - _ct_m = _ct_m_in - _ct_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct_r) - ) - _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) - if isinstance(_ct_moduli, (int, np.integer)): - _ct_moduli = [int(_ct_moduli)] - ct = Polynomial( - { - "batch": _ct_data.shape[0], - "num_elements": _ct_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct_m, - "precision": 32, - "degree_layout": (_ct_r, _ct_c), - }, - {"moduli": list(_ct_moduli)[:_ct_m]}, - ) - ct.polynomial = _ct_data.reshape( - _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in - )[..., :_ct_m].copy() - ct.batch = ct.polynomial.shape[0] - ct.num_elements = ct.polynomial.shape[1] - ct.num_moduli = _ct_m - ct.degree_layout = (_ct_r, _ct_c) - ct.r = _ct_r - ct.c = _ct_c - ct.moduli = list(_ct_moduli)[:_ct_m] - ct.moduli_array = jnp.array( - ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) - ) - v16 = [ct] - return v16 - - -def matvec_identity__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 0 - v8 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - _num_moduli = ct.polynomial.shape[-1] - _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": ct.polynomial.shape[0], - "num_elements": ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - ct.polynomial.reshape( - ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli - ) - ) - pt = v0.decrypt(_ct_for_dec) - v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v10 = v8.copy() - for v11 in range(0, 8): - v13 = int(v11) - v14 = v9[0, v13] - v10[v13] = v14 - return v10 - - -def matvec_shift__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> ( - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -): - v2 = np.array( - [ - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 1.000000e00, - 1.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - 0.000000e00, - ], - dtype=np.float32, - ).reshape(8, 8) - v3 = _assign_layout_15335824159471298539(v2) - v4 = v3[3 : 3 + 1, 0 : 0 + 5] - v5 = v3[3 : 3 + 1, 5 : 5 + 3] - v6 = np.zeros( - ( - 1, - 8, - ), - dtype=np.float32, - ) - v7 = v6.copy() - v7[0 : 0 + 1, 3 : 3 + 5] = v4 - v8 = v7.copy() - v8[0 : 0 + 1, 0 : 0 + 3] = v5 - v9 = v3[4 : 4 + 1, 0 : 0 + 5] - v10 = v3[4 : 4 + 1, 5 : 5 + 3] - v11 = v6.copy() - v11[0 : 0 + 1, 3 : 3 + 5] = v9 - v12 = v11.copy() - v12[0 : 0 + 1, 0 : 0 + 3] = v10 - v13 = v3[5 : 5 + 1, 0 : 0 + 5] - v14 = v3[5 : 5 + 1, 5 : 5 + 3] - v15 = v6.copy() - v15[0 : 0 + 1, 3 : 3 + 5] = v13 - v16 = v15.copy() - v16[0 : 0 + 1, 0 : 0 + 3] = v14 - v17 = v3[6 : 6 + 1, 0 : 0 + 2] - v18 = v3[6 : 6 + 1, 2 : 2 + 6] - v19 = v6.copy() - v19[0 : 0 + 1, 6 : 6 + 2] = v17 - v20 = v19.copy() - v20[0 : 0 + 1, 0 : 0 + 6] = v18 - v21 = v3[7 : 7 + 1, 0 : 0 + 2] - v22 = v3[7 : 7 + 1, 2 : 2 + 6] - v23 = v6.copy() - v23[0 : 0 + 1, 6 : 6 + 2] = v21 - v24 = v23.copy() - v24[0 : 0 + 1, 0 : 0 + 6] = v22 - v25 = v3[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v25) - v26 = v3[1 : 1 + 1, 0 : 0 + 8].reshape(8) - pt1 = v0.encode(v26) - v27 = v3[2 : 2 + 1, 0 : 0 + 8].reshape(8) - pt2 = v0.encode(v27) - v28 = v8[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt3 = v0.encode(v28) - v29 = v12[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt4 = v0.encode(v29) - v30 = v16[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt5 = v0.encode(v30) - v31 = v20[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt6 = v0.encode(v31) - v32 = v24[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt7 = v0.encode(v32) - v33 = [pt] - v34 = [pt1] - v35 = [pt2] - v36 = [pt3] - v37 = [pt4] - v38 = [pt5] - v39 = [pt6] - v40 = [pt7] - return (v33, v34, v35, v36, v37, v38, v39, v40) - - -def matvec_shift__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, - v5: np.ndarray, - v6: np.ndarray, - v7: np.ndarray, - v8: np.ndarray, - v9: np.ndarray, - v10: np.ndarray, -) -> np.ndarray: - v11 = 1 - v12 = 2 - v13 = 3 - v14 = 6 - v15 = 0 - pt = v3[0] - pt1 = v4[0] - pt2 = v5[0] - pt3 = v6[0] - pt4 = v7[0] - pt5 = v8[0] - pt6 = v9[0] - pt7 = v10[0] - ct = v2[0] - _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct1_arg_m_in = _ct1_arg_data.shape[-1] - _ct1_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_arg_m_in - ) - _ct1_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_arg_r) - ) - _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct1_arg_moduli, (int, np.integer)): - _ct1_arg_moduli = [int(_ct1_arg_moduli)] - ct1_arg = Polynomial( - { - "batch": _ct1_arg_data.shape[0], - "num_elements": _ct1_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_arg_m, - "precision": 32, - "degree_layout": (_ct1_arg_r, _ct1_arg_c), - }, - {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, - ) - ct1_arg.polynomial = _ct1_arg_data.reshape( - _ct1_arg_data.shape[0], - _ct1_arg_data.shape[1], - _ct1_arg_r, - _ct1_arg_c, - _ct1_arg_m_in, - )[..., :_ct1_arg_m].copy() - ct1_arg.batch = ct1_arg.polynomial.shape[0] - ct1_arg.num_elements = ct1_arg.polynomial.shape[1] - ct1_arg.num_moduli = _ct1_arg_m - ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) - ct1_arg.r = _ct1_arg_r - ct1_arg.c = _ct1_arg_c - ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] - ct1_arg.moduli_array = jnp.array( - ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) - ) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) - _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw - _ct1_m_in = _ct1_data.shape[-1] - _ct1_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_m_in - ) - _ct1_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_r) - ) - _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) - if isinstance(_ct1_moduli, (int, np.integer)): - _ct1_moduli = [int(_ct1_moduli)] - ct1 = Polynomial( - { - "batch": _ct1_data.shape[0], - "num_elements": _ct1_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_m, - "precision": 32, - "degree_layout": (_ct1_r, _ct1_c), - }, - {"moduli": list(_ct1_moduli)[:_ct1_m]}, - ) - ct1.polynomial = _ct1_data.reshape( - _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in - )[..., :_ct1_m].copy() - ct1.batch = ct1.polynomial.shape[0] - ct1.num_elements = ct1.polynomial.shape[1] - ct1.num_moduli = _ct1_m - ct1.degree_layout = (_ct1_r, _ct1_c) - ct1.r = _ct1_r - ct1.c = _ct1_c - ct1.moduli = list(_ct1_moduli)[:_ct1_m] - ct1.moduli_array = jnp.array( - ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) - ) - _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct2_arg_m_in = _ct2_arg_data.shape[-1] - _ct2_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_arg_m_in - ) - _ct2_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_arg_r) - ) - _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct2_arg_moduli, (int, np.integer)): - _ct2_arg_moduli = [int(_ct2_arg_moduli)] - ct2_arg = Polynomial( - { - "batch": _ct2_arg_data.shape[0], - "num_elements": _ct2_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_arg_m, - "precision": 32, - "degree_layout": (_ct2_arg_r, _ct2_arg_c), - }, - {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, - ) - ct2_arg.polynomial = _ct2_arg_data.reshape( - _ct2_arg_data.shape[0], - _ct2_arg_data.shape[1], - _ct2_arg_r, - _ct2_arg_c, - _ct2_arg_m_in, - )[..., :_ct2_arg_m].copy() - ct2_arg.batch = ct2_arg.polynomial.shape[0] - ct2_arg.num_elements = ct2_arg.polynomial.shape[1] - ct2_arg.num_moduli = _ct2_arg_m - ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) - ct2_arg.r = _ct2_arg_r - ct2_arg.c = _ct2_arg_c - ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] - ct2_arg.moduli_array = jnp.array( - ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) - ) - ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw - _ct2_m_in = _ct2_data.shape[-1] - _ct2_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_m_in - ) - _ct2_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_r) - ) - _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) - if isinstance(_ct2_moduli, (int, np.integer)): - _ct2_moduli = [int(_ct2_moduli)] - ct2 = Polynomial( - { - "batch": _ct2_data.shape[0], - "num_elements": _ct2_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_m, - "precision": 32, - "degree_layout": (_ct2_r, _ct2_c), - }, - {"moduli": list(_ct2_moduli)[:_ct2_m]}, - ) - ct2.polynomial = _ct2_data.reshape( - _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in - )[..., :_ct2_m].copy() - ct2.batch = ct2.polynomial.shape[0] - ct2.num_elements = ct2.polynomial.shape[1] - ct2.num_moduli = _ct2_m - ct2.degree_layout = (_ct2_r, _ct2_c) - ct2.r = _ct2_r - ct2.c = _ct2_c - ct2.moduli = list(_ct2_moduli)[:_ct2_m] - ct2.moduli_array = jnp.array( - ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) - ) - _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct3_arg_m_in = _ct3_arg_data.shape[-1] - _ct3_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_arg_m_in - ) - _ct3_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_arg_r) - ) - _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct3_arg_moduli, (int, np.integer)): - _ct3_arg_moduli = [int(_ct3_arg_moduli)] - ct3_arg = Polynomial( - { - "batch": _ct3_arg_data.shape[0], - "num_elements": _ct3_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_arg_m, - "precision": 32, - "degree_layout": (_ct3_arg_r, _ct3_arg_c), - }, - {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, - ) - ct3_arg.polynomial = _ct3_arg_data.reshape( - _ct3_arg_data.shape[0], - _ct3_arg_data.shape[1], - _ct3_arg_r, - _ct3_arg_c, - _ct3_arg_m_in, - )[..., :_ct3_arg_m].copy() - ct3_arg.batch = ct3_arg.polynomial.shape[0] - ct3_arg.num_elements = ct3_arg.polynomial.shape[1] - ct3_arg.num_moduli = _ct3_arg_m - ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) - ct3_arg.r = _ct3_arg_r - ct3_arg.c = _ct3_arg_c - ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] - ct3_arg.moduli_array = jnp.array( - ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) - ) - ct3_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] - .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct3_ptct = v0.ptct_mul[v0.max_level] - ct3_ptct.set_plaintext(ct3_pt_ntt) - ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) - _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw - _ct3_m_in = _ct3_data.shape[-1] - _ct3_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_m_in - ) - _ct3_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_r) - ) - _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) - if isinstance(_ct3_moduli, (int, np.integer)): - _ct3_moduli = [int(_ct3_moduli)] - ct3 = Polynomial( - { - "batch": _ct3_data.shape[0], - "num_elements": _ct3_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_m, - "precision": 32, - "degree_layout": (_ct3_r, _ct3_c), - }, - {"moduli": list(_ct3_moduli)[:_ct3_m]}, - ) - ct3.polynomial = _ct3_data.reshape( - _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in - )[..., :_ct3_m].copy() - ct3.batch = ct3.polynomial.shape[0] - ct3.num_elements = ct3.polynomial.shape[1] - ct3.num_moduli = _ct3_m - ct3.degree_layout = (_ct3_r, _ct3_c) - ct3.r = _ct3_r - ct3.c = _ct3_c - ct3.moduli = list(_ct3_moduli)[:_ct3_m] - ct3.moduli_array = jnp.array( - ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) - ) - _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct4_arg_m_in = _ct4_arg_data.shape[-1] - _ct4_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_arg_m_in - ) - _ct4_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_arg_r) - ) - _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct4_arg_moduli, (int, np.integer)): - _ct4_arg_moduli = [int(_ct4_arg_moduli)] - ct4_arg = Polynomial( - { - "batch": _ct4_arg_data.shape[0], - "num_elements": _ct4_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_arg_m, - "precision": 32, - "degree_layout": (_ct4_arg_r, _ct4_arg_c), - }, - {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, - ) - ct4_arg.polynomial = _ct4_arg_data.reshape( - _ct4_arg_data.shape[0], - _ct4_arg_data.shape[1], - _ct4_arg_r, - _ct4_arg_c, - _ct4_arg_m_in, - )[..., :_ct4_arg_m].copy() - ct4_arg.batch = ct4_arg.polynomial.shape[0] - ct4_arg.num_elements = ct4_arg.polynomial.shape[1] - ct4_arg.num_moduli = _ct4_arg_m - ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) - ct4_arg.r = _ct4_arg_r - ct4_arg.c = _ct4_arg_c - ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] - ct4_arg.moduli_array = jnp.array( - ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) - ) - ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) - _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw - _ct4_m_in = _ct4_data.shape[-1] - _ct4_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_m_in - ) - _ct4_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_r) - ) - _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) - if isinstance(_ct4_moduli, (int, np.integer)): - _ct4_moduli = [int(_ct4_moduli)] - ct4 = Polynomial( - { - "batch": _ct4_data.shape[0], - "num_elements": _ct4_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_m, - "precision": 32, - "degree_layout": (_ct4_r, _ct4_c), - }, - {"moduli": list(_ct4_moduli)[:_ct4_m]}, - ) - ct4.polynomial = _ct4_data.reshape( - _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in - )[..., :_ct4_m].copy() - ct4.batch = ct4.polynomial.shape[0] - ct4.num_elements = ct4.polynomial.shape[1] - ct4.num_moduli = _ct4_m - ct4.degree_layout = (_ct4_r, _ct4_c) - ct4.r = _ct4_r - ct4.c = _ct4_c - ct4.moduli = list(_ct4_moduli)[:_ct4_m] - ct4.moduli_array = jnp.array( - ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) - ) - _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct5_arg_m_in = _ct5_arg_data.shape[-1] - _ct5_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_arg_m_in - ) - _ct5_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_arg_r) - ) - _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct5_arg_moduli, (int, np.integer)): - _ct5_arg_moduli = [int(_ct5_arg_moduli)] - ct5_arg = Polynomial( - { - "batch": _ct5_arg_data.shape[0], - "num_elements": _ct5_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_arg_m, - "precision": 32, - "degree_layout": (_ct5_arg_r, _ct5_arg_c), - }, - {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, - ) - ct5_arg.polynomial = _ct5_arg_data.reshape( - _ct5_arg_data.shape[0], - _ct5_arg_data.shape[1], - _ct5_arg_r, - _ct5_arg_c, - _ct5_arg_m_in, - )[..., :_ct5_arg_m].copy() - ct5_arg.batch = ct5_arg.polynomial.shape[0] - ct5_arg.num_elements = ct5_arg.polynomial.shape[1] - ct5_arg.num_moduli = _ct5_arg_m - ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) - ct5_arg.r = _ct5_arg_r - ct5_arg.c = _ct5_arg_c - ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] - ct5_arg.moduli_array = jnp.array( - ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) - ) - ct5_pt_ntt = ( - pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) - _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw - _ct5_m_in = _ct5_data.shape[-1] - _ct5_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_m_in - ) - _ct5_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_r) - ) - _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) - if isinstance(_ct5_moduli, (int, np.integer)): - _ct5_moduli = [int(_ct5_moduli)] - ct5 = Polynomial( - { - "batch": _ct5_data.shape[0], - "num_elements": _ct5_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_m, - "precision": 32, - "degree_layout": (_ct5_r, _ct5_c), - }, - {"moduli": list(_ct5_moduli)[:_ct5_m]}, - ) - ct5.polynomial = _ct5_data.reshape( - _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in - )[..., :_ct5_m].copy() - ct5.batch = ct5.polynomial.shape[0] - ct5.num_elements = ct5.polynomial.shape[1] - ct5.num_moduli = _ct5_m - ct5.degree_layout = (_ct5_r, _ct5_c) - ct5.r = _ct5_r - ct5.c = _ct5_c - ct5.moduli = list(_ct5_moduli)[:_ct5_m] - ct5.moduli_array = jnp.array( - ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) - ) - _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct6_arg_m_in = _ct6_arg_data.shape[-1] - _ct6_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_arg_m_in - ) - _ct6_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_arg_r) - ) - _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct6_arg_moduli, (int, np.integer)): - _ct6_arg_moduli = [int(_ct6_arg_moduli)] - ct6_arg = Polynomial( - { - "batch": _ct6_arg_data.shape[0], - "num_elements": _ct6_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_arg_m, - "precision": 32, - "degree_layout": (_ct6_arg_r, _ct6_arg_c), - }, - {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, - ) - ct6_arg.polynomial = _ct6_arg_data.reshape( - _ct6_arg_data.shape[0], - _ct6_arg_data.shape[1], - _ct6_arg_r, - _ct6_arg_c, - _ct6_arg_m_in, - )[..., :_ct6_arg_m].copy() - ct6_arg.batch = ct6_arg.polynomial.shape[0] - ct6_arg.num_elements = ct6_arg.polynomial.shape[1] - ct6_arg.num_moduli = _ct6_arg_m - ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) - ct6_arg.r = _ct6_arg_r - ct6_arg.c = _ct6_arg_c - ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] - ct6_arg.moduli_array = jnp.array( - ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) - ) - ct6_pt_ntt = ( - pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] - .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct6_ptct = v0.ptct_mul[v0.max_level] - ct6_ptct.set_plaintext(ct6_pt_ntt) - ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) - _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw - _ct6_m_in = _ct6_data.shape[-1] - _ct6_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_m_in - ) - _ct6_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_r) - ) - _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) - if isinstance(_ct6_moduli, (int, np.integer)): - _ct6_moduli = [int(_ct6_moduli)] - ct6 = Polynomial( - { - "batch": _ct6_data.shape[0], - "num_elements": _ct6_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_m, - "precision": 32, - "degree_layout": (_ct6_r, _ct6_c), - }, - {"moduli": list(_ct6_moduli)[:_ct6_m]}, - ) - ct6.polynomial = _ct6_data.reshape( - _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in - )[..., :_ct6_m].copy() - ct6.batch = ct6.polynomial.shape[0] - ct6.num_elements = ct6.polynomial.shape[1] - ct6.num_moduli = _ct6_m - ct6.degree_layout = (_ct6_r, _ct6_c) - ct6.r = _ct6_r - ct6.c = _ct6_c - ct6.moduli = list(_ct6_moduli)[:_ct6_m] - ct6.moduli_array = jnp.array( - ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) - ) - _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct7_arg_m_in = _ct7_arg_data.shape[-1] - _ct7_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_arg_m_in - ) - _ct7_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_arg_r) - ) - _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct7_arg_moduli, (int, np.integer)): - _ct7_arg_moduli = [int(_ct7_arg_moduli)] - ct7_arg = Polynomial( - { - "batch": _ct7_arg_data.shape[0], - "num_elements": _ct7_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_arg_m, - "precision": 32, - "degree_layout": (_ct7_arg_r, _ct7_arg_c), - }, - {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, - ) - ct7_arg.polynomial = _ct7_arg_data.reshape( - _ct7_arg_data.shape[0], - _ct7_arg_data.shape[1], - _ct7_arg_r, - _ct7_arg_c, - _ct7_arg_m_in, - )[..., :_ct7_arg_m].copy() - ct7_arg.batch = ct7_arg.polynomial.shape[0] - ct7_arg.num_elements = ct7_arg.polynomial.shape[1] - ct7_arg.num_moduli = _ct7_arg_m - ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) - ct7_arg.r = _ct7_arg_r - ct7_arg.c = _ct7_arg_c - ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] - ct7_arg.moduli_array = jnp.array( - ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) - ) - ct7_pt_ntt = ( - pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] - .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct7_ptct = v0.ptct_mul[v0.max_level] - ct7_ptct.set_plaintext(ct7_pt_ntt) - ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) - _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw - _ct7_m_in = _ct7_data.shape[-1] - _ct7_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_m_in - ) - _ct7_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_r) - ) - _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) - if isinstance(_ct7_moduli, (int, np.integer)): - _ct7_moduli = [int(_ct7_moduli)] - ct7 = Polynomial( - { - "batch": _ct7_data.shape[0], - "num_elements": _ct7_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_m, - "precision": 32, - "degree_layout": (_ct7_r, _ct7_c), - }, - {"moduli": list(_ct7_moduli)[:_ct7_m]}, - ) - ct7.polynomial = _ct7_data.reshape( - _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in - )[..., :_ct7_m].copy() - ct7.batch = ct7.polynomial.shape[0] - ct7.num_elements = ct7.polynomial.shape[1] - ct7.num_moduli = _ct7_m - ct7.degree_layout = (_ct7_r, _ct7_c) - ct7.r = _ct7_r - ct7.c = _ct7_c - ct7.moduli = list(_ct7_moduli)[:_ct7_m] - ct7.moduli_array = jnp.array( - ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) - ) - _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct8_arg_m_in = _ct8_arg_data.shape[-1] - _ct8_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_arg_m_in - ) - _ct8_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_arg_r) - ) - _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct8_arg_moduli, (int, np.integer)): - _ct8_arg_moduli = [int(_ct8_arg_moduli)] - ct8_arg = Polynomial( - { - "batch": _ct8_arg_data.shape[0], - "num_elements": _ct8_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_arg_m, - "precision": 32, - "degree_layout": (_ct8_arg_r, _ct8_arg_c), - }, - {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, - ) - ct8_arg.polynomial = _ct8_arg_data.reshape( - _ct8_arg_data.shape[0], - _ct8_arg_data.shape[1], - _ct8_arg_r, - _ct8_arg_c, - _ct8_arg_m_in, - )[..., :_ct8_arg_m].copy() - ct8_arg.batch = ct8_arg.polynomial.shape[0] - ct8_arg.num_elements = ct8_arg.polynomial.shape[1] - ct8_arg.num_moduli = _ct8_arg_m - ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) - ct8_arg.r = _ct8_arg_r - ct8_arg.c = _ct8_arg_c - ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] - ct8_arg.moduli_array = jnp.array( - ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) - ) - ct8_pt_ntt = ( - pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] - .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct8_ptct = v0.ptct_mul[v0.max_level] - ct8_ptct.set_plaintext(ct8_pt_ntt) - ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) - _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw - _ct8_m_in = _ct8_data.shape[-1] - _ct8_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_m_in - ) - _ct8_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_r) - ) - _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) - if isinstance(_ct8_moduli, (int, np.integer)): - _ct8_moduli = [int(_ct8_moduli)] - ct8 = Polynomial( - { - "batch": _ct8_data.shape[0], - "num_elements": _ct8_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_m, - "precision": 32, - "degree_layout": (_ct8_r, _ct8_c), - }, - {"moduli": list(_ct8_moduli)[:_ct8_m]}, - ) - ct8.polynomial = _ct8_data.reshape( - _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in - )[..., :_ct8_m].copy() - ct8.batch = ct8.polynomial.shape[0] - ct8.num_elements = ct8.polynomial.shape[1] - ct8.num_moduli = _ct8_m - ct8.degree_layout = (_ct8_r, _ct8_c) - ct8.r = _ct8_r - ct8.c = _ct8_c - ct8.moduli = list(_ct8_moduli)[:_ct8_m] - ct8.moduli_array = jnp.array( - ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) - ) - _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - _ct9_m_in = _ct9_data.shape[-1] - _ct9_m = _ct9_m_in - _ct9_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_r) - ) - _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) - if isinstance(_ct9_moduli, (int, np.integer)): - _ct9_moduli = [int(_ct9_moduli)] - ct9 = Polynomial( - { - "batch": _ct9_data.shape[0], - "num_elements": _ct9_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_m, - "precision": 32, - "degree_layout": (_ct9_r, _ct9_c), - }, - {"moduli": list(_ct9_moduli)[:_ct9_m]}, - ) - ct9.polynomial = _ct9_data.reshape( - _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in - )[..., :_ct9_m].copy() - ct9.batch = ct9.polynomial.shape[0] - ct9.num_elements = ct9.polynomial.shape[1] - ct9.num_moduli = _ct9_m - ct9.degree_layout = (_ct9_r, _ct9_c) - ct9.r = _ct9_r - ct9.c = _ct9_c - ct9.moduli = list(_ct9_moduli)[:_ct9_m] - ct9.moduli_array = jnp.array( - ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) - ) - _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 - _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] - _ct9_rhs_m = _ct9_rhs_m_in - _ct9_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_rhs_r) - ) - _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) - if isinstance(_ct9_rhs_moduli, (int, np.integer)): - _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] - ct9_rhs = Polynomial( - { - "batch": _ct9_rhs_data.shape[0], - "num_elements": _ct9_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_rhs_m, - "precision": 32, - "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), - }, - {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, - ) - ct9_rhs.polynomial = _ct9_rhs_data.reshape( - _ct9_rhs_data.shape[0], - _ct9_rhs_data.shape[1], - _ct9_rhs_r, - _ct9_rhs_c, - _ct9_rhs_m_in, - )[..., :_ct9_rhs_m].copy() - ct9_rhs.batch = ct9_rhs.polynomial.shape[0] - ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] - ct9_rhs.num_moduli = _ct9_rhs_m - ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) - ct9_rhs.r = _ct9_rhs_r - ct9_rhs.c = _ct9_rhs_c - ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] - ct9_rhs.moduli_array = jnp.array( - ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) - ) - ct9.add(ct9_rhs) - _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) - ct9.polynomial = jnp.where( - ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial - ) - _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - _ct10_m_in = _ct10_data.shape[-1] - _ct10_m = _ct10_m_in - _ct10_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_r) - ) - _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) - if isinstance(_ct10_moduli, (int, np.integer)): - _ct10_moduli = [int(_ct10_moduli)] - ct10 = Polynomial( - { - "batch": _ct10_data.shape[0], - "num_elements": _ct10_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_m, - "precision": 32, - "degree_layout": (_ct10_r, _ct10_c), - }, - {"moduli": list(_ct10_moduli)[:_ct10_m]}, - ) - ct10.polynomial = _ct10_data.reshape( - _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in - )[..., :_ct10_m].copy() - ct10.batch = ct10.polynomial.shape[0] - ct10.num_elements = ct10.polynomial.shape[1] - ct10.num_moduli = _ct10_m - ct10.degree_layout = (_ct10_r, _ct10_c) - ct10.r = _ct10_r - ct10.c = _ct10_c - ct10.moduli = list(_ct10_moduli)[:_ct10_m] - ct10.moduli_array = jnp.array( - ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) - ) - _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] - _ct10_rhs_m = _ct10_rhs_m_in - _ct10_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_rhs_r) - ) - _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) - if isinstance(_ct10_rhs_moduli, (int, np.integer)): - _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] - ct10_rhs = Polynomial( - { - "batch": _ct10_rhs_data.shape[0], - "num_elements": _ct10_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_rhs_m, - "precision": 32, - "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), - }, - {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, - ) - ct10_rhs.polynomial = _ct10_rhs_data.reshape( - _ct10_rhs_data.shape[0], - _ct10_rhs_data.shape[1], - _ct10_rhs_r, - _ct10_rhs_c, - _ct10_rhs_m_in, - )[..., :_ct10_rhs_m].copy() - ct10_rhs.batch = ct10_rhs.polynomial.shape[0] - ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] - ct10_rhs.num_moduli = _ct10_rhs_m - ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) - ct10_rhs.r = _ct10_rhs_r - ct10_rhs.c = _ct10_rhs_c - ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] - ct10_rhs.moduli_array = jnp.array( - ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) - ) - ct10.add(ct10_rhs) - _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) - ct10.polynomial = jnp.where( - ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial - ) - _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 - _ct11_arg_m_in = _ct11_arg_data.shape[-1] - _ct11_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_arg_m_in - ) - _ct11_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_arg_r) - ) - _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) - if isinstance(_ct11_arg_moduli, (int, np.integer)): - _ct11_arg_moduli = [int(_ct11_arg_moduli)] - ct11_arg = Polynomial( - { - "batch": _ct11_arg_data.shape[0], - "num_elements": _ct11_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_arg_m, - "precision": 32, - "degree_layout": (_ct11_arg_r, _ct11_arg_c), - }, - {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, - ) - ct11_arg.polynomial = _ct11_arg_data.reshape( - _ct11_arg_data.shape[0], - _ct11_arg_data.shape[1], - _ct11_arg_r, - _ct11_arg_c, - _ct11_arg_m_in, - )[..., :_ct11_arg_m].copy() - ct11_arg.batch = ct11_arg.polynomial.shape[0] - ct11_arg.num_elements = ct11_arg.polynomial.shape[1] - ct11_arg.num_moduli = _ct11_arg_m - ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) - ct11_arg.r = _ct11_arg_r - ct11_arg.c = _ct11_arg_c - ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] - ct11_arg.moduli_array = jnp.array( - ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) - ) - ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) - _ct11_data = ( - ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw - ) - _ct11_m_in = _ct11_data.shape[-1] - _ct11_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_m_in - ) - _ct11_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_r) - ) - _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) - if isinstance(_ct11_moduli, (int, np.integer)): - _ct11_moduli = [int(_ct11_moduli)] - ct11 = Polynomial( - { - "batch": _ct11_data.shape[0], - "num_elements": _ct11_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_m, - "precision": 32, - "degree_layout": (_ct11_r, _ct11_c), - }, - {"moduli": list(_ct11_moduli)[:_ct11_m]}, - ) - ct11.polynomial = _ct11_data.reshape( - _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in - )[..., :_ct11_m].copy() - ct11.batch = ct11.polynomial.shape[0] - ct11.num_elements = ct11.polynomial.shape[1] - ct11.num_moduli = _ct11_m - ct11.degree_layout = (_ct11_r, _ct11_c) - ct11.r = _ct11_r - ct11.c = _ct11_c - ct11.moduli = list(_ct11_moduli)[:_ct11_m] - ct11.moduli_array = jnp.array( - ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) - ) - _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct12_arg_m_in = _ct12_arg_data.shape[-1] - _ct12_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_arg_m_in - ) - _ct12_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_arg_r) - ) - _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct12_arg_moduli, (int, np.integer)): - _ct12_arg_moduli = [int(_ct12_arg_moduli)] - ct12_arg = Polynomial( - { - "batch": _ct12_arg_data.shape[0], - "num_elements": _ct12_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_arg_m, - "precision": 32, - "degree_layout": (_ct12_arg_r, _ct12_arg_c), - }, - {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, - ) - ct12_arg.polynomial = _ct12_arg_data.reshape( - _ct12_arg_data.shape[0], - _ct12_arg_data.shape[1], - _ct12_arg_r, - _ct12_arg_c, - _ct12_arg_m_in, - )[..., :_ct12_arg_m].copy() - ct12_arg.batch = ct12_arg.polynomial.shape[0] - ct12_arg.num_elements = ct12_arg.polynomial.shape[1] - ct12_arg.num_moduli = _ct12_arg_m - ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) - ct12_arg.r = _ct12_arg_r - ct12_arg.c = _ct12_arg_c - ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] - ct12_arg.moduli_array = jnp.array( - ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) - ) - ct12_pt_ntt = ( - pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] - .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct12_ptct = v0.ptct_mul[v0.max_level] - ct12_ptct.set_plaintext(ct12_pt_ntt) - ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) - _ct12_data = ( - ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw - ) - _ct12_m_in = _ct12_data.shape[-1] - _ct12_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_m_in - ) - _ct12_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_r) - ) - _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) - if isinstance(_ct12_moduli, (int, np.integer)): - _ct12_moduli = [int(_ct12_moduli)] - ct12 = Polynomial( - { - "batch": _ct12_data.shape[0], - "num_elements": _ct12_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_m, - "precision": 32, - "degree_layout": (_ct12_r, _ct12_c), - }, - {"moduli": list(_ct12_moduli)[:_ct12_m]}, - ) - ct12.polynomial = _ct12_data.reshape( - _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in - )[..., :_ct12_m].copy() - ct12.batch = ct12.polynomial.shape[0] - ct12.num_elements = ct12.polynomial.shape[1] - ct12.num_moduli = _ct12_m - ct12.degree_layout = (_ct12_r, _ct12_c) - ct12.r = _ct12_r - ct12.c = _ct12_c - ct12.moduli = list(_ct12_moduli)[:_ct12_m] - ct12.moduli_array = jnp.array( - ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) - ) - _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct13_arg_m_in = _ct13_arg_data.shape[-1] - _ct13_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_arg_m_in - ) - _ct13_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_arg_r) - ) - _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct13_arg_moduli, (int, np.integer)): - _ct13_arg_moduli = [int(_ct13_arg_moduli)] - ct13_arg = Polynomial( - { - "batch": _ct13_arg_data.shape[0], - "num_elements": _ct13_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_arg_m, - "precision": 32, - "degree_layout": (_ct13_arg_r, _ct13_arg_c), - }, - {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, - ) - ct13_arg.polynomial = _ct13_arg_data.reshape( - _ct13_arg_data.shape[0], - _ct13_arg_data.shape[1], - _ct13_arg_r, - _ct13_arg_c, - _ct13_arg_m_in, - )[..., :_ct13_arg_m].copy() - ct13_arg.batch = ct13_arg.polynomial.shape[0] - ct13_arg.num_elements = ct13_arg.polynomial.shape[1] - ct13_arg.num_moduli = _ct13_arg_m - ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) - ct13_arg.r = _ct13_arg_r - ct13_arg.c = _ct13_arg_c - ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] - ct13_arg.moduli_array = jnp.array( - ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) - ) - ct13_pt_ntt = ( - pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] - .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct13_ptct = v0.ptct_mul[v0.max_level] - ct13_ptct.set_plaintext(ct13_pt_ntt) - ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) - _ct13_data = ( - ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw - ) - _ct13_m_in = _ct13_data.shape[-1] - _ct13_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_m_in - ) - _ct13_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_r) - ) - _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) - if isinstance(_ct13_moduli, (int, np.integer)): - _ct13_moduli = [int(_ct13_moduli)] - ct13 = Polynomial( - { - "batch": _ct13_data.shape[0], - "num_elements": _ct13_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_m, - "precision": 32, - "degree_layout": (_ct13_r, _ct13_c), - }, - {"moduli": list(_ct13_moduli)[:_ct13_m]}, - ) - ct13.polynomial = _ct13_data.reshape( - _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in - )[..., :_ct13_m].copy() - ct13.batch = ct13.polynomial.shape[0] - ct13.num_elements = ct13.polynomial.shape[1] - ct13.num_moduli = _ct13_m - ct13.degree_layout = (_ct13_r, _ct13_c) - ct13.r = _ct13_r - ct13.c = _ct13_c - ct13.moduli = list(_ct13_moduli)[:_ct13_m] - ct13.moduli_array = jnp.array( - ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) - ) - _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - _ct14_m_in = _ct14_data.shape[-1] - _ct14_m = _ct14_m_in - _ct14_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_r) - ) - _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) - if isinstance(_ct14_moduli, (int, np.integer)): - _ct14_moduli = [int(_ct14_moduli)] - ct14 = Polynomial( - { - "batch": _ct14_data.shape[0], - "num_elements": _ct14_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_m, - "precision": 32, - "degree_layout": (_ct14_r, _ct14_c), - }, - {"moduli": list(_ct14_moduli)[:_ct14_m]}, - ) - ct14.polynomial = _ct14_data.reshape( - _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in - )[..., :_ct14_m].copy() - ct14.batch = ct14.polynomial.shape[0] - ct14.num_elements = ct14.polynomial.shape[1] - ct14.num_moduli = _ct14_m - ct14.degree_layout = (_ct14_r, _ct14_c) - ct14.r = _ct14_r - ct14.c = _ct14_c - ct14.moduli = list(_ct14_moduli)[:_ct14_m] - ct14.moduli_array = jnp.array( - ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) - ) - _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] - _ct14_rhs_m = _ct14_rhs_m_in - _ct14_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_rhs_r) - ) - _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) - if isinstance(_ct14_rhs_moduli, (int, np.integer)): - _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] - ct14_rhs = Polynomial( - { - "batch": _ct14_rhs_data.shape[0], - "num_elements": _ct14_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_rhs_m, - "precision": 32, - "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), - }, - {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, - ) - ct14_rhs.polynomial = _ct14_rhs_data.reshape( - _ct14_rhs_data.shape[0], - _ct14_rhs_data.shape[1], - _ct14_rhs_r, - _ct14_rhs_c, - _ct14_rhs_m_in, - )[..., :_ct14_rhs_m].copy() - ct14_rhs.batch = ct14_rhs.polynomial.shape[0] - ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] - ct14_rhs.num_moduli = _ct14_rhs_m - ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) - ct14_rhs.r = _ct14_rhs_r - ct14_rhs.c = _ct14_rhs_c - ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] - ct14_rhs.moduli_array = jnp.array( - ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) - ) - ct14.add(ct14_rhs) - _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) - ct14.polynomial = jnp.where( - ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial - ) - _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 - _ct15_arg_m_in = _ct15_arg_data.shape[-1] - _ct15_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_arg_m_in - ) - _ct15_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_arg_r) - ) - _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) - if isinstance(_ct15_arg_moduli, (int, np.integer)): - _ct15_arg_moduli = [int(_ct15_arg_moduli)] - ct15_arg = Polynomial( - { - "batch": _ct15_arg_data.shape[0], - "num_elements": _ct15_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_arg_m, - "precision": 32, - "degree_layout": (_ct15_arg_r, _ct15_arg_c), - }, - {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, - ) - ct15_arg.polynomial = _ct15_arg_data.reshape( - _ct15_arg_data.shape[0], - _ct15_arg_data.shape[1], - _ct15_arg_r, - _ct15_arg_c, - _ct15_arg_m_in, - )[..., :_ct15_arg_m].copy() - ct15_arg.batch = ct15_arg.polynomial.shape[0] - ct15_arg.num_elements = ct15_arg.polynomial.shape[1] - ct15_arg.num_moduli = _ct15_arg_m - ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) - ct15_arg.r = _ct15_arg_r - ct15_arg.c = _ct15_arg_c - ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] - ct15_arg.moduli_array = jnp.array( - ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) - ) - ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) - _ct15_data = ( - ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw - ) - _ct15_m_in = _ct15_data.shape[-1] - _ct15_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_m_in - ) - _ct15_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_r) - ) - _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) - if isinstance(_ct15_moduli, (int, np.integer)): - _ct15_moduli = [int(_ct15_moduli)] - ct15 = Polynomial( - { - "batch": _ct15_data.shape[0], - "num_elements": _ct15_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_m, - "precision": 32, - "degree_layout": (_ct15_r, _ct15_c), - }, - {"moduli": list(_ct15_moduli)[:_ct15_m]}, - ) - ct15.polynomial = _ct15_data.reshape( - _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in - )[..., :_ct15_m].copy() - ct15.batch = ct15.polynomial.shape[0] - ct15.num_elements = ct15.polynomial.shape[1] - ct15.num_moduli = _ct15_m - ct15.degree_layout = (_ct15_r, _ct15_c) - ct15.r = _ct15_r - ct15.c = _ct15_c - ct15.moduli = list(_ct15_moduli)[:_ct15_m] - ct15.moduli_array = jnp.array( - ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) - ) - _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - _ct16_m_in = _ct16_data.shape[-1] - _ct16_m = _ct16_m_in - _ct16_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_r) - ) - _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) - if isinstance(_ct16_moduli, (int, np.integer)): - _ct16_moduli = [int(_ct16_moduli)] - ct16 = Polynomial( - { - "batch": _ct16_data.shape[0], - "num_elements": _ct16_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_m, - "precision": 32, - "degree_layout": (_ct16_r, _ct16_c), - }, - {"moduli": list(_ct16_moduli)[:_ct16_m]}, - ) - ct16.polynomial = _ct16_data.reshape( - _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in - )[..., :_ct16_m].copy() - ct16.batch = ct16.polynomial.shape[0] - ct16.num_elements = ct16.polynomial.shape[1] - ct16.num_moduli = _ct16_m - ct16.degree_layout = (_ct16_r, _ct16_c) - ct16.r = _ct16_r - ct16.c = _ct16_c - ct16.moduli = list(_ct16_moduli)[:_ct16_m] - ct16.moduli_array = jnp.array( - ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) - ) - _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 - _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] - _ct16_rhs_m = _ct16_rhs_m_in - _ct16_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_rhs_r) - ) - _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) - if isinstance(_ct16_rhs_moduli, (int, np.integer)): - _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] - ct16_rhs = Polynomial( - { - "batch": _ct16_rhs_data.shape[0], - "num_elements": _ct16_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_rhs_m, - "precision": 32, - "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), - }, - {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, - ) - ct16_rhs.polynomial = _ct16_rhs_data.reshape( - _ct16_rhs_data.shape[0], - _ct16_rhs_data.shape[1], - _ct16_rhs_r, - _ct16_rhs_c, - _ct16_rhs_m_in, - )[..., :_ct16_rhs_m].copy() - ct16_rhs.batch = ct16_rhs.polynomial.shape[0] - ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] - ct16_rhs.num_moduli = _ct16_rhs_m - ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) - ct16_rhs.r = _ct16_rhs_r - ct16_rhs.c = _ct16_rhs_c - ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] - ct16_rhs.moduli_array = jnp.array( - ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) - ) - ct16.add(ct16_rhs) - _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) - ct16.polynomial = jnp.where( - ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial - ) - _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - _ct17_m_in = _ct17_data.shape[-1] - _ct17_m = _ct17_m_in - _ct17_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_r) - ) - _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) - if isinstance(_ct17_moduli, (int, np.integer)): - _ct17_moduli = [int(_ct17_moduli)] - ct17 = Polynomial( - { - "batch": _ct17_data.shape[0], - "num_elements": _ct17_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_m, - "precision": 32, - "degree_layout": (_ct17_r, _ct17_c), - }, - {"moduli": list(_ct17_moduli)[:_ct17_m]}, - ) - ct17.polynomial = _ct17_data.reshape( - _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in - )[..., :_ct17_m].copy() - ct17.batch = ct17.polynomial.shape[0] - ct17.num_elements = ct17.polynomial.shape[1] - ct17.num_moduli = _ct17_m - ct17.degree_layout = (_ct17_r, _ct17_c) - ct17.r = _ct17_r - ct17.c = _ct17_c - ct17.moduli = list(_ct17_moduli)[:_ct17_m] - ct17.moduli_array = jnp.array( - ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) - ) - _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] - _ct17_rhs_m = _ct17_rhs_m_in - _ct17_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_rhs_r) - ) - _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) - if isinstance(_ct17_rhs_moduli, (int, np.integer)): - _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] - ct17_rhs = Polynomial( - { - "batch": _ct17_rhs_data.shape[0], - "num_elements": _ct17_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_rhs_m, - "precision": 32, - "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), - }, - {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, - ) - ct17_rhs.polynomial = _ct17_rhs_data.reshape( - _ct17_rhs_data.shape[0], - _ct17_rhs_data.shape[1], - _ct17_rhs_r, - _ct17_rhs_c, - _ct17_rhs_m_in, - )[..., :_ct17_rhs_m].copy() - ct17_rhs.batch = ct17_rhs.polynomial.shape[0] - ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] - ct17_rhs.num_moduli = _ct17_rhs_m - ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) - ct17_rhs.r = _ct17_rhs_r - ct17_rhs.c = _ct17_rhs_c - ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] - ct17_rhs.moduli_array = jnp.array( - ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) - ) - ct17.add(ct17_rhs) - _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) - ct17.polynomial = jnp.where( - ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial - ) - _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 - _ct18_m_in = _ct18_data.shape[-1] - _ct18_m = _ct18_m_in - _ct18_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_r) - ) - _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) - if isinstance(_ct18_moduli, (int, np.integer)): - _ct18_moduli = [int(_ct18_moduli)] - ct18 = Polynomial( - { - "batch": _ct18_data.shape[0], - "num_elements": _ct18_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_m, - "precision": 32, - "degree_layout": (_ct18_r, _ct18_c), - }, - {"moduli": list(_ct18_moduli)[:_ct18_m]}, - ) - ct18.polynomial = _ct18_data.reshape( - _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in - )[..., :_ct18_m].copy() - ct18.batch = ct18.polynomial.shape[0] - ct18.num_elements = ct18.polynomial.shape[1] - ct18.num_moduli = _ct18_m - ct18.degree_layout = (_ct18_r, _ct18_c) - ct18.r = _ct18_r - ct18.c = _ct18_c - ct18.moduli = list(_ct18_moduli)[:_ct18_m] - ct18.moduli_array = jnp.array( - ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) - ) - _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 - _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] - _ct18_rhs_m = _ct18_rhs_m_in - _ct18_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_rhs_r) - ) - _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) - if isinstance(_ct18_rhs_moduli, (int, np.integer)): - _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] - ct18_rhs = Polynomial( - { - "batch": _ct18_rhs_data.shape[0], - "num_elements": _ct18_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_rhs_m, - "precision": 32, - "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), - }, - {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, - ) - ct18_rhs.polynomial = _ct18_rhs_data.reshape( - _ct18_rhs_data.shape[0], - _ct18_rhs_data.shape[1], - _ct18_rhs_r, - _ct18_rhs_c, - _ct18_rhs_m_in, - )[..., :_ct18_rhs_m].copy() - ct18_rhs.batch = ct18_rhs.polynomial.shape[0] - ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] - ct18_rhs.num_moduli = _ct18_rhs_m - ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) - ct18_rhs.r = _ct18_rhs_r - ct18_rhs.c = _ct18_rhs_c - ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] - ct18_rhs.moduli_array = jnp.array( - ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) - ) - ct18.add(ct18_rhs) - _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) - ct18.polynomial = jnp.where( - ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial - ) - _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 - _ct19_m_in = _ct19_data.shape[-1] - _ct19_m = _ct19_m_in - _ct19_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_r) - ) - _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) - if isinstance(_ct19_moduli, (int, np.integer)): - _ct19_moduli = [int(_ct19_moduli)] - ct19 = Polynomial( - { - "batch": _ct19_data.shape[0], - "num_elements": _ct19_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_m, - "precision": 32, - "degree_layout": (_ct19_r, _ct19_c), - }, - {"moduli": list(_ct19_moduli)[:_ct19_m]}, - ) - ct19.polynomial = _ct19_data.reshape( - _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in - )[..., :_ct19_m].copy() - ct19.batch = ct19.polynomial.shape[0] - ct19.num_elements = ct19.polynomial.shape[1] - ct19.num_moduli = _ct19_m - ct19.degree_layout = (_ct19_r, _ct19_c) - ct19.r = _ct19_r - ct19.c = _ct19_c - ct19.moduli = list(_ct19_moduli)[:_ct19_m] - ct19.moduli_array = jnp.array( - ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) - ) - _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 - _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] - _ct19_rhs_m = _ct19_rhs_m_in - _ct19_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_rhs_r) - ) - _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) - if isinstance(_ct19_rhs_moduli, (int, np.integer)): - _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] - ct19_rhs = Polynomial( - { - "batch": _ct19_rhs_data.shape[0], - "num_elements": _ct19_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_rhs_m, - "precision": 32, - "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), - }, - {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, - ) - ct19_rhs.polynomial = _ct19_rhs_data.reshape( - _ct19_rhs_data.shape[0], - _ct19_rhs_data.shape[1], - _ct19_rhs_r, - _ct19_rhs_c, - _ct19_rhs_m_in, - )[..., :_ct19_rhs_m].copy() - ct19_rhs.batch = ct19_rhs.polynomial.shape[0] - ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] - ct19_rhs.num_moduli = _ct19_rhs_m - ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) - ct19_rhs.r = _ct19_rhs_r - ct19_rhs.c = _ct19_rhs_c - ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] - ct19_rhs.moduli_array = jnp.array( - ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) - ) - ct19.add(ct19_rhs) - _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) - ct19.polynomial = jnp.where( - ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial - ) - v16 = [None] * 1 - _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 - _ct20_arg_m_in = _ct20_arg_data.shape[-1] - _ct20_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct20_arg_m_in - ) - _ct20_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_arg_r) - ) - _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) - if isinstance(_ct20_arg_moduli, (int, np.integer)): - _ct20_arg_moduli = [int(_ct20_arg_moduli)] - ct20_arg = Polynomial( - { - "batch": _ct20_arg_data.shape[0], - "num_elements": _ct20_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_arg_m, - "precision": 32, - "degree_layout": (_ct20_arg_r, _ct20_arg_c), - }, - {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, - ) - ct20_arg.polynomial = _ct20_arg_data.reshape( - _ct20_arg_data.shape[0], - _ct20_arg_data.shape[1], - _ct20_arg_r, - _ct20_arg_c, - _ct20_arg_m_in, - )[..., :_ct20_arg_m].copy() - ct20_arg.batch = ct20_arg.polynomial.shape[0] - ct20_arg.num_elements = ct20_arg.polynomial.shape[1] - ct20_arg.num_moduli = _ct20_arg_m - ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) - ct20_arg.r = _ct20_arg_r - ct20_arg.c = _ct20_arg_c - ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] - ct20_arg.moduli_array = jnp.array( - ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) - ) - ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) - _ct20_data = ( - ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw - ) - _ct20_m_in = _ct20_data.shape[-1] - _ct20_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct20_m_in - ) - _ct20_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_r) - ) - _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) - if isinstance(_ct20_moduli, (int, np.integer)): - _ct20_moduli = [int(_ct20_moduli)] - ct20 = Polynomial( - { - "batch": _ct20_data.shape[0], - "num_elements": _ct20_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_m, - "precision": 32, - "degree_layout": (_ct20_r, _ct20_c), - }, - {"moduli": list(_ct20_moduli)[:_ct20_m]}, - ) - ct20.polynomial = _ct20_data.reshape( - _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in - )[..., :_ct20_m].copy() - ct20.batch = ct20.polynomial.shape[0] - ct20.num_elements = ct20.polynomial.shape[1] - ct20.num_moduli = _ct20_m - ct20.degree_layout = (_ct20_r, _ct20_c) - ct20.r = _ct20_r - ct20.c = _ct20_c - ct20.moduli = list(_ct20_moduli)[:_ct20_m] - ct20.moduli_array = jnp.array( - ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) - ) - v16[0] = ct20 - v17 = v16 - return v17 - - -def matvec_shift( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_shift__preprocessing(v0, v1) - v11 = matvec_shift__preprocessed(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) - return v11 - - -def matvec_shift__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw - _ct_m_in = _ct_data.shape[-1] - _ct_m = _ct_m_in - _ct_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct_r) - ) - _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) - if isinstance(_ct_moduli, (int, np.integer)): - _ct_moduli = [int(_ct_moduli)] - ct = Polynomial( - { - "batch": _ct_data.shape[0], - "num_elements": _ct_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct_m, - "precision": 32, - "degree_layout": (_ct_r, _ct_c), - }, - {"moduli": list(_ct_moduli)[:_ct_m]}, - ) - ct.polynomial = _ct_data.reshape( - _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in - )[..., :_ct_m].copy() - ct.batch = ct.polynomial.shape[0] - ct.num_elements = ct.polynomial.shape[1] - ct.num_moduli = _ct_m - ct.degree_layout = (_ct_r, _ct_c) - ct.r = _ct_r - ct.c = _ct_c - ct.moduli = list(_ct_moduli)[:_ct_m] - ct.moduli_array = jnp.array( - ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) - ) - v16 = [ct] - return v16 - - -def matvec_shift__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 0 - v8 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - _num_moduli = ct.polynomial.shape[-1] - _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": ct.polynomial.shape[0], - "num_elements": ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - ct.polynomial.reshape( - ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli - ) - ) - pt = v0.decrypt(_ct_for_dec) - v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v10 = v8.copy() - for v11 in range(0, 8): - v13 = int(v11) - v14 = v9[0, v13] - v10[v13] = v14 - return v10 - - -def matvec_random__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> ( - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -): - v2 = np.array( - [ - 8.116263e-01, - 1.906357e00, - 1.490788e00, - 1.237451e00, - 3.964354e-01, - 3.963896e-01, - 2.103589e-01, - 1.745735e00, - 1.242118e00, - 1.445338e00, - 1.391105e-01, - 1.942829e00, - 1.681641e00, - 5.034443e-01, - 4.454674e-01, - 4.484686e-01, - 6.780602e-01, - 1.097037e00, - 9.206955e-01, - 6.533354e-01, - 1.262521e00, - 3.650383e-01, - 6.550748e-01, - 7.960875e-01, - 9.665329e-01, - 1.591834e00, - 4.793802e-01, - 1.077045e00, - 1.225588e00, - 1.882558e-01, - 1.254335e00, - 4.239958e-01, - 2.235980e-01, - 1.902883e00, - 1.934701e00, - 1.635955e00, - 6.787661e-01, - 2.855770e-01, - 1.400043e00, - 9.362897e-01, - 3.318726e-01, - 1.040836e00, - 1.653382e-01, - 1.827709e00, - 5.916820e-01, - 1.358792e00, - 6.922510e-01, - 1.088129e00, - 1.138749e00, - 4.512235e-01, - 1.942211e00, - 1.572752e00, - 1.885048e00, - 1.800172e00, - 1.236010e00, - 1.851561e00, - 2.681358e-01, - 4.723674e-01, - 1.859318e-01, - 7.181276e-01, - 8.384869e-01, - 6.155632e-01, - 1.674601e00, - 7.778313e-01, - ], - dtype=np.float32, - ).reshape(8, 8) - v3 = _assign_layout_15335824159471298539(v2) - v4 = v3[3 : 3 + 1, 0 : 0 + 5] - v5 = v3[3 : 3 + 1, 5 : 5 + 3] - v6 = np.zeros( - ( - 1, - 8, - ), - dtype=np.float32, - ) - v7 = v6.copy() - v7[0 : 0 + 1, 3 : 3 + 5] = v4 - v8 = v7.copy() - v8[0 : 0 + 1, 0 : 0 + 3] = v5 - v9 = v3[4 : 4 + 1, 0 : 0 + 5] - v10 = v3[4 : 4 + 1, 5 : 5 + 3] - v11 = v6.copy() - v11[0 : 0 + 1, 3 : 3 + 5] = v9 - v12 = v11.copy() - v12[0 : 0 + 1, 0 : 0 + 3] = v10 - v13 = v3[5 : 5 + 1, 0 : 0 + 5] - v14 = v3[5 : 5 + 1, 5 : 5 + 3] - v15 = v6.copy() - v15[0 : 0 + 1, 3 : 3 + 5] = v13 - v16 = v15.copy() - v16[0 : 0 + 1, 0 : 0 + 3] = v14 - v17 = v3[6 : 6 + 1, 0 : 0 + 2] - v18 = v3[6 : 6 + 1, 2 : 2 + 6] - v19 = v6.copy() - v19[0 : 0 + 1, 6 : 6 + 2] = v17 - v20 = v19.copy() - v20[0 : 0 + 1, 0 : 0 + 6] = v18 - v21 = v3[7 : 7 + 1, 0 : 0 + 2] - v22 = v3[7 : 7 + 1, 2 : 2 + 6] - v23 = v6.copy() - v23[0 : 0 + 1, 6 : 6 + 2] = v21 - v24 = v23.copy() - v24[0 : 0 + 1, 0 : 0 + 6] = v22 - v25 = v3[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v25) - v26 = v3[1 : 1 + 1, 0 : 0 + 8].reshape(8) - pt1 = v0.encode(v26) - v27 = v3[2 : 2 + 1, 0 : 0 + 8].reshape(8) - pt2 = v0.encode(v27) - v28 = v8[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt3 = v0.encode(v28) - v29 = v12[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt4 = v0.encode(v29) - v30 = v16[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt5 = v0.encode(v30) - v31 = v20[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt6 = v0.encode(v31) - v32 = v24[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt7 = v0.encode(v32) - v33 = [pt] - v34 = [pt1] - v35 = [pt2] - v36 = [pt3] - v37 = [pt4] - v38 = [pt5] - v39 = [pt6] - v40 = [pt7] - return (v33, v34, v35, v36, v37, v38, v39, v40) - - -def matvec_random__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, - v5: np.ndarray, - v6: np.ndarray, - v7: np.ndarray, - v8: np.ndarray, - v9: np.ndarray, - v10: np.ndarray, -) -> np.ndarray: - v11 = 1 - v12 = 2 - v13 = 3 - v14 = 6 - v15 = 0 - pt = v3[0] - pt1 = v4[0] - pt2 = v5[0] - pt3 = v6[0] - pt4 = v7[0] - pt5 = v8[0] - pt6 = v9[0] - pt7 = v10[0] - ct = v2[0] - _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct1_arg_m_in = _ct1_arg_data.shape[-1] - _ct1_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_arg_m_in - ) - _ct1_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_arg_r) - ) - _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct1_arg_moduli, (int, np.integer)): - _ct1_arg_moduli = [int(_ct1_arg_moduli)] - ct1_arg = Polynomial( - { - "batch": _ct1_arg_data.shape[0], - "num_elements": _ct1_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_arg_m, - "precision": 32, - "degree_layout": (_ct1_arg_r, _ct1_arg_c), - }, - {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, - ) - ct1_arg.polynomial = _ct1_arg_data.reshape( - _ct1_arg_data.shape[0], - _ct1_arg_data.shape[1], - _ct1_arg_r, - _ct1_arg_c, - _ct1_arg_m_in, - )[..., :_ct1_arg_m].copy() - ct1_arg.batch = ct1_arg.polynomial.shape[0] - ct1_arg.num_elements = ct1_arg.polynomial.shape[1] - ct1_arg.num_moduli = _ct1_arg_m - ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) - ct1_arg.r = _ct1_arg_r - ct1_arg.c = _ct1_arg_c - ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] - ct1_arg.moduli_array = jnp.array( - ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) - ) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) - _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw - _ct1_m_in = _ct1_data.shape[-1] - _ct1_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_m_in - ) - _ct1_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_r) - ) - _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) - if isinstance(_ct1_moduli, (int, np.integer)): - _ct1_moduli = [int(_ct1_moduli)] - ct1 = Polynomial( - { - "batch": _ct1_data.shape[0], - "num_elements": _ct1_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_m, - "precision": 32, - "degree_layout": (_ct1_r, _ct1_c), - }, - {"moduli": list(_ct1_moduli)[:_ct1_m]}, - ) - ct1.polynomial = _ct1_data.reshape( - _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in - )[..., :_ct1_m].copy() - ct1.batch = ct1.polynomial.shape[0] - ct1.num_elements = ct1.polynomial.shape[1] - ct1.num_moduli = _ct1_m - ct1.degree_layout = (_ct1_r, _ct1_c) - ct1.r = _ct1_r - ct1.c = _ct1_c - ct1.moduli = list(_ct1_moduli)[:_ct1_m] - ct1.moduli_array = jnp.array( - ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) - ) - _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct2_arg_m_in = _ct2_arg_data.shape[-1] - _ct2_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_arg_m_in - ) - _ct2_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_arg_r) - ) - _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct2_arg_moduli, (int, np.integer)): - _ct2_arg_moduli = [int(_ct2_arg_moduli)] - ct2_arg = Polynomial( - { - "batch": _ct2_arg_data.shape[0], - "num_elements": _ct2_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_arg_m, - "precision": 32, - "degree_layout": (_ct2_arg_r, _ct2_arg_c), - }, - {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, - ) - ct2_arg.polynomial = _ct2_arg_data.reshape( - _ct2_arg_data.shape[0], - _ct2_arg_data.shape[1], - _ct2_arg_r, - _ct2_arg_c, - _ct2_arg_m_in, - )[..., :_ct2_arg_m].copy() - ct2_arg.batch = ct2_arg.polynomial.shape[0] - ct2_arg.num_elements = ct2_arg.polynomial.shape[1] - ct2_arg.num_moduli = _ct2_arg_m - ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) - ct2_arg.r = _ct2_arg_r - ct2_arg.c = _ct2_arg_c - ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] - ct2_arg.moduli_array = jnp.array( - ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) - ) - ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw - _ct2_m_in = _ct2_data.shape[-1] - _ct2_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_m_in - ) - _ct2_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_r) - ) - _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) - if isinstance(_ct2_moduli, (int, np.integer)): - _ct2_moduli = [int(_ct2_moduli)] - ct2 = Polynomial( - { - "batch": _ct2_data.shape[0], - "num_elements": _ct2_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_m, - "precision": 32, - "degree_layout": (_ct2_r, _ct2_c), - }, - {"moduli": list(_ct2_moduli)[:_ct2_m]}, - ) - ct2.polynomial = _ct2_data.reshape( - _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in - )[..., :_ct2_m].copy() - ct2.batch = ct2.polynomial.shape[0] - ct2.num_elements = ct2.polynomial.shape[1] - ct2.num_moduli = _ct2_m - ct2.degree_layout = (_ct2_r, _ct2_c) - ct2.r = _ct2_r - ct2.c = _ct2_c - ct2.moduli = list(_ct2_moduli)[:_ct2_m] - ct2.moduli_array = jnp.array( - ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) - ) - _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct3_arg_m_in = _ct3_arg_data.shape[-1] - _ct3_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_arg_m_in - ) - _ct3_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_arg_r) - ) - _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct3_arg_moduli, (int, np.integer)): - _ct3_arg_moduli = [int(_ct3_arg_moduli)] - ct3_arg = Polynomial( - { - "batch": _ct3_arg_data.shape[0], - "num_elements": _ct3_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_arg_m, - "precision": 32, - "degree_layout": (_ct3_arg_r, _ct3_arg_c), - }, - {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, - ) - ct3_arg.polynomial = _ct3_arg_data.reshape( - _ct3_arg_data.shape[0], - _ct3_arg_data.shape[1], - _ct3_arg_r, - _ct3_arg_c, - _ct3_arg_m_in, - )[..., :_ct3_arg_m].copy() - ct3_arg.batch = ct3_arg.polynomial.shape[0] - ct3_arg.num_elements = ct3_arg.polynomial.shape[1] - ct3_arg.num_moduli = _ct3_arg_m - ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) - ct3_arg.r = _ct3_arg_r - ct3_arg.c = _ct3_arg_c - ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] - ct3_arg.moduli_array = jnp.array( - ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) - ) - ct3_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] - .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct3_ptct = v0.ptct_mul[v0.max_level] - ct3_ptct.set_plaintext(ct3_pt_ntt) - ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) - _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw - _ct3_m_in = _ct3_data.shape[-1] - _ct3_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_m_in - ) - _ct3_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_r) - ) - _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) - if isinstance(_ct3_moduli, (int, np.integer)): - _ct3_moduli = [int(_ct3_moduli)] - ct3 = Polynomial( - { - "batch": _ct3_data.shape[0], - "num_elements": _ct3_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_m, - "precision": 32, - "degree_layout": (_ct3_r, _ct3_c), - }, - {"moduli": list(_ct3_moduli)[:_ct3_m]}, - ) - ct3.polynomial = _ct3_data.reshape( - _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in - )[..., :_ct3_m].copy() - ct3.batch = ct3.polynomial.shape[0] - ct3.num_elements = ct3.polynomial.shape[1] - ct3.num_moduli = _ct3_m - ct3.degree_layout = (_ct3_r, _ct3_c) - ct3.r = _ct3_r - ct3.c = _ct3_c - ct3.moduli = list(_ct3_moduli)[:_ct3_m] - ct3.moduli_array = jnp.array( - ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) - ) - _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct4_arg_m_in = _ct4_arg_data.shape[-1] - _ct4_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_arg_m_in - ) - _ct4_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_arg_r) - ) - _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct4_arg_moduli, (int, np.integer)): - _ct4_arg_moduli = [int(_ct4_arg_moduli)] - ct4_arg = Polynomial( - { - "batch": _ct4_arg_data.shape[0], - "num_elements": _ct4_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_arg_m, - "precision": 32, - "degree_layout": (_ct4_arg_r, _ct4_arg_c), - }, - {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, - ) - ct4_arg.polynomial = _ct4_arg_data.reshape( - _ct4_arg_data.shape[0], - _ct4_arg_data.shape[1], - _ct4_arg_r, - _ct4_arg_c, - _ct4_arg_m_in, - )[..., :_ct4_arg_m].copy() - ct4_arg.batch = ct4_arg.polynomial.shape[0] - ct4_arg.num_elements = ct4_arg.polynomial.shape[1] - ct4_arg.num_moduli = _ct4_arg_m - ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) - ct4_arg.r = _ct4_arg_r - ct4_arg.c = _ct4_arg_c - ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] - ct4_arg.moduli_array = jnp.array( - ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) - ) - ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) - _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw - _ct4_m_in = _ct4_data.shape[-1] - _ct4_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_m_in - ) - _ct4_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_r) - ) - _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) - if isinstance(_ct4_moduli, (int, np.integer)): - _ct4_moduli = [int(_ct4_moduli)] - ct4 = Polynomial( - { - "batch": _ct4_data.shape[0], - "num_elements": _ct4_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_m, - "precision": 32, - "degree_layout": (_ct4_r, _ct4_c), - }, - {"moduli": list(_ct4_moduli)[:_ct4_m]}, - ) - ct4.polynomial = _ct4_data.reshape( - _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in - )[..., :_ct4_m].copy() - ct4.batch = ct4.polynomial.shape[0] - ct4.num_elements = ct4.polynomial.shape[1] - ct4.num_moduli = _ct4_m - ct4.degree_layout = (_ct4_r, _ct4_c) - ct4.r = _ct4_r - ct4.c = _ct4_c - ct4.moduli = list(_ct4_moduli)[:_ct4_m] - ct4.moduli_array = jnp.array( - ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) - ) - _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct5_arg_m_in = _ct5_arg_data.shape[-1] - _ct5_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_arg_m_in - ) - _ct5_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_arg_r) - ) - _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct5_arg_moduli, (int, np.integer)): - _ct5_arg_moduli = [int(_ct5_arg_moduli)] - ct5_arg = Polynomial( - { - "batch": _ct5_arg_data.shape[0], - "num_elements": _ct5_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_arg_m, - "precision": 32, - "degree_layout": (_ct5_arg_r, _ct5_arg_c), - }, - {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, - ) - ct5_arg.polynomial = _ct5_arg_data.reshape( - _ct5_arg_data.shape[0], - _ct5_arg_data.shape[1], - _ct5_arg_r, - _ct5_arg_c, - _ct5_arg_m_in, - )[..., :_ct5_arg_m].copy() - ct5_arg.batch = ct5_arg.polynomial.shape[0] - ct5_arg.num_elements = ct5_arg.polynomial.shape[1] - ct5_arg.num_moduli = _ct5_arg_m - ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) - ct5_arg.r = _ct5_arg_r - ct5_arg.c = _ct5_arg_c - ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] - ct5_arg.moduli_array = jnp.array( - ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) - ) - ct5_pt_ntt = ( - pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) - _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw - _ct5_m_in = _ct5_data.shape[-1] - _ct5_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_m_in - ) - _ct5_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_r) - ) - _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) - if isinstance(_ct5_moduli, (int, np.integer)): - _ct5_moduli = [int(_ct5_moduli)] - ct5 = Polynomial( - { - "batch": _ct5_data.shape[0], - "num_elements": _ct5_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_m, - "precision": 32, - "degree_layout": (_ct5_r, _ct5_c), - }, - {"moduli": list(_ct5_moduli)[:_ct5_m]}, - ) - ct5.polynomial = _ct5_data.reshape( - _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in - )[..., :_ct5_m].copy() - ct5.batch = ct5.polynomial.shape[0] - ct5.num_elements = ct5.polynomial.shape[1] - ct5.num_moduli = _ct5_m - ct5.degree_layout = (_ct5_r, _ct5_c) - ct5.r = _ct5_r - ct5.c = _ct5_c - ct5.moduli = list(_ct5_moduli)[:_ct5_m] - ct5.moduli_array = jnp.array( - ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) - ) - _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct6_arg_m_in = _ct6_arg_data.shape[-1] - _ct6_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_arg_m_in - ) - _ct6_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_arg_r) - ) - _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct6_arg_moduli, (int, np.integer)): - _ct6_arg_moduli = [int(_ct6_arg_moduli)] - ct6_arg = Polynomial( - { - "batch": _ct6_arg_data.shape[0], - "num_elements": _ct6_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_arg_m, - "precision": 32, - "degree_layout": (_ct6_arg_r, _ct6_arg_c), - }, - {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, - ) - ct6_arg.polynomial = _ct6_arg_data.reshape( - _ct6_arg_data.shape[0], - _ct6_arg_data.shape[1], - _ct6_arg_r, - _ct6_arg_c, - _ct6_arg_m_in, - )[..., :_ct6_arg_m].copy() - ct6_arg.batch = ct6_arg.polynomial.shape[0] - ct6_arg.num_elements = ct6_arg.polynomial.shape[1] - ct6_arg.num_moduli = _ct6_arg_m - ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) - ct6_arg.r = _ct6_arg_r - ct6_arg.c = _ct6_arg_c - ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] - ct6_arg.moduli_array = jnp.array( - ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) - ) - ct6_pt_ntt = ( - pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] - .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct6_ptct = v0.ptct_mul[v0.max_level] - ct6_ptct.set_plaintext(ct6_pt_ntt) - ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) - _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw - _ct6_m_in = _ct6_data.shape[-1] - _ct6_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_m_in - ) - _ct6_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_r) - ) - _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) - if isinstance(_ct6_moduli, (int, np.integer)): - _ct6_moduli = [int(_ct6_moduli)] - ct6 = Polynomial( - { - "batch": _ct6_data.shape[0], - "num_elements": _ct6_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_m, - "precision": 32, - "degree_layout": (_ct6_r, _ct6_c), - }, - {"moduli": list(_ct6_moduli)[:_ct6_m]}, - ) - ct6.polynomial = _ct6_data.reshape( - _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in - )[..., :_ct6_m].copy() - ct6.batch = ct6.polynomial.shape[0] - ct6.num_elements = ct6.polynomial.shape[1] - ct6.num_moduli = _ct6_m - ct6.degree_layout = (_ct6_r, _ct6_c) - ct6.r = _ct6_r - ct6.c = _ct6_c - ct6.moduli = list(_ct6_moduli)[:_ct6_m] - ct6.moduli_array = jnp.array( - ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) - ) - _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct7_arg_m_in = _ct7_arg_data.shape[-1] - _ct7_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_arg_m_in - ) - _ct7_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_arg_r) - ) - _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct7_arg_moduli, (int, np.integer)): - _ct7_arg_moduli = [int(_ct7_arg_moduli)] - ct7_arg = Polynomial( - { - "batch": _ct7_arg_data.shape[0], - "num_elements": _ct7_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_arg_m, - "precision": 32, - "degree_layout": (_ct7_arg_r, _ct7_arg_c), - }, - {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, - ) - ct7_arg.polynomial = _ct7_arg_data.reshape( - _ct7_arg_data.shape[0], - _ct7_arg_data.shape[1], - _ct7_arg_r, - _ct7_arg_c, - _ct7_arg_m_in, - )[..., :_ct7_arg_m].copy() - ct7_arg.batch = ct7_arg.polynomial.shape[0] - ct7_arg.num_elements = ct7_arg.polynomial.shape[1] - ct7_arg.num_moduli = _ct7_arg_m - ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) - ct7_arg.r = _ct7_arg_r - ct7_arg.c = _ct7_arg_c - ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] - ct7_arg.moduli_array = jnp.array( - ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) - ) - ct7_pt_ntt = ( - pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] - .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct7_ptct = v0.ptct_mul[v0.max_level] - ct7_ptct.set_plaintext(ct7_pt_ntt) - ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) - _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw - _ct7_m_in = _ct7_data.shape[-1] - _ct7_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_m_in - ) - _ct7_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_r) - ) - _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) - if isinstance(_ct7_moduli, (int, np.integer)): - _ct7_moduli = [int(_ct7_moduli)] - ct7 = Polynomial( - { - "batch": _ct7_data.shape[0], - "num_elements": _ct7_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_m, - "precision": 32, - "degree_layout": (_ct7_r, _ct7_c), - }, - {"moduli": list(_ct7_moduli)[:_ct7_m]}, - ) - ct7.polynomial = _ct7_data.reshape( - _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in - )[..., :_ct7_m].copy() - ct7.batch = ct7.polynomial.shape[0] - ct7.num_elements = ct7.polynomial.shape[1] - ct7.num_moduli = _ct7_m - ct7.degree_layout = (_ct7_r, _ct7_c) - ct7.r = _ct7_r - ct7.c = _ct7_c - ct7.moduli = list(_ct7_moduli)[:_ct7_m] - ct7.moduli_array = jnp.array( - ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) - ) - _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct8_arg_m_in = _ct8_arg_data.shape[-1] - _ct8_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_arg_m_in - ) - _ct8_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_arg_r) - ) - _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct8_arg_moduli, (int, np.integer)): - _ct8_arg_moduli = [int(_ct8_arg_moduli)] - ct8_arg = Polynomial( - { - "batch": _ct8_arg_data.shape[0], - "num_elements": _ct8_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_arg_m, - "precision": 32, - "degree_layout": (_ct8_arg_r, _ct8_arg_c), - }, - {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, - ) - ct8_arg.polynomial = _ct8_arg_data.reshape( - _ct8_arg_data.shape[0], - _ct8_arg_data.shape[1], - _ct8_arg_r, - _ct8_arg_c, - _ct8_arg_m_in, - )[..., :_ct8_arg_m].copy() - ct8_arg.batch = ct8_arg.polynomial.shape[0] - ct8_arg.num_elements = ct8_arg.polynomial.shape[1] - ct8_arg.num_moduli = _ct8_arg_m - ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) - ct8_arg.r = _ct8_arg_r - ct8_arg.c = _ct8_arg_c - ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] - ct8_arg.moduli_array = jnp.array( - ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) - ) - ct8_pt_ntt = ( - pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] - .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct8_ptct = v0.ptct_mul[v0.max_level] - ct8_ptct.set_plaintext(ct8_pt_ntt) - ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) - _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw - _ct8_m_in = _ct8_data.shape[-1] - _ct8_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_m_in - ) - _ct8_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_r) - ) - _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) - if isinstance(_ct8_moduli, (int, np.integer)): - _ct8_moduli = [int(_ct8_moduli)] - ct8 = Polynomial( - { - "batch": _ct8_data.shape[0], - "num_elements": _ct8_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_m, - "precision": 32, - "degree_layout": (_ct8_r, _ct8_c), - }, - {"moduli": list(_ct8_moduli)[:_ct8_m]}, - ) - ct8.polynomial = _ct8_data.reshape( - _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in - )[..., :_ct8_m].copy() - ct8.batch = ct8.polynomial.shape[0] - ct8.num_elements = ct8.polynomial.shape[1] - ct8.num_moduli = _ct8_m - ct8.degree_layout = (_ct8_r, _ct8_c) - ct8.r = _ct8_r - ct8.c = _ct8_c - ct8.moduli = list(_ct8_moduli)[:_ct8_m] - ct8.moduli_array = jnp.array( - ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) - ) - _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - _ct9_m_in = _ct9_data.shape[-1] - _ct9_m = _ct9_m_in - _ct9_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_r) - ) - _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) - if isinstance(_ct9_moduli, (int, np.integer)): - _ct9_moduli = [int(_ct9_moduli)] - ct9 = Polynomial( - { - "batch": _ct9_data.shape[0], - "num_elements": _ct9_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_m, - "precision": 32, - "degree_layout": (_ct9_r, _ct9_c), - }, - {"moduli": list(_ct9_moduli)[:_ct9_m]}, - ) - ct9.polynomial = _ct9_data.reshape( - _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in - )[..., :_ct9_m].copy() - ct9.batch = ct9.polynomial.shape[0] - ct9.num_elements = ct9.polynomial.shape[1] - ct9.num_moduli = _ct9_m - ct9.degree_layout = (_ct9_r, _ct9_c) - ct9.r = _ct9_r - ct9.c = _ct9_c - ct9.moduli = list(_ct9_moduli)[:_ct9_m] - ct9.moduli_array = jnp.array( - ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) - ) - _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 - _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] - _ct9_rhs_m = _ct9_rhs_m_in - _ct9_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_rhs_r) - ) - _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) - if isinstance(_ct9_rhs_moduli, (int, np.integer)): - _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] - ct9_rhs = Polynomial( - { - "batch": _ct9_rhs_data.shape[0], - "num_elements": _ct9_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_rhs_m, - "precision": 32, - "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), - }, - {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, - ) - ct9_rhs.polynomial = _ct9_rhs_data.reshape( - _ct9_rhs_data.shape[0], - _ct9_rhs_data.shape[1], - _ct9_rhs_r, - _ct9_rhs_c, - _ct9_rhs_m_in, - )[..., :_ct9_rhs_m].copy() - ct9_rhs.batch = ct9_rhs.polynomial.shape[0] - ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] - ct9_rhs.num_moduli = _ct9_rhs_m - ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) - ct9_rhs.r = _ct9_rhs_r - ct9_rhs.c = _ct9_rhs_c - ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] - ct9_rhs.moduli_array = jnp.array( - ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) - ) - ct9.add(ct9_rhs) - _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) - ct9.polynomial = jnp.where( - ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial - ) - _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - _ct10_m_in = _ct10_data.shape[-1] - _ct10_m = _ct10_m_in - _ct10_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_r) - ) - _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) - if isinstance(_ct10_moduli, (int, np.integer)): - _ct10_moduli = [int(_ct10_moduli)] - ct10 = Polynomial( - { - "batch": _ct10_data.shape[0], - "num_elements": _ct10_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_m, - "precision": 32, - "degree_layout": (_ct10_r, _ct10_c), - }, - {"moduli": list(_ct10_moduli)[:_ct10_m]}, - ) - ct10.polynomial = _ct10_data.reshape( - _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in - )[..., :_ct10_m].copy() - ct10.batch = ct10.polynomial.shape[0] - ct10.num_elements = ct10.polynomial.shape[1] - ct10.num_moduli = _ct10_m - ct10.degree_layout = (_ct10_r, _ct10_c) - ct10.r = _ct10_r - ct10.c = _ct10_c - ct10.moduli = list(_ct10_moduli)[:_ct10_m] - ct10.moduli_array = jnp.array( - ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) - ) - _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] - _ct10_rhs_m = _ct10_rhs_m_in - _ct10_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_rhs_r) - ) - _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) - if isinstance(_ct10_rhs_moduli, (int, np.integer)): - _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] - ct10_rhs = Polynomial( - { - "batch": _ct10_rhs_data.shape[0], - "num_elements": _ct10_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_rhs_m, - "precision": 32, - "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), - }, - {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, - ) - ct10_rhs.polynomial = _ct10_rhs_data.reshape( - _ct10_rhs_data.shape[0], - _ct10_rhs_data.shape[1], - _ct10_rhs_r, - _ct10_rhs_c, - _ct10_rhs_m_in, - )[..., :_ct10_rhs_m].copy() - ct10_rhs.batch = ct10_rhs.polynomial.shape[0] - ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] - ct10_rhs.num_moduli = _ct10_rhs_m - ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) - ct10_rhs.r = _ct10_rhs_r - ct10_rhs.c = _ct10_rhs_c - ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] - ct10_rhs.moduli_array = jnp.array( - ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) - ) - ct10.add(ct10_rhs) - _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) - ct10.polynomial = jnp.where( - ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial - ) - _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 - _ct11_arg_m_in = _ct11_arg_data.shape[-1] - _ct11_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_arg_m_in - ) - _ct11_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_arg_r) - ) - _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) - if isinstance(_ct11_arg_moduli, (int, np.integer)): - _ct11_arg_moduli = [int(_ct11_arg_moduli)] - ct11_arg = Polynomial( - { - "batch": _ct11_arg_data.shape[0], - "num_elements": _ct11_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_arg_m, - "precision": 32, - "degree_layout": (_ct11_arg_r, _ct11_arg_c), - }, - {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, - ) - ct11_arg.polynomial = _ct11_arg_data.reshape( - _ct11_arg_data.shape[0], - _ct11_arg_data.shape[1], - _ct11_arg_r, - _ct11_arg_c, - _ct11_arg_m_in, - )[..., :_ct11_arg_m].copy() - ct11_arg.batch = ct11_arg.polynomial.shape[0] - ct11_arg.num_elements = ct11_arg.polynomial.shape[1] - ct11_arg.num_moduli = _ct11_arg_m - ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) - ct11_arg.r = _ct11_arg_r - ct11_arg.c = _ct11_arg_c - ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] - ct11_arg.moduli_array = jnp.array( - ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) - ) - ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) - _ct11_data = ( - ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw - ) - _ct11_m_in = _ct11_data.shape[-1] - _ct11_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_m_in - ) - _ct11_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_r) - ) - _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) - if isinstance(_ct11_moduli, (int, np.integer)): - _ct11_moduli = [int(_ct11_moduli)] - ct11 = Polynomial( - { - "batch": _ct11_data.shape[0], - "num_elements": _ct11_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_m, - "precision": 32, - "degree_layout": (_ct11_r, _ct11_c), - }, - {"moduli": list(_ct11_moduli)[:_ct11_m]}, - ) - ct11.polynomial = _ct11_data.reshape( - _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in - )[..., :_ct11_m].copy() - ct11.batch = ct11.polynomial.shape[0] - ct11.num_elements = ct11.polynomial.shape[1] - ct11.num_moduli = _ct11_m - ct11.degree_layout = (_ct11_r, _ct11_c) - ct11.r = _ct11_r - ct11.c = _ct11_c - ct11.moduli = list(_ct11_moduli)[:_ct11_m] - ct11.moduli_array = jnp.array( - ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) - ) - _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct12_arg_m_in = _ct12_arg_data.shape[-1] - _ct12_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_arg_m_in - ) - _ct12_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_arg_r) - ) - _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct12_arg_moduli, (int, np.integer)): - _ct12_arg_moduli = [int(_ct12_arg_moduli)] - ct12_arg = Polynomial( - { - "batch": _ct12_arg_data.shape[0], - "num_elements": _ct12_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_arg_m, - "precision": 32, - "degree_layout": (_ct12_arg_r, _ct12_arg_c), - }, - {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, - ) - ct12_arg.polynomial = _ct12_arg_data.reshape( - _ct12_arg_data.shape[0], - _ct12_arg_data.shape[1], - _ct12_arg_r, - _ct12_arg_c, - _ct12_arg_m_in, - )[..., :_ct12_arg_m].copy() - ct12_arg.batch = ct12_arg.polynomial.shape[0] - ct12_arg.num_elements = ct12_arg.polynomial.shape[1] - ct12_arg.num_moduli = _ct12_arg_m - ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) - ct12_arg.r = _ct12_arg_r - ct12_arg.c = _ct12_arg_c - ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] - ct12_arg.moduli_array = jnp.array( - ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) - ) - ct12_pt_ntt = ( - pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] - .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct12_ptct = v0.ptct_mul[v0.max_level] - ct12_ptct.set_plaintext(ct12_pt_ntt) - ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) - _ct12_data = ( - ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw - ) - _ct12_m_in = _ct12_data.shape[-1] - _ct12_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_m_in - ) - _ct12_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_r) - ) - _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) - if isinstance(_ct12_moduli, (int, np.integer)): - _ct12_moduli = [int(_ct12_moduli)] - ct12 = Polynomial( - { - "batch": _ct12_data.shape[0], - "num_elements": _ct12_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_m, - "precision": 32, - "degree_layout": (_ct12_r, _ct12_c), - }, - {"moduli": list(_ct12_moduli)[:_ct12_m]}, - ) - ct12.polynomial = _ct12_data.reshape( - _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in - )[..., :_ct12_m].copy() - ct12.batch = ct12.polynomial.shape[0] - ct12.num_elements = ct12.polynomial.shape[1] - ct12.num_moduli = _ct12_m - ct12.degree_layout = (_ct12_r, _ct12_c) - ct12.r = _ct12_r - ct12.c = _ct12_c - ct12.moduli = list(_ct12_moduli)[:_ct12_m] - ct12.moduli_array = jnp.array( - ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) - ) - _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct13_arg_m_in = _ct13_arg_data.shape[-1] - _ct13_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_arg_m_in - ) - _ct13_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_arg_r) - ) - _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct13_arg_moduli, (int, np.integer)): - _ct13_arg_moduli = [int(_ct13_arg_moduli)] - ct13_arg = Polynomial( - { - "batch": _ct13_arg_data.shape[0], - "num_elements": _ct13_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_arg_m, - "precision": 32, - "degree_layout": (_ct13_arg_r, _ct13_arg_c), - }, - {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, - ) - ct13_arg.polynomial = _ct13_arg_data.reshape( - _ct13_arg_data.shape[0], - _ct13_arg_data.shape[1], - _ct13_arg_r, - _ct13_arg_c, - _ct13_arg_m_in, - )[..., :_ct13_arg_m].copy() - ct13_arg.batch = ct13_arg.polynomial.shape[0] - ct13_arg.num_elements = ct13_arg.polynomial.shape[1] - ct13_arg.num_moduli = _ct13_arg_m - ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) - ct13_arg.r = _ct13_arg_r - ct13_arg.c = _ct13_arg_c - ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] - ct13_arg.moduli_array = jnp.array( - ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) - ) - ct13_pt_ntt = ( - pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] - .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct13_ptct = v0.ptct_mul[v0.max_level] - ct13_ptct.set_plaintext(ct13_pt_ntt) - ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) - _ct13_data = ( - ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw - ) - _ct13_m_in = _ct13_data.shape[-1] - _ct13_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_m_in - ) - _ct13_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_r) - ) - _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) - if isinstance(_ct13_moduli, (int, np.integer)): - _ct13_moduli = [int(_ct13_moduli)] - ct13 = Polynomial( - { - "batch": _ct13_data.shape[0], - "num_elements": _ct13_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_m, - "precision": 32, - "degree_layout": (_ct13_r, _ct13_c), - }, - {"moduli": list(_ct13_moduli)[:_ct13_m]}, - ) - ct13.polynomial = _ct13_data.reshape( - _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in - )[..., :_ct13_m].copy() - ct13.batch = ct13.polynomial.shape[0] - ct13.num_elements = ct13.polynomial.shape[1] - ct13.num_moduli = _ct13_m - ct13.degree_layout = (_ct13_r, _ct13_c) - ct13.r = _ct13_r - ct13.c = _ct13_c - ct13.moduli = list(_ct13_moduli)[:_ct13_m] - ct13.moduli_array = jnp.array( - ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) - ) - _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - _ct14_m_in = _ct14_data.shape[-1] - _ct14_m = _ct14_m_in - _ct14_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_r) - ) - _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) - if isinstance(_ct14_moduli, (int, np.integer)): - _ct14_moduli = [int(_ct14_moduli)] - ct14 = Polynomial( - { - "batch": _ct14_data.shape[0], - "num_elements": _ct14_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_m, - "precision": 32, - "degree_layout": (_ct14_r, _ct14_c), - }, - {"moduli": list(_ct14_moduli)[:_ct14_m]}, - ) - ct14.polynomial = _ct14_data.reshape( - _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in - )[..., :_ct14_m].copy() - ct14.batch = ct14.polynomial.shape[0] - ct14.num_elements = ct14.polynomial.shape[1] - ct14.num_moduli = _ct14_m - ct14.degree_layout = (_ct14_r, _ct14_c) - ct14.r = _ct14_r - ct14.c = _ct14_c - ct14.moduli = list(_ct14_moduli)[:_ct14_m] - ct14.moduli_array = jnp.array( - ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) - ) - _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] - _ct14_rhs_m = _ct14_rhs_m_in - _ct14_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_rhs_r) - ) - _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) - if isinstance(_ct14_rhs_moduli, (int, np.integer)): - _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] - ct14_rhs = Polynomial( - { - "batch": _ct14_rhs_data.shape[0], - "num_elements": _ct14_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_rhs_m, - "precision": 32, - "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), - }, - {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, - ) - ct14_rhs.polynomial = _ct14_rhs_data.reshape( - _ct14_rhs_data.shape[0], - _ct14_rhs_data.shape[1], - _ct14_rhs_r, - _ct14_rhs_c, - _ct14_rhs_m_in, - )[..., :_ct14_rhs_m].copy() - ct14_rhs.batch = ct14_rhs.polynomial.shape[0] - ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] - ct14_rhs.num_moduli = _ct14_rhs_m - ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) - ct14_rhs.r = _ct14_rhs_r - ct14_rhs.c = _ct14_rhs_c - ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] - ct14_rhs.moduli_array = jnp.array( - ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) - ) - ct14.add(ct14_rhs) - _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) - ct14.polynomial = jnp.where( - ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial - ) - _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 - _ct15_arg_m_in = _ct15_arg_data.shape[-1] - _ct15_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_arg_m_in - ) - _ct15_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_arg_r) - ) - _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) - if isinstance(_ct15_arg_moduli, (int, np.integer)): - _ct15_arg_moduli = [int(_ct15_arg_moduli)] - ct15_arg = Polynomial( - { - "batch": _ct15_arg_data.shape[0], - "num_elements": _ct15_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_arg_m, - "precision": 32, - "degree_layout": (_ct15_arg_r, _ct15_arg_c), - }, - {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, - ) - ct15_arg.polynomial = _ct15_arg_data.reshape( - _ct15_arg_data.shape[0], - _ct15_arg_data.shape[1], - _ct15_arg_r, - _ct15_arg_c, - _ct15_arg_m_in, - )[..., :_ct15_arg_m].copy() - ct15_arg.batch = ct15_arg.polynomial.shape[0] - ct15_arg.num_elements = ct15_arg.polynomial.shape[1] - ct15_arg.num_moduli = _ct15_arg_m - ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) - ct15_arg.r = _ct15_arg_r - ct15_arg.c = _ct15_arg_c - ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] - ct15_arg.moduli_array = jnp.array( - ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) - ) - ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) - _ct15_data = ( - ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw - ) - _ct15_m_in = _ct15_data.shape[-1] - _ct15_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_m_in - ) - _ct15_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_r) - ) - _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) - if isinstance(_ct15_moduli, (int, np.integer)): - _ct15_moduli = [int(_ct15_moduli)] - ct15 = Polynomial( - { - "batch": _ct15_data.shape[0], - "num_elements": _ct15_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_m, - "precision": 32, - "degree_layout": (_ct15_r, _ct15_c), - }, - {"moduli": list(_ct15_moduli)[:_ct15_m]}, - ) - ct15.polynomial = _ct15_data.reshape( - _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in - )[..., :_ct15_m].copy() - ct15.batch = ct15.polynomial.shape[0] - ct15.num_elements = ct15.polynomial.shape[1] - ct15.num_moduli = _ct15_m - ct15.degree_layout = (_ct15_r, _ct15_c) - ct15.r = _ct15_r - ct15.c = _ct15_c - ct15.moduli = list(_ct15_moduli)[:_ct15_m] - ct15.moduli_array = jnp.array( - ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) - ) - _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - _ct16_m_in = _ct16_data.shape[-1] - _ct16_m = _ct16_m_in - _ct16_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_r) - ) - _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) - if isinstance(_ct16_moduli, (int, np.integer)): - _ct16_moduli = [int(_ct16_moduli)] - ct16 = Polynomial( - { - "batch": _ct16_data.shape[0], - "num_elements": _ct16_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_m, - "precision": 32, - "degree_layout": (_ct16_r, _ct16_c), - }, - {"moduli": list(_ct16_moduli)[:_ct16_m]}, - ) - ct16.polynomial = _ct16_data.reshape( - _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in - )[..., :_ct16_m].copy() - ct16.batch = ct16.polynomial.shape[0] - ct16.num_elements = ct16.polynomial.shape[1] - ct16.num_moduli = _ct16_m - ct16.degree_layout = (_ct16_r, _ct16_c) - ct16.r = _ct16_r - ct16.c = _ct16_c - ct16.moduli = list(_ct16_moduli)[:_ct16_m] - ct16.moduli_array = jnp.array( - ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) - ) - _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 - _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] - _ct16_rhs_m = _ct16_rhs_m_in - _ct16_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_rhs_r) - ) - _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) - if isinstance(_ct16_rhs_moduli, (int, np.integer)): - _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] - ct16_rhs = Polynomial( - { - "batch": _ct16_rhs_data.shape[0], - "num_elements": _ct16_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_rhs_m, - "precision": 32, - "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), - }, - {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, - ) - ct16_rhs.polynomial = _ct16_rhs_data.reshape( - _ct16_rhs_data.shape[0], - _ct16_rhs_data.shape[1], - _ct16_rhs_r, - _ct16_rhs_c, - _ct16_rhs_m_in, - )[..., :_ct16_rhs_m].copy() - ct16_rhs.batch = ct16_rhs.polynomial.shape[0] - ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] - ct16_rhs.num_moduli = _ct16_rhs_m - ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) - ct16_rhs.r = _ct16_rhs_r - ct16_rhs.c = _ct16_rhs_c - ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] - ct16_rhs.moduli_array = jnp.array( - ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) - ) - ct16.add(ct16_rhs) - _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) - ct16.polynomial = jnp.where( - ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial - ) - _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - _ct17_m_in = _ct17_data.shape[-1] - _ct17_m = _ct17_m_in - _ct17_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_r) - ) - _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) - if isinstance(_ct17_moduli, (int, np.integer)): - _ct17_moduli = [int(_ct17_moduli)] - ct17 = Polynomial( - { - "batch": _ct17_data.shape[0], - "num_elements": _ct17_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_m, - "precision": 32, - "degree_layout": (_ct17_r, _ct17_c), - }, - {"moduli": list(_ct17_moduli)[:_ct17_m]}, - ) - ct17.polynomial = _ct17_data.reshape( - _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in - )[..., :_ct17_m].copy() - ct17.batch = ct17.polynomial.shape[0] - ct17.num_elements = ct17.polynomial.shape[1] - ct17.num_moduli = _ct17_m - ct17.degree_layout = (_ct17_r, _ct17_c) - ct17.r = _ct17_r - ct17.c = _ct17_c - ct17.moduli = list(_ct17_moduli)[:_ct17_m] - ct17.moduli_array = jnp.array( - ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) - ) - _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] - _ct17_rhs_m = _ct17_rhs_m_in - _ct17_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_rhs_r) - ) - _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) - if isinstance(_ct17_rhs_moduli, (int, np.integer)): - _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] - ct17_rhs = Polynomial( - { - "batch": _ct17_rhs_data.shape[0], - "num_elements": _ct17_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_rhs_m, - "precision": 32, - "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), - }, - {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, - ) - ct17_rhs.polynomial = _ct17_rhs_data.reshape( - _ct17_rhs_data.shape[0], - _ct17_rhs_data.shape[1], - _ct17_rhs_r, - _ct17_rhs_c, - _ct17_rhs_m_in, - )[..., :_ct17_rhs_m].copy() - ct17_rhs.batch = ct17_rhs.polynomial.shape[0] - ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] - ct17_rhs.num_moduli = _ct17_rhs_m - ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) - ct17_rhs.r = _ct17_rhs_r - ct17_rhs.c = _ct17_rhs_c - ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] - ct17_rhs.moduli_array = jnp.array( - ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) - ) - ct17.add(ct17_rhs) - _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) - ct17.polynomial = jnp.where( - ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial - ) - _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 - _ct18_m_in = _ct18_data.shape[-1] - _ct18_m = _ct18_m_in - _ct18_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_r) - ) - _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) - if isinstance(_ct18_moduli, (int, np.integer)): - _ct18_moduli = [int(_ct18_moduli)] - ct18 = Polynomial( - { - "batch": _ct18_data.shape[0], - "num_elements": _ct18_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_m, - "precision": 32, - "degree_layout": (_ct18_r, _ct18_c), - }, - {"moduli": list(_ct18_moduli)[:_ct18_m]}, - ) - ct18.polynomial = _ct18_data.reshape( - _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in - )[..., :_ct18_m].copy() - ct18.batch = ct18.polynomial.shape[0] - ct18.num_elements = ct18.polynomial.shape[1] - ct18.num_moduli = _ct18_m - ct18.degree_layout = (_ct18_r, _ct18_c) - ct18.r = _ct18_r - ct18.c = _ct18_c - ct18.moduli = list(_ct18_moduli)[:_ct18_m] - ct18.moduli_array = jnp.array( - ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) - ) - _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 - _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] - _ct18_rhs_m = _ct18_rhs_m_in - _ct18_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_rhs_r) - ) - _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) - if isinstance(_ct18_rhs_moduli, (int, np.integer)): - _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] - ct18_rhs = Polynomial( - { - "batch": _ct18_rhs_data.shape[0], - "num_elements": _ct18_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_rhs_m, - "precision": 32, - "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), - }, - {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, - ) - ct18_rhs.polynomial = _ct18_rhs_data.reshape( - _ct18_rhs_data.shape[0], - _ct18_rhs_data.shape[1], - _ct18_rhs_r, - _ct18_rhs_c, - _ct18_rhs_m_in, - )[..., :_ct18_rhs_m].copy() - ct18_rhs.batch = ct18_rhs.polynomial.shape[0] - ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] - ct18_rhs.num_moduli = _ct18_rhs_m - ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) - ct18_rhs.r = _ct18_rhs_r - ct18_rhs.c = _ct18_rhs_c - ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] - ct18_rhs.moduli_array = jnp.array( - ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) - ) - ct18.add(ct18_rhs) - _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) - ct18.polynomial = jnp.where( - ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial - ) - _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 - _ct19_m_in = _ct19_data.shape[-1] - _ct19_m = _ct19_m_in - _ct19_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_r) - ) - _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) - if isinstance(_ct19_moduli, (int, np.integer)): - _ct19_moduli = [int(_ct19_moduli)] - ct19 = Polynomial( - { - "batch": _ct19_data.shape[0], - "num_elements": _ct19_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_m, - "precision": 32, - "degree_layout": (_ct19_r, _ct19_c), - }, - {"moduli": list(_ct19_moduli)[:_ct19_m]}, - ) - ct19.polynomial = _ct19_data.reshape( - _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in - )[..., :_ct19_m].copy() - ct19.batch = ct19.polynomial.shape[0] - ct19.num_elements = ct19.polynomial.shape[1] - ct19.num_moduli = _ct19_m - ct19.degree_layout = (_ct19_r, _ct19_c) - ct19.r = _ct19_r - ct19.c = _ct19_c - ct19.moduli = list(_ct19_moduli)[:_ct19_m] - ct19.moduli_array = jnp.array( - ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) - ) - _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 - _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] - _ct19_rhs_m = _ct19_rhs_m_in - _ct19_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_rhs_r) - ) - _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) - if isinstance(_ct19_rhs_moduli, (int, np.integer)): - _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] - ct19_rhs = Polynomial( - { - "batch": _ct19_rhs_data.shape[0], - "num_elements": _ct19_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_rhs_m, - "precision": 32, - "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), - }, - {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, - ) - ct19_rhs.polynomial = _ct19_rhs_data.reshape( - _ct19_rhs_data.shape[0], - _ct19_rhs_data.shape[1], - _ct19_rhs_r, - _ct19_rhs_c, - _ct19_rhs_m_in, - )[..., :_ct19_rhs_m].copy() - ct19_rhs.batch = ct19_rhs.polynomial.shape[0] - ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] - ct19_rhs.num_moduli = _ct19_rhs_m - ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) - ct19_rhs.r = _ct19_rhs_r - ct19_rhs.c = _ct19_rhs_c - ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] - ct19_rhs.moduli_array = jnp.array( - ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) - ) - ct19.add(ct19_rhs) - _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) - ct19.polynomial = jnp.where( - ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial - ) - v16 = [None] * 1 - _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 - _ct20_arg_m_in = _ct20_arg_data.shape[-1] - _ct20_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct20_arg_m_in - ) - _ct20_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_arg_r) - ) - _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) - if isinstance(_ct20_arg_moduli, (int, np.integer)): - _ct20_arg_moduli = [int(_ct20_arg_moduli)] - ct20_arg = Polynomial( - { - "batch": _ct20_arg_data.shape[0], - "num_elements": _ct20_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_arg_m, - "precision": 32, - "degree_layout": (_ct20_arg_r, _ct20_arg_c), - }, - {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, - ) - ct20_arg.polynomial = _ct20_arg_data.reshape( - _ct20_arg_data.shape[0], - _ct20_arg_data.shape[1], - _ct20_arg_r, - _ct20_arg_c, - _ct20_arg_m_in, - )[..., :_ct20_arg_m].copy() - ct20_arg.batch = ct20_arg.polynomial.shape[0] - ct20_arg.num_elements = ct20_arg.polynomial.shape[1] - ct20_arg.num_moduli = _ct20_arg_m - ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) - ct20_arg.r = _ct20_arg_r - ct20_arg.c = _ct20_arg_c - ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] - ct20_arg.moduli_array = jnp.array( - ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) - ) - ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) - _ct20_data = ( - ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw - ) - _ct20_m_in = _ct20_data.shape[-1] - _ct20_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct20_m_in - ) - _ct20_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_r) - ) - _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) - if isinstance(_ct20_moduli, (int, np.integer)): - _ct20_moduli = [int(_ct20_moduli)] - ct20 = Polynomial( - { - "batch": _ct20_data.shape[0], - "num_elements": _ct20_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_m, - "precision": 32, - "degree_layout": (_ct20_r, _ct20_c), - }, - {"moduli": list(_ct20_moduli)[:_ct20_m]}, - ) - ct20.polynomial = _ct20_data.reshape( - _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in - )[..., :_ct20_m].copy() - ct20.batch = ct20.polynomial.shape[0] - ct20.num_elements = ct20.polynomial.shape[1] - ct20.num_moduli = _ct20_m - ct20.degree_layout = (_ct20_r, _ct20_c) - ct20.r = _ct20_r - ct20.c = _ct20_c - ct20.moduli = list(_ct20_moduli)[:_ct20_m] - ct20.moduli_array = jnp.array( - ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) - ) - v16[0] = ct20 - v17 = v16 - return v17 - - -def matvec_random( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4, v5, v6, v7, v8, v9, v10) = matvec_random__preprocessing(v0, v1) - v11 = matvec_random__preprocessed(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) - return v11 - - -def matvec_random__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw - _ct_m_in = _ct_data.shape[-1] - _ct_m = _ct_m_in - _ct_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct_r) - ) - _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) - if isinstance(_ct_moduli, (int, np.integer)): - _ct_moduli = [int(_ct_moduli)] - ct = Polynomial( - { - "batch": _ct_data.shape[0], - "num_elements": _ct_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct_m, - "precision": 32, - "degree_layout": (_ct_r, _ct_c), - }, - {"moduli": list(_ct_moduli)[:_ct_m]}, - ) - ct.polynomial = _ct_data.reshape( - _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in - )[..., :_ct_m].copy() - ct.batch = ct.polynomial.shape[0] - ct.num_elements = ct.polynomial.shape[1] - ct.num_moduli = _ct_m - ct.degree_layout = (_ct_r, _ct_c) - ct.r = _ct_r - ct.c = _ct_c - ct.moduli = list(_ct_moduli)[:_ct_m] - ct.moduli_array = jnp.array( - ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) - ) - v16 = [ct] - return v16 - - -def matvec_random__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 0 - v8 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - _num_moduli = ct.polynomial.shape[-1] - _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": ct.polynomial.shape[0], - "num_elements": ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - ct.polynomial.reshape( - ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli - ) - ) - pt = v0.decrypt(_ct_for_dec) - v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v10 = v8.copy() - for v11 in range(0, 8): - v13 = int(v11) - v14 = v9[0, v13] - v10[v13] = v14 - return v10 - - -def matvec_chain__preprocessing( - v0: ckks.CKKSContext, - v1: dict, -) -> ( - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, - np.ndarray, -): - v2 = np.array( - [ - 1.340000e00, - 5.800000e-01, - 1.260000e00, - 7.400000e-01, - 6.900000e-01, - 1.070000e00, - 6.000000e-01, - 1.390000e00, - 1.130000e00, - 1.220000e00, - 5.200000e-01, - 1.090000e00, - 1.060000e00, - 6.600000e-01, - 6.500000e-01, - 1.200000e00, - 8.200000e-01, - 1.190000e00, - 1.050000e00, - 8.900000e-01, - 1.430000e00, - 1.340000e00, - 8.600000e-01, - 5.400000e-01, - 8.000000e-01, - 9.000000e-01, - 1.200000e00, - 1.500000e00, - 8.600000e-01, - 1.260000e00, - 1.090000e00, - 1.190000e00, - 6.500000e-01, - 9.000000e-01, - 7.400000e-01, - 8.400000e-01, - 1.010000e00, - 1.170000e00, - 6.100000e-01, - 6.300000e-01, - 8.200000e-01, - 1.160000e00, - 1.350000e00, - 1.050000e00, - 1.350000e00, - 8.800000e-01, - 8.200000e-01, - 8.500000e-01, - 6.700000e-01, - 1.330000e00, - 8.400000e-01, - 1.050000e00, - 1.080000e00, - 1.020000e00, - 5.000000e-01, - 1.490000e00, - 1.410000e00, - 7.100000e-01, - 7.900000e-01, - 1.020000e00, - 1.400000e00, - 1.480000e00, - 7.600000e-01, - 1.060000e00, - ], - dtype=np.float32, - ).reshape(8, 8) - v3 = np.array( - [ - 1.200000e00, - 7.900000e-01, - 7.300000e-01, - 1.050000e00, - 1.220000e00, - 9.200000e-01, - 1.480000e00, - 1.180000e00, - 9.800000e-01, - 8.900000e-01, - 8.400000e-01, - 1.230000e00, - 9.400000e-01, - 5.600000e-01, - 9.000000e-01, - 1.240000e00, - 6.800000e-01, - 6.800000e-01, - 1.030000e00, - 1.030000e00, - 1.130000e00, - 1.350000e00, - 1.220000e00, - 1.110000e00, - 1.220000e00, - 8.200000e-01, - 8.600000e-01, - 7.300000e-01, - 7.900000e-01, - 1.130000e00, - 5.900000e-01, - 9.300000e-01, - 9.300000e-01, - 9.900000e-01, - 9.300000e-01, - 8.100000e-01, - 9.300000e-01, - 1.390000e00, - 1.440000e00, - 1.000000e00, - 1.120000e00, - 6.200000e-01, - 8.200000e-01, - 9.100000e-01, - 1.370000e00, - 7.500000e-01, - 9.800000e-01, - 1.490000e00, - 1.020000e00, - 1.110000e00, - 6.200000e-01, - 1.330000e00, - 1.100000e00, - 1.050000e00, - 8.400000e-01, - 8.000000e-01, - 9.200000e-01, - 1.180000e00, - 1.380000e00, - 1.010000e00, - 1.170000e00, - 1.090000e00, - 1.120000e00, - 1.170000e00, - ], - dtype=np.float32, - ).reshape(8, 8) - v4 = _assign_layout_15335824159471298539(v2) - v5 = _assign_layout_15335824159471298539(v3) - v6 = v5[3 : 3 + 1, 0 : 0 + 5] - v7 = v5[3 : 3 + 1, 5 : 5 + 3] - v8 = np.zeros( - ( - 1, - 8, - ), - dtype=np.float32, - ) - v9 = v8.copy() - v9[0 : 0 + 1, 3 : 3 + 5] = v6 - v10 = v9.copy() - v10[0 : 0 + 1, 0 : 0 + 3] = v7 - v11 = v5[4 : 4 + 1, 0 : 0 + 5] - v12 = v5[4 : 4 + 1, 5 : 5 + 3] - v13 = v8.copy() - v13[0 : 0 + 1, 3 : 3 + 5] = v11 - v14 = v13.copy() - v14[0 : 0 + 1, 0 : 0 + 3] = v12 - v15 = v5[5 : 5 + 1, 0 : 0 + 5] - v16 = v5[5 : 5 + 1, 5 : 5 + 3] - v17 = v8.copy() - v17[0 : 0 + 1, 3 : 3 + 5] = v15 - v18 = v17.copy() - v18[0 : 0 + 1, 0 : 0 + 3] = v16 - v19 = v5[6 : 6 + 1, 0 : 0 + 2] - v20 = v5[6 : 6 + 1, 2 : 2 + 6] - v21 = v8.copy() - v21[0 : 0 + 1, 6 : 6 + 2] = v19 - v22 = v21.copy() - v22[0 : 0 + 1, 0 : 0 + 6] = v20 - v23 = v5[7 : 7 + 1, 0 : 0 + 2] - v24 = v5[7 : 7 + 1, 2 : 2 + 6] - v25 = v8.copy() - v25[0 : 0 + 1, 6 : 6 + 2] = v23 - v26 = v25.copy() - v26[0 : 0 + 1, 0 : 0 + 6] = v24 - v27 = v4[3 : 3 + 1, 0 : 0 + 5] - v28 = v4[3 : 3 + 1, 5 : 5 + 3] - v29 = v8.copy() - v29[0 : 0 + 1, 3 : 3 + 5] = v27 - v30 = v29.copy() - v30[0 : 0 + 1, 0 : 0 + 3] = v28 - v31 = v4[4 : 4 + 1, 0 : 0 + 5] - v32 = v4[4 : 4 + 1, 5 : 5 + 3] - v33 = v8.copy() - v33[0 : 0 + 1, 3 : 3 + 5] = v31 - v34 = v33.copy() - v34[0 : 0 + 1, 0 : 0 + 3] = v32 - v35 = v4[5 : 5 + 1, 0 : 0 + 5] - v36 = v4[5 : 5 + 1, 5 : 5 + 3] - v37 = v8.copy() - v37[0 : 0 + 1, 3 : 3 + 5] = v35 - v38 = v37.copy() - v38[0 : 0 + 1, 0 : 0 + 3] = v36 - v39 = v4[6 : 6 + 1, 0 : 0 + 2] - v40 = v4[6 : 6 + 1, 2 : 2 + 6] - v41 = v8.copy() - v41[0 : 0 + 1, 6 : 6 + 2] = v39 - v42 = v41.copy() - v42[0 : 0 + 1, 0 : 0 + 6] = v40 - v43 = v4[7 : 7 + 1, 0 : 0 + 2] - v44 = v4[7 : 7 + 1, 2 : 2 + 6] - v45 = v8.copy() - v45[0 : 0 + 1, 6 : 6 + 2] = v43 - v46 = v45.copy() - v46[0 : 0 + 1, 0 : 0 + 6] = v44 - v47 = v4[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v47) - v48 = v4[1 : 1 + 1, 0 : 0 + 8].reshape(8) - pt1 = v0.encode(v48) - v49 = v4[2 : 2 + 1, 0 : 0 + 8].reshape(8) - pt2 = v0.encode(v49) - v50 = v30[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt3 = v0.encode(v50) - v51 = v34[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt4 = v0.encode(v51) - v52 = v38[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt5 = v0.encode(v52) - v53 = v42[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt6 = v0.encode(v53) - v54 = v46[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt7 = v0.encode(v54) - v55 = v5[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt8 = v0.encode(v55) - v56 = v5[1 : 1 + 1, 0 : 0 + 8].reshape(8) - pt9 = v0.encode(v56) - v57 = v5[2 : 2 + 1, 0 : 0 + 8].reshape(8) - pt10 = v0.encode(v57) - v58 = v10[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt11 = v0.encode(v58) - v59 = v14[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt12 = v0.encode(v59) - v60 = v18[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt13 = v0.encode(v60) - v61 = v22[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt14 = v0.encode(v61) - v62 = v26[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt15 = v0.encode(v62) - v63 = [pt] - v64 = [pt1] - v65 = [pt2] - v66 = [pt3] - v67 = [pt4] - v68 = [pt5] - v69 = [pt6] - v70 = [pt7] - v71 = [pt8, pt9] - v72 = [pt10, pt11] - v73 = [pt12, pt13] - v74 = [pt14, pt15] - return (v63, v64, v65, v66, v67, v68, v69, v70, v71, v72, v73, v74) - - -def matvec_chain__preprocessed( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, - v4: np.ndarray, - v5: np.ndarray, - v6: np.ndarray, - v7: np.ndarray, - v8: np.ndarray, - v9: np.ndarray, - v10: np.ndarray, - v11: np.ndarray, - v12: np.ndarray, - v13: np.ndarray, - v14: np.ndarray, -) -> np.ndarray: - v15 = 1 - v16 = 2 - v17 = 3 - v18 = 6 - v19 = 0 - pt = v3[0] - pt1 = v4[0] - pt2 = v5[0] - pt3 = v6[0] - pt4 = v7[0] - pt5 = v8[0] - pt6 = v9[0] - pt7 = v10[0] - pt8 = v11[0] - pt9 = v11[1] - pt10 = v12[0] - pt11 = v12[1] - pt12 = v13[0] - pt13 = v13[1] - pt14 = v14[0] - pt15 = v14[1] - ct = v2[0] - _ct1_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct1_arg_m_in = _ct1_arg_data.shape[-1] - _ct1_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_arg_m_in - ) - _ct1_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_arg_r) - ) - _ct1_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct1_arg_moduli, (int, np.integer)): - _ct1_arg_moduli = [int(_ct1_arg_moduli)] - ct1_arg = Polynomial( - { - "batch": _ct1_arg_data.shape[0], - "num_elements": _ct1_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_arg_m, - "precision": 32, - "degree_layout": (_ct1_arg_r, _ct1_arg_c), - }, - {"moduli": list(_ct1_arg_moduli)[:_ct1_arg_m]}, - ) - ct1_arg.polynomial = _ct1_arg_data.reshape( - _ct1_arg_data.shape[0], - _ct1_arg_data.shape[1], - _ct1_arg_r, - _ct1_arg_c, - _ct1_arg_m_in, - )[..., :_ct1_arg_m].copy() - ct1_arg.batch = ct1_arg.polynomial.shape[0] - ct1_arg.num_elements = ct1_arg.polynomial.shape[1] - ct1_arg.num_moduli = _ct1_arg_m - ct1_arg.degree_layout = (_ct1_arg_r, _ct1_arg_c) - ct1_arg.r = _ct1_arg_r - ct1_arg.c = _ct1_arg_c - ct1_arg.moduli = list(_ct1_arg_moduli)[:_ct1_arg_m] - ct1_arg.moduli_array = jnp.array( - ct1_arg.moduli, dtype=getattr(ct1_arg, "modulus_dtype", jnp.uint32) - ) - ct1_pt_ntt = ( - pt.polynomial[0, 0, :, : ct1_arg.polynomial.shape[-1]] - .reshape(ct1_arg.r, ct1_arg.c, ct1_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct1_ptct = v0.ptct_mul[v0.max_level] - ct1_ptct.set_plaintext(ct1_pt_ntt) - ct1_raw = ct1_ptct.mul(ct1_arg, use_bat=False) - _ct1_data = ct1_raw.polynomial if hasattr(ct1_raw, "polynomial") else ct1_raw - _ct1_m_in = _ct1_data.shape[-1] - _ct1_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct1_m_in - ) - _ct1_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct1_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct1_r) - ) - _ct1_moduli = getattr(ct1_raw, "moduli", v0.q_towers) - if isinstance(_ct1_moduli, (int, np.integer)): - _ct1_moduli = [int(_ct1_moduli)] - ct1 = Polynomial( - { - "batch": _ct1_data.shape[0], - "num_elements": _ct1_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct1_m, - "precision": 32, - "degree_layout": (_ct1_r, _ct1_c), - }, - {"moduli": list(_ct1_moduli)[:_ct1_m]}, - ) - ct1.polynomial = _ct1_data.reshape( - _ct1_data.shape[0], _ct1_data.shape[1], _ct1_r, _ct1_c, _ct1_m_in - )[..., :_ct1_m].copy() - ct1.batch = ct1.polynomial.shape[0] - ct1.num_elements = ct1.polynomial.shape[1] - ct1.num_moduli = _ct1_m - ct1.degree_layout = (_ct1_r, _ct1_c) - ct1.r = _ct1_r - ct1.c = _ct1_c - ct1.moduli = list(_ct1_moduli)[:_ct1_m] - ct1.moduli_array = jnp.array( - ct1.moduli, dtype=getattr(ct1, "modulus_dtype", jnp.uint32) - ) - _ct2_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct2_arg_m_in = _ct2_arg_data.shape[-1] - _ct2_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_arg_m_in - ) - _ct2_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_arg_r) - ) - _ct2_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct2_arg_moduli, (int, np.integer)): - _ct2_arg_moduli = [int(_ct2_arg_moduli)] - ct2_arg = Polynomial( - { - "batch": _ct2_arg_data.shape[0], - "num_elements": _ct2_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_arg_m, - "precision": 32, - "degree_layout": (_ct2_arg_r, _ct2_arg_c), - }, - {"moduli": list(_ct2_arg_moduli)[:_ct2_arg_m]}, - ) - ct2_arg.polynomial = _ct2_arg_data.reshape( - _ct2_arg_data.shape[0], - _ct2_arg_data.shape[1], - _ct2_arg_r, - _ct2_arg_c, - _ct2_arg_m_in, - )[..., :_ct2_arg_m].copy() - ct2_arg.batch = ct2_arg.polynomial.shape[0] - ct2_arg.num_elements = ct2_arg.polynomial.shape[1] - ct2_arg.num_moduli = _ct2_arg_m - ct2_arg.degree_layout = (_ct2_arg_r, _ct2_arg_c) - ct2_arg.r = _ct2_arg_r - ct2_arg.c = _ct2_arg_c - ct2_arg.moduli = list(_ct2_arg_moduli)[:_ct2_arg_m] - ct2_arg.moduli_array = jnp.array( - ct2_arg.moduli, dtype=getattr(ct2_arg, "modulus_dtype", jnp.uint32) - ) - ct2_raw = v0.he_rot[v0.max_level, 1].rotate(ct2_arg) - _ct2_data = ct2_raw.polynomial if hasattr(ct2_raw, "polynomial") else ct2_raw - _ct2_m_in = _ct2_data.shape[-1] - _ct2_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct2_m_in - ) - _ct2_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct2_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct2_r) - ) - _ct2_moduli = getattr(ct2_raw, "moduli", v0.q_towers) - if isinstance(_ct2_moduli, (int, np.integer)): - _ct2_moduli = [int(_ct2_moduli)] - ct2 = Polynomial( - { - "batch": _ct2_data.shape[0], - "num_elements": _ct2_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct2_m, - "precision": 32, - "degree_layout": (_ct2_r, _ct2_c), - }, - {"moduli": list(_ct2_moduli)[:_ct2_m]}, - ) - ct2.polynomial = _ct2_data.reshape( - _ct2_data.shape[0], _ct2_data.shape[1], _ct2_r, _ct2_c, _ct2_m_in - )[..., :_ct2_m].copy() - ct2.batch = ct2.polynomial.shape[0] - ct2.num_elements = ct2.polynomial.shape[1] - ct2.num_moduli = _ct2_m - ct2.degree_layout = (_ct2_r, _ct2_c) - ct2.r = _ct2_r - ct2.c = _ct2_c - ct2.moduli = list(_ct2_moduli)[:_ct2_m] - ct2.moduli_array = jnp.array( - ct2.moduli, dtype=getattr(ct2, "modulus_dtype", jnp.uint32) - ) - _ct3_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct3_arg_m_in = _ct3_arg_data.shape[-1] - _ct3_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_arg_m_in - ) - _ct3_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_arg_r) - ) - _ct3_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct3_arg_moduli, (int, np.integer)): - _ct3_arg_moduli = [int(_ct3_arg_moduli)] - ct3_arg = Polynomial( - { - "batch": _ct3_arg_data.shape[0], - "num_elements": _ct3_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_arg_m, - "precision": 32, - "degree_layout": (_ct3_arg_r, _ct3_arg_c), - }, - {"moduli": list(_ct3_arg_moduli)[:_ct3_arg_m]}, - ) - ct3_arg.polynomial = _ct3_arg_data.reshape( - _ct3_arg_data.shape[0], - _ct3_arg_data.shape[1], - _ct3_arg_r, - _ct3_arg_c, - _ct3_arg_m_in, - )[..., :_ct3_arg_m].copy() - ct3_arg.batch = ct3_arg.polynomial.shape[0] - ct3_arg.num_elements = ct3_arg.polynomial.shape[1] - ct3_arg.num_moduli = _ct3_arg_m - ct3_arg.degree_layout = (_ct3_arg_r, _ct3_arg_c) - ct3_arg.r = _ct3_arg_r - ct3_arg.c = _ct3_arg_c - ct3_arg.moduli = list(_ct3_arg_moduli)[:_ct3_arg_m] - ct3_arg.moduli_array = jnp.array( - ct3_arg.moduli, dtype=getattr(ct3_arg, "modulus_dtype", jnp.uint32) - ) - ct3_pt_ntt = ( - pt1.polynomial[0, 0, :, : ct3_arg.polynomial.shape[-1]] - .reshape(ct3_arg.r, ct3_arg.c, ct3_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct3_ptct = v0.ptct_mul[v0.max_level] - ct3_ptct.set_plaintext(ct3_pt_ntt) - ct3_raw = ct3_ptct.mul(ct3_arg, use_bat=False) - _ct3_data = ct3_raw.polynomial if hasattr(ct3_raw, "polynomial") else ct3_raw - _ct3_m_in = _ct3_data.shape[-1] - _ct3_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct3_m_in - ) - _ct3_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct3_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct3_r) - ) - _ct3_moduli = getattr(ct3_raw, "moduli", v0.q_towers) - if isinstance(_ct3_moduli, (int, np.integer)): - _ct3_moduli = [int(_ct3_moduli)] - ct3 = Polynomial( - { - "batch": _ct3_data.shape[0], - "num_elements": _ct3_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct3_m, - "precision": 32, - "degree_layout": (_ct3_r, _ct3_c), - }, - {"moduli": list(_ct3_moduli)[:_ct3_m]}, - ) - ct3.polynomial = _ct3_data.reshape( - _ct3_data.shape[0], _ct3_data.shape[1], _ct3_r, _ct3_c, _ct3_m_in - )[..., :_ct3_m].copy() - ct3.batch = ct3.polynomial.shape[0] - ct3.num_elements = ct3.polynomial.shape[1] - ct3.num_moduli = _ct3_m - ct3.degree_layout = (_ct3_r, _ct3_c) - ct3.r = _ct3_r - ct3.c = _ct3_c - ct3.moduli = list(_ct3_moduli)[:_ct3_m] - ct3.moduli_array = jnp.array( - ct3.moduli, dtype=getattr(ct3, "modulus_dtype", jnp.uint32) - ) - _ct4_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct4_arg_m_in = _ct4_arg_data.shape[-1] - _ct4_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_arg_m_in - ) - _ct4_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_arg_r) - ) - _ct4_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct4_arg_moduli, (int, np.integer)): - _ct4_arg_moduli = [int(_ct4_arg_moduli)] - ct4_arg = Polynomial( - { - "batch": _ct4_arg_data.shape[0], - "num_elements": _ct4_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_arg_m, - "precision": 32, - "degree_layout": (_ct4_arg_r, _ct4_arg_c), - }, - {"moduli": list(_ct4_arg_moduli)[:_ct4_arg_m]}, - ) - ct4_arg.polynomial = _ct4_arg_data.reshape( - _ct4_arg_data.shape[0], - _ct4_arg_data.shape[1], - _ct4_arg_r, - _ct4_arg_c, - _ct4_arg_m_in, - )[..., :_ct4_arg_m].copy() - ct4_arg.batch = ct4_arg.polynomial.shape[0] - ct4_arg.num_elements = ct4_arg.polynomial.shape[1] - ct4_arg.num_moduli = _ct4_arg_m - ct4_arg.degree_layout = (_ct4_arg_r, _ct4_arg_c) - ct4_arg.r = _ct4_arg_r - ct4_arg.c = _ct4_arg_c - ct4_arg.moduli = list(_ct4_arg_moduli)[:_ct4_arg_m] - ct4_arg.moduli_array = jnp.array( - ct4_arg.moduli, dtype=getattr(ct4_arg, "modulus_dtype", jnp.uint32) - ) - ct4_raw = v0.he_rot[v0.max_level, 2].rotate(ct4_arg) - _ct4_data = ct4_raw.polynomial if hasattr(ct4_raw, "polynomial") else ct4_raw - _ct4_m_in = _ct4_data.shape[-1] - _ct4_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct4_m_in - ) - _ct4_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct4_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct4_r) - ) - _ct4_moduli = getattr(ct4_raw, "moduli", v0.q_towers) - if isinstance(_ct4_moduli, (int, np.integer)): - _ct4_moduli = [int(_ct4_moduli)] - ct4 = Polynomial( - { - "batch": _ct4_data.shape[0], - "num_elements": _ct4_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct4_m, - "precision": 32, - "degree_layout": (_ct4_r, _ct4_c), - }, - {"moduli": list(_ct4_moduli)[:_ct4_m]}, - ) - ct4.polynomial = _ct4_data.reshape( - _ct4_data.shape[0], _ct4_data.shape[1], _ct4_r, _ct4_c, _ct4_m_in - )[..., :_ct4_m].copy() - ct4.batch = ct4.polynomial.shape[0] - ct4.num_elements = ct4.polynomial.shape[1] - ct4.num_moduli = _ct4_m - ct4.degree_layout = (_ct4_r, _ct4_c) - ct4.r = _ct4_r - ct4.c = _ct4_c - ct4.moduli = list(_ct4_moduli)[:_ct4_m] - ct4.moduli_array = jnp.array( - ct4.moduli, dtype=getattr(ct4, "modulus_dtype", jnp.uint32) - ) - _ct5_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct5_arg_m_in = _ct5_arg_data.shape[-1] - _ct5_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_arg_m_in - ) - _ct5_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_arg_r) - ) - _ct5_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct5_arg_moduli, (int, np.integer)): - _ct5_arg_moduli = [int(_ct5_arg_moduli)] - ct5_arg = Polynomial( - { - "batch": _ct5_arg_data.shape[0], - "num_elements": _ct5_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_arg_m, - "precision": 32, - "degree_layout": (_ct5_arg_r, _ct5_arg_c), - }, - {"moduli": list(_ct5_arg_moduli)[:_ct5_arg_m]}, - ) - ct5_arg.polynomial = _ct5_arg_data.reshape( - _ct5_arg_data.shape[0], - _ct5_arg_data.shape[1], - _ct5_arg_r, - _ct5_arg_c, - _ct5_arg_m_in, - )[..., :_ct5_arg_m].copy() - ct5_arg.batch = ct5_arg.polynomial.shape[0] - ct5_arg.num_elements = ct5_arg.polynomial.shape[1] - ct5_arg.num_moduli = _ct5_arg_m - ct5_arg.degree_layout = (_ct5_arg_r, _ct5_arg_c) - ct5_arg.r = _ct5_arg_r - ct5_arg.c = _ct5_arg_c - ct5_arg.moduli = list(_ct5_arg_moduli)[:_ct5_arg_m] - ct5_arg.moduli_array = jnp.array( - ct5_arg.moduli, dtype=getattr(ct5_arg, "modulus_dtype", jnp.uint32) - ) - ct5_pt_ntt = ( - pt2.polynomial[0, 0, :, : ct5_arg.polynomial.shape[-1]] - .reshape(ct5_arg.r, ct5_arg.c, ct5_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct5_ptct = v0.ptct_mul[v0.max_level] - ct5_ptct.set_plaintext(ct5_pt_ntt) - ct5_raw = ct5_ptct.mul(ct5_arg, use_bat=False) - _ct5_data = ct5_raw.polynomial if hasattr(ct5_raw, "polynomial") else ct5_raw - _ct5_m_in = _ct5_data.shape[-1] - _ct5_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct5_m_in - ) - _ct5_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct5_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct5_r) - ) - _ct5_moduli = getattr(ct5_raw, "moduli", v0.q_towers) - if isinstance(_ct5_moduli, (int, np.integer)): - _ct5_moduli = [int(_ct5_moduli)] - ct5 = Polynomial( - { - "batch": _ct5_data.shape[0], - "num_elements": _ct5_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct5_m, - "precision": 32, - "degree_layout": (_ct5_r, _ct5_c), - }, - {"moduli": list(_ct5_moduli)[:_ct5_m]}, - ) - ct5.polynomial = _ct5_data.reshape( - _ct5_data.shape[0], _ct5_data.shape[1], _ct5_r, _ct5_c, _ct5_m_in - )[..., :_ct5_m].copy() - ct5.batch = ct5.polynomial.shape[0] - ct5.num_elements = ct5.polynomial.shape[1] - ct5.num_moduli = _ct5_m - ct5.degree_layout = (_ct5_r, _ct5_c) - ct5.r = _ct5_r - ct5.c = _ct5_c - ct5.moduli = list(_ct5_moduli)[:_ct5_m] - ct5.moduli_array = jnp.array( - ct5.moduli, dtype=getattr(ct5, "modulus_dtype", jnp.uint32) - ) - _ct6_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct6_arg_m_in = _ct6_arg_data.shape[-1] - _ct6_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_arg_m_in - ) - _ct6_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_arg_r) - ) - _ct6_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct6_arg_moduli, (int, np.integer)): - _ct6_arg_moduli = [int(_ct6_arg_moduli)] - ct6_arg = Polynomial( - { - "batch": _ct6_arg_data.shape[0], - "num_elements": _ct6_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_arg_m, - "precision": 32, - "degree_layout": (_ct6_arg_r, _ct6_arg_c), - }, - {"moduli": list(_ct6_arg_moduli)[:_ct6_arg_m]}, - ) - ct6_arg.polynomial = _ct6_arg_data.reshape( - _ct6_arg_data.shape[0], - _ct6_arg_data.shape[1], - _ct6_arg_r, - _ct6_arg_c, - _ct6_arg_m_in, - )[..., :_ct6_arg_m].copy() - ct6_arg.batch = ct6_arg.polynomial.shape[0] - ct6_arg.num_elements = ct6_arg.polynomial.shape[1] - ct6_arg.num_moduli = _ct6_arg_m - ct6_arg.degree_layout = (_ct6_arg_r, _ct6_arg_c) - ct6_arg.r = _ct6_arg_r - ct6_arg.c = _ct6_arg_c - ct6_arg.moduli = list(_ct6_arg_moduli)[:_ct6_arg_m] - ct6_arg.moduli_array = jnp.array( - ct6_arg.moduli, dtype=getattr(ct6_arg, "modulus_dtype", jnp.uint32) - ) - ct6_pt_ntt = ( - pt3.polynomial[0, 0, :, : ct6_arg.polynomial.shape[-1]] - .reshape(ct6_arg.r, ct6_arg.c, ct6_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct6_ptct = v0.ptct_mul[v0.max_level] - ct6_ptct.set_plaintext(ct6_pt_ntt) - ct6_raw = ct6_ptct.mul(ct6_arg, use_bat=False) - _ct6_data = ct6_raw.polynomial if hasattr(ct6_raw, "polynomial") else ct6_raw - _ct6_m_in = _ct6_data.shape[-1] - _ct6_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct6_m_in - ) - _ct6_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct6_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct6_r) - ) - _ct6_moduli = getattr(ct6_raw, "moduli", v0.q_towers) - if isinstance(_ct6_moduli, (int, np.integer)): - _ct6_moduli = [int(_ct6_moduli)] - ct6 = Polynomial( - { - "batch": _ct6_data.shape[0], - "num_elements": _ct6_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct6_m, - "precision": 32, - "degree_layout": (_ct6_r, _ct6_c), - }, - {"moduli": list(_ct6_moduli)[:_ct6_m]}, - ) - ct6.polynomial = _ct6_data.reshape( - _ct6_data.shape[0], _ct6_data.shape[1], _ct6_r, _ct6_c, _ct6_m_in - )[..., :_ct6_m].copy() - ct6.batch = ct6.polynomial.shape[0] - ct6.num_elements = ct6.polynomial.shape[1] - ct6.num_moduli = _ct6_m - ct6.degree_layout = (_ct6_r, _ct6_c) - ct6.r = _ct6_r - ct6.c = _ct6_c - ct6.moduli = list(_ct6_moduli)[:_ct6_m] - ct6.moduli_array = jnp.array( - ct6.moduli, dtype=getattr(ct6, "modulus_dtype", jnp.uint32) - ) - _ct7_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct7_arg_m_in = _ct7_arg_data.shape[-1] - _ct7_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_arg_m_in - ) - _ct7_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_arg_r) - ) - _ct7_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct7_arg_moduli, (int, np.integer)): - _ct7_arg_moduli = [int(_ct7_arg_moduli)] - ct7_arg = Polynomial( - { - "batch": _ct7_arg_data.shape[0], - "num_elements": _ct7_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_arg_m, - "precision": 32, - "degree_layout": (_ct7_arg_r, _ct7_arg_c), - }, - {"moduli": list(_ct7_arg_moduli)[:_ct7_arg_m]}, - ) - ct7_arg.polynomial = _ct7_arg_data.reshape( - _ct7_arg_data.shape[0], - _ct7_arg_data.shape[1], - _ct7_arg_r, - _ct7_arg_c, - _ct7_arg_m_in, - )[..., :_ct7_arg_m].copy() - ct7_arg.batch = ct7_arg.polynomial.shape[0] - ct7_arg.num_elements = ct7_arg.polynomial.shape[1] - ct7_arg.num_moduli = _ct7_arg_m - ct7_arg.degree_layout = (_ct7_arg_r, _ct7_arg_c) - ct7_arg.r = _ct7_arg_r - ct7_arg.c = _ct7_arg_c - ct7_arg.moduli = list(_ct7_arg_moduli)[:_ct7_arg_m] - ct7_arg.moduli_array = jnp.array( - ct7_arg.moduli, dtype=getattr(ct7_arg, "modulus_dtype", jnp.uint32) - ) - ct7_pt_ntt = ( - pt4.polynomial[0, 0, :, : ct7_arg.polynomial.shape[-1]] - .reshape(ct7_arg.r, ct7_arg.c, ct7_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct7_ptct = v0.ptct_mul[v0.max_level] - ct7_ptct.set_plaintext(ct7_pt_ntt) - ct7_raw = ct7_ptct.mul(ct7_arg, use_bat=False) - _ct7_data = ct7_raw.polynomial if hasattr(ct7_raw, "polynomial") else ct7_raw - _ct7_m_in = _ct7_data.shape[-1] - _ct7_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct7_m_in - ) - _ct7_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct7_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct7_r) - ) - _ct7_moduli = getattr(ct7_raw, "moduli", v0.q_towers) - if isinstance(_ct7_moduli, (int, np.integer)): - _ct7_moduli = [int(_ct7_moduli)] - ct7 = Polynomial( - { - "batch": _ct7_data.shape[0], - "num_elements": _ct7_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct7_m, - "precision": 32, - "degree_layout": (_ct7_r, _ct7_c), - }, - {"moduli": list(_ct7_moduli)[:_ct7_m]}, - ) - ct7.polynomial = _ct7_data.reshape( - _ct7_data.shape[0], _ct7_data.shape[1], _ct7_r, _ct7_c, _ct7_m_in - )[..., :_ct7_m].copy() - ct7.batch = ct7.polynomial.shape[0] - ct7.num_elements = ct7.polynomial.shape[1] - ct7.num_moduli = _ct7_m - ct7.degree_layout = (_ct7_r, _ct7_c) - ct7.r = _ct7_r - ct7.c = _ct7_c - ct7.moduli = list(_ct7_moduli)[:_ct7_m] - ct7.moduli_array = jnp.array( - ct7.moduli, dtype=getattr(ct7, "modulus_dtype", jnp.uint32) - ) - _ct8_arg_data = ct4.polynomial if hasattr(ct4, "polynomial") else ct4 - _ct8_arg_m_in = _ct8_arg_data.shape[-1] - _ct8_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_arg_m_in - ) - _ct8_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_arg_r) - ) - _ct8_arg_moduli = getattr(ct4, "moduli", v0.q_towers) - if isinstance(_ct8_arg_moduli, (int, np.integer)): - _ct8_arg_moduli = [int(_ct8_arg_moduli)] - ct8_arg = Polynomial( - { - "batch": _ct8_arg_data.shape[0], - "num_elements": _ct8_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_arg_m, - "precision": 32, - "degree_layout": (_ct8_arg_r, _ct8_arg_c), - }, - {"moduli": list(_ct8_arg_moduli)[:_ct8_arg_m]}, - ) - ct8_arg.polynomial = _ct8_arg_data.reshape( - _ct8_arg_data.shape[0], - _ct8_arg_data.shape[1], - _ct8_arg_r, - _ct8_arg_c, - _ct8_arg_m_in, - )[..., :_ct8_arg_m].copy() - ct8_arg.batch = ct8_arg.polynomial.shape[0] - ct8_arg.num_elements = ct8_arg.polynomial.shape[1] - ct8_arg.num_moduli = _ct8_arg_m - ct8_arg.degree_layout = (_ct8_arg_r, _ct8_arg_c) - ct8_arg.r = _ct8_arg_r - ct8_arg.c = _ct8_arg_c - ct8_arg.moduli = list(_ct8_arg_moduli)[:_ct8_arg_m] - ct8_arg.moduli_array = jnp.array( - ct8_arg.moduli, dtype=getattr(ct8_arg, "modulus_dtype", jnp.uint32) - ) - ct8_pt_ntt = ( - pt5.polynomial[0, 0, :, : ct8_arg.polynomial.shape[-1]] - .reshape(ct8_arg.r, ct8_arg.c, ct8_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct8_ptct = v0.ptct_mul[v0.max_level] - ct8_ptct.set_plaintext(ct8_pt_ntt) - ct8_raw = ct8_ptct.mul(ct8_arg, use_bat=False) - _ct8_data = ct8_raw.polynomial if hasattr(ct8_raw, "polynomial") else ct8_raw - _ct8_m_in = _ct8_data.shape[-1] - _ct8_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct8_m_in - ) - _ct8_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct8_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct8_r) - ) - _ct8_moduli = getattr(ct8_raw, "moduli", v0.q_towers) - if isinstance(_ct8_moduli, (int, np.integer)): - _ct8_moduli = [int(_ct8_moduli)] - ct8 = Polynomial( - { - "batch": _ct8_data.shape[0], - "num_elements": _ct8_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct8_m, - "precision": 32, - "degree_layout": (_ct8_r, _ct8_c), - }, - {"moduli": list(_ct8_moduli)[:_ct8_m]}, - ) - ct8.polynomial = _ct8_data.reshape( - _ct8_data.shape[0], _ct8_data.shape[1], _ct8_r, _ct8_c, _ct8_m_in - )[..., :_ct8_m].copy() - ct8.batch = ct8.polynomial.shape[0] - ct8.num_elements = ct8.polynomial.shape[1] - ct8.num_moduli = _ct8_m - ct8.degree_layout = (_ct8_r, _ct8_c) - ct8.r = _ct8_r - ct8.c = _ct8_c - ct8.moduli = list(_ct8_moduli)[:_ct8_m] - ct8.moduli_array = jnp.array( - ct8.moduli, dtype=getattr(ct8, "modulus_dtype", jnp.uint32) - ) - _ct9_data = ct6.polynomial if hasattr(ct6, "polynomial") else ct6 - _ct9_m_in = _ct9_data.shape[-1] - _ct9_m = _ct9_m_in - _ct9_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_r) - ) - _ct9_moduli = getattr(ct6, "moduli", v0.q_towers) - if isinstance(_ct9_moduli, (int, np.integer)): - _ct9_moduli = [int(_ct9_moduli)] - ct9 = Polynomial( - { - "batch": _ct9_data.shape[0], - "num_elements": _ct9_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_m, - "precision": 32, - "degree_layout": (_ct9_r, _ct9_c), - }, - {"moduli": list(_ct9_moduli)[:_ct9_m]}, - ) - ct9.polynomial = _ct9_data.reshape( - _ct9_data.shape[0], _ct9_data.shape[1], _ct9_r, _ct9_c, _ct9_m_in - )[..., :_ct9_m].copy() - ct9.batch = ct9.polynomial.shape[0] - ct9.num_elements = ct9.polynomial.shape[1] - ct9.num_moduli = _ct9_m - ct9.degree_layout = (_ct9_r, _ct9_c) - ct9.r = _ct9_r - ct9.c = _ct9_c - ct9.moduli = list(_ct9_moduli)[:_ct9_m] - ct9.moduli_array = jnp.array( - ct9.moduli, dtype=getattr(ct9, "modulus_dtype", jnp.uint32) - ) - _ct9_rhs_data = ct7.polynomial if hasattr(ct7, "polynomial") else ct7 - _ct9_rhs_m_in = _ct9_rhs_data.shape[-1] - _ct9_rhs_m = _ct9_rhs_m_in - _ct9_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct9_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct9_rhs_r) - ) - _ct9_rhs_moduli = getattr(ct7, "moduli", v0.q_towers) - if isinstance(_ct9_rhs_moduli, (int, np.integer)): - _ct9_rhs_moduli = [int(_ct9_rhs_moduli)] - ct9_rhs = Polynomial( - { - "batch": _ct9_rhs_data.shape[0], - "num_elements": _ct9_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct9_rhs_m, - "precision": 32, - "degree_layout": (_ct9_rhs_r, _ct9_rhs_c), - }, - {"moduli": list(_ct9_rhs_moduli)[:_ct9_rhs_m]}, - ) - ct9_rhs.polynomial = _ct9_rhs_data.reshape( - _ct9_rhs_data.shape[0], - _ct9_rhs_data.shape[1], - _ct9_rhs_r, - _ct9_rhs_c, - _ct9_rhs_m_in, - )[..., :_ct9_rhs_m].copy() - ct9_rhs.batch = ct9_rhs.polynomial.shape[0] - ct9_rhs.num_elements = ct9_rhs.polynomial.shape[1] - ct9_rhs.num_moduli = _ct9_rhs_m - ct9_rhs.degree_layout = (_ct9_rhs_r, _ct9_rhs_c) - ct9_rhs.r = _ct9_rhs_r - ct9_rhs.c = _ct9_rhs_c - ct9_rhs.moduli = list(_ct9_rhs_moduli)[:_ct9_rhs_m] - ct9_rhs.moduli_array = jnp.array( - ct9_rhs.moduli, dtype=getattr(ct9_rhs, "modulus_dtype", jnp.uint32) - ) - ct9.add(ct9_rhs) - _moduli = jnp.array(ct9.moduli, dtype=jnp.uint32) - ct9.polynomial = jnp.where( - ct9.polynomial >= _moduli, ct9.polynomial - _moduli, ct9.polynomial - ) - _ct10_data = ct9.polynomial if hasattr(ct9, "polynomial") else ct9 - _ct10_m_in = _ct10_data.shape[-1] - _ct10_m = _ct10_m_in - _ct10_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_r) - ) - _ct10_moduli = getattr(ct9, "moduli", v0.q_towers) - if isinstance(_ct10_moduli, (int, np.integer)): - _ct10_moduli = [int(_ct10_moduli)] - ct10 = Polynomial( - { - "batch": _ct10_data.shape[0], - "num_elements": _ct10_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_m, - "precision": 32, - "degree_layout": (_ct10_r, _ct10_c), - }, - {"moduli": list(_ct10_moduli)[:_ct10_m]}, - ) - ct10.polynomial = _ct10_data.reshape( - _ct10_data.shape[0], _ct10_data.shape[1], _ct10_r, _ct10_c, _ct10_m_in - )[..., :_ct10_m].copy() - ct10.batch = ct10.polynomial.shape[0] - ct10.num_elements = ct10.polynomial.shape[1] - ct10.num_moduli = _ct10_m - ct10.degree_layout = (_ct10_r, _ct10_c) - ct10.r = _ct10_r - ct10.c = _ct10_c - ct10.moduli = list(_ct10_moduli)[:_ct10_m] - ct10.moduli_array = jnp.array( - ct10.moduli, dtype=getattr(ct10, "modulus_dtype", jnp.uint32) - ) - _ct10_rhs_data = ct8.polynomial if hasattr(ct8, "polynomial") else ct8 - _ct10_rhs_m_in = _ct10_rhs_data.shape[-1] - _ct10_rhs_m = _ct10_rhs_m_in - _ct10_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct10_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct10_rhs_r) - ) - _ct10_rhs_moduli = getattr(ct8, "moduli", v0.q_towers) - if isinstance(_ct10_rhs_moduli, (int, np.integer)): - _ct10_rhs_moduli = [int(_ct10_rhs_moduli)] - ct10_rhs = Polynomial( - { - "batch": _ct10_rhs_data.shape[0], - "num_elements": _ct10_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct10_rhs_m, - "precision": 32, - "degree_layout": (_ct10_rhs_r, _ct10_rhs_c), - }, - {"moduli": list(_ct10_rhs_moduli)[:_ct10_rhs_m]}, - ) - ct10_rhs.polynomial = _ct10_rhs_data.reshape( - _ct10_rhs_data.shape[0], - _ct10_rhs_data.shape[1], - _ct10_rhs_r, - _ct10_rhs_c, - _ct10_rhs_m_in, - )[..., :_ct10_rhs_m].copy() - ct10_rhs.batch = ct10_rhs.polynomial.shape[0] - ct10_rhs.num_elements = ct10_rhs.polynomial.shape[1] - ct10_rhs.num_moduli = _ct10_rhs_m - ct10_rhs.degree_layout = (_ct10_rhs_r, _ct10_rhs_c) - ct10_rhs.r = _ct10_rhs_r - ct10_rhs.c = _ct10_rhs_c - ct10_rhs.moduli = list(_ct10_rhs_moduli)[:_ct10_rhs_m] - ct10_rhs.moduli_array = jnp.array( - ct10_rhs.moduli, dtype=getattr(ct10_rhs, "modulus_dtype", jnp.uint32) - ) - ct10.add(ct10_rhs) - _moduli = jnp.array(ct10.moduli, dtype=jnp.uint32) - ct10.polynomial = jnp.where( - ct10.polynomial >= _moduli, ct10.polynomial - _moduli, ct10.polynomial - ) - _ct11_arg_data = ct10.polynomial if hasattr(ct10, "polynomial") else ct10 - _ct11_arg_m_in = _ct11_arg_data.shape[-1] - _ct11_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_arg_m_in - ) - _ct11_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_arg_r) - ) - _ct11_arg_moduli = getattr(ct10, "moduli", v0.q_towers) - if isinstance(_ct11_arg_moduli, (int, np.integer)): - _ct11_arg_moduli = [int(_ct11_arg_moduli)] - ct11_arg = Polynomial( - { - "batch": _ct11_arg_data.shape[0], - "num_elements": _ct11_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_arg_m, - "precision": 32, - "degree_layout": (_ct11_arg_r, _ct11_arg_c), - }, - {"moduli": list(_ct11_arg_moduli)[:_ct11_arg_m]}, - ) - ct11_arg.polynomial = _ct11_arg_data.reshape( - _ct11_arg_data.shape[0], - _ct11_arg_data.shape[1], - _ct11_arg_r, - _ct11_arg_c, - _ct11_arg_m_in, - )[..., :_ct11_arg_m].copy() - ct11_arg.batch = ct11_arg.polynomial.shape[0] - ct11_arg.num_elements = ct11_arg.polynomial.shape[1] - ct11_arg.num_moduli = _ct11_arg_m - ct11_arg.degree_layout = (_ct11_arg_r, _ct11_arg_c) - ct11_arg.r = _ct11_arg_r - ct11_arg.c = _ct11_arg_c - ct11_arg.moduli = list(_ct11_arg_moduli)[:_ct11_arg_m] - ct11_arg.moduli_array = jnp.array( - ct11_arg.moduli, dtype=getattr(ct11_arg, "modulus_dtype", jnp.uint32) - ) - ct11_raw = v0.he_rot[v0.max_level, 3].rotate(ct11_arg) - _ct11_data = ( - ct11_raw.polynomial if hasattr(ct11_raw, "polynomial") else ct11_raw - ) - _ct11_m_in = _ct11_data.shape[-1] - _ct11_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct11_m_in - ) - _ct11_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct11_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct11_r) - ) - _ct11_moduli = getattr(ct11_raw, "moduli", v0.q_towers) - if isinstance(_ct11_moduli, (int, np.integer)): - _ct11_moduli = [int(_ct11_moduli)] - ct11 = Polynomial( - { - "batch": _ct11_data.shape[0], - "num_elements": _ct11_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct11_m, - "precision": 32, - "degree_layout": (_ct11_r, _ct11_c), - }, - {"moduli": list(_ct11_moduli)[:_ct11_m]}, - ) - ct11.polynomial = _ct11_data.reshape( - _ct11_data.shape[0], _ct11_data.shape[1], _ct11_r, _ct11_c, _ct11_m_in - )[..., :_ct11_m].copy() - ct11.batch = ct11.polynomial.shape[0] - ct11.num_elements = ct11.polynomial.shape[1] - ct11.num_moduli = _ct11_m - ct11.degree_layout = (_ct11_r, _ct11_c) - ct11.r = _ct11_r - ct11.c = _ct11_c - ct11.moduli = list(_ct11_moduli)[:_ct11_m] - ct11.moduli_array = jnp.array( - ct11.moduli, dtype=getattr(ct11, "modulus_dtype", jnp.uint32) - ) - _ct12_arg_data = ct.polynomial if hasattr(ct, "polynomial") else ct - _ct12_arg_m_in = _ct12_arg_data.shape[-1] - _ct12_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_arg_m_in - ) - _ct12_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_arg_r) - ) - _ct12_arg_moduli = getattr(ct, "moduli", v0.q_towers) - if isinstance(_ct12_arg_moduli, (int, np.integer)): - _ct12_arg_moduli = [int(_ct12_arg_moduli)] - ct12_arg = Polynomial( - { - "batch": _ct12_arg_data.shape[0], - "num_elements": _ct12_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_arg_m, - "precision": 32, - "degree_layout": (_ct12_arg_r, _ct12_arg_c), - }, - {"moduli": list(_ct12_arg_moduli)[:_ct12_arg_m]}, - ) - ct12_arg.polynomial = _ct12_arg_data.reshape( - _ct12_arg_data.shape[0], - _ct12_arg_data.shape[1], - _ct12_arg_r, - _ct12_arg_c, - _ct12_arg_m_in, - )[..., :_ct12_arg_m].copy() - ct12_arg.batch = ct12_arg.polynomial.shape[0] - ct12_arg.num_elements = ct12_arg.polynomial.shape[1] - ct12_arg.num_moduli = _ct12_arg_m - ct12_arg.degree_layout = (_ct12_arg_r, _ct12_arg_c) - ct12_arg.r = _ct12_arg_r - ct12_arg.c = _ct12_arg_c - ct12_arg.moduli = list(_ct12_arg_moduli)[:_ct12_arg_m] - ct12_arg.moduli_array = jnp.array( - ct12_arg.moduli, dtype=getattr(ct12_arg, "modulus_dtype", jnp.uint32) - ) - ct12_pt_ntt = ( - pt6.polynomial[0, 0, :, : ct12_arg.polynomial.shape[-1]] - .reshape(ct12_arg.r, ct12_arg.c, ct12_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct12_ptct = v0.ptct_mul[v0.max_level] - ct12_ptct.set_plaintext(ct12_pt_ntt) - ct12_raw = ct12_ptct.mul(ct12_arg, use_bat=False) - _ct12_data = ( - ct12_raw.polynomial if hasattr(ct12_raw, "polynomial") else ct12_raw - ) - _ct12_m_in = _ct12_data.shape[-1] - _ct12_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct12_m_in - ) - _ct12_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct12_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct12_r) - ) - _ct12_moduli = getattr(ct12_raw, "moduli", v0.q_towers) - if isinstance(_ct12_moduli, (int, np.integer)): - _ct12_moduli = [int(_ct12_moduli)] - ct12 = Polynomial( - { - "batch": _ct12_data.shape[0], - "num_elements": _ct12_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct12_m, - "precision": 32, - "degree_layout": (_ct12_r, _ct12_c), - }, - {"moduli": list(_ct12_moduli)[:_ct12_m]}, - ) - ct12.polynomial = _ct12_data.reshape( - _ct12_data.shape[0], _ct12_data.shape[1], _ct12_r, _ct12_c, _ct12_m_in - )[..., :_ct12_m].copy() - ct12.batch = ct12.polynomial.shape[0] - ct12.num_elements = ct12.polynomial.shape[1] - ct12.num_moduli = _ct12_m - ct12.degree_layout = (_ct12_r, _ct12_c) - ct12.r = _ct12_r - ct12.c = _ct12_c - ct12.moduli = list(_ct12_moduli)[:_ct12_m] - ct12.moduli_array = jnp.array( - ct12.moduli, dtype=getattr(ct12, "modulus_dtype", jnp.uint32) - ) - _ct13_arg_data = ct2.polynomial if hasattr(ct2, "polynomial") else ct2 - _ct13_arg_m_in = _ct13_arg_data.shape[-1] - _ct13_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_arg_m_in - ) - _ct13_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_arg_r) - ) - _ct13_arg_moduli = getattr(ct2, "moduli", v0.q_towers) - if isinstance(_ct13_arg_moduli, (int, np.integer)): - _ct13_arg_moduli = [int(_ct13_arg_moduli)] - ct13_arg = Polynomial( - { - "batch": _ct13_arg_data.shape[0], - "num_elements": _ct13_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_arg_m, - "precision": 32, - "degree_layout": (_ct13_arg_r, _ct13_arg_c), - }, - {"moduli": list(_ct13_arg_moduli)[:_ct13_arg_m]}, - ) - ct13_arg.polynomial = _ct13_arg_data.reshape( - _ct13_arg_data.shape[0], - _ct13_arg_data.shape[1], - _ct13_arg_r, - _ct13_arg_c, - _ct13_arg_m_in, - )[..., :_ct13_arg_m].copy() - ct13_arg.batch = ct13_arg.polynomial.shape[0] - ct13_arg.num_elements = ct13_arg.polynomial.shape[1] - ct13_arg.num_moduli = _ct13_arg_m - ct13_arg.degree_layout = (_ct13_arg_r, _ct13_arg_c) - ct13_arg.r = _ct13_arg_r - ct13_arg.c = _ct13_arg_c - ct13_arg.moduli = list(_ct13_arg_moduli)[:_ct13_arg_m] - ct13_arg.moduli_array = jnp.array( - ct13_arg.moduli, dtype=getattr(ct13_arg, "modulus_dtype", jnp.uint32) - ) - ct13_pt_ntt = ( - pt7.polynomial[0, 0, :, : ct13_arg.polynomial.shape[-1]] - .reshape(ct13_arg.r, ct13_arg.c, ct13_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct13_ptct = v0.ptct_mul[v0.max_level] - ct13_ptct.set_plaintext(ct13_pt_ntt) - ct13_raw = ct13_ptct.mul(ct13_arg, use_bat=False) - _ct13_data = ( - ct13_raw.polynomial if hasattr(ct13_raw, "polynomial") else ct13_raw - ) - _ct13_m_in = _ct13_data.shape[-1] - _ct13_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct13_m_in - ) - _ct13_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct13_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct13_r) - ) - _ct13_moduli = getattr(ct13_raw, "moduli", v0.q_towers) - if isinstance(_ct13_moduli, (int, np.integer)): - _ct13_moduli = [int(_ct13_moduli)] - ct13 = Polynomial( - { - "batch": _ct13_data.shape[0], - "num_elements": _ct13_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct13_m, - "precision": 32, - "degree_layout": (_ct13_r, _ct13_c), - }, - {"moduli": list(_ct13_moduli)[:_ct13_m]}, - ) - ct13.polynomial = _ct13_data.reshape( - _ct13_data.shape[0], _ct13_data.shape[1], _ct13_r, _ct13_c, _ct13_m_in - )[..., :_ct13_m].copy() - ct13.batch = ct13.polynomial.shape[0] - ct13.num_elements = ct13.polynomial.shape[1] - ct13.num_moduli = _ct13_m - ct13.degree_layout = (_ct13_r, _ct13_c) - ct13.r = _ct13_r - ct13.c = _ct13_c - ct13.moduli = list(_ct13_moduli)[:_ct13_m] - ct13.moduli_array = jnp.array( - ct13.moduli, dtype=getattr(ct13, "modulus_dtype", jnp.uint32) - ) - _ct14_data = ct12.polynomial if hasattr(ct12, "polynomial") else ct12 - _ct14_m_in = _ct14_data.shape[-1] - _ct14_m = _ct14_m_in - _ct14_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_r) - ) - _ct14_moduli = getattr(ct12, "moduli", v0.q_towers) - if isinstance(_ct14_moduli, (int, np.integer)): - _ct14_moduli = [int(_ct14_moduli)] - ct14 = Polynomial( - { - "batch": _ct14_data.shape[0], - "num_elements": _ct14_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_m, - "precision": 32, - "degree_layout": (_ct14_r, _ct14_c), - }, - {"moduli": list(_ct14_moduli)[:_ct14_m]}, - ) - ct14.polynomial = _ct14_data.reshape( - _ct14_data.shape[0], _ct14_data.shape[1], _ct14_r, _ct14_c, _ct14_m_in - )[..., :_ct14_m].copy() - ct14.batch = ct14.polynomial.shape[0] - ct14.num_elements = ct14.polynomial.shape[1] - ct14.num_moduli = _ct14_m - ct14.degree_layout = (_ct14_r, _ct14_c) - ct14.r = _ct14_r - ct14.c = _ct14_c - ct14.moduli = list(_ct14_moduli)[:_ct14_m] - ct14.moduli_array = jnp.array( - ct14.moduli, dtype=getattr(ct14, "modulus_dtype", jnp.uint32) - ) - _ct14_rhs_data = ct13.polynomial if hasattr(ct13, "polynomial") else ct13 - _ct14_rhs_m_in = _ct14_rhs_data.shape[-1] - _ct14_rhs_m = _ct14_rhs_m_in - _ct14_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct14_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct14_rhs_r) - ) - _ct14_rhs_moduli = getattr(ct13, "moduli", v0.q_towers) - if isinstance(_ct14_rhs_moduli, (int, np.integer)): - _ct14_rhs_moduli = [int(_ct14_rhs_moduli)] - ct14_rhs = Polynomial( - { - "batch": _ct14_rhs_data.shape[0], - "num_elements": _ct14_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct14_rhs_m, - "precision": 32, - "degree_layout": (_ct14_rhs_r, _ct14_rhs_c), - }, - {"moduli": list(_ct14_rhs_moduli)[:_ct14_rhs_m]}, - ) - ct14_rhs.polynomial = _ct14_rhs_data.reshape( - _ct14_rhs_data.shape[0], - _ct14_rhs_data.shape[1], - _ct14_rhs_r, - _ct14_rhs_c, - _ct14_rhs_m_in, - )[..., :_ct14_rhs_m].copy() - ct14_rhs.batch = ct14_rhs.polynomial.shape[0] - ct14_rhs.num_elements = ct14_rhs.polynomial.shape[1] - ct14_rhs.num_moduli = _ct14_rhs_m - ct14_rhs.degree_layout = (_ct14_rhs_r, _ct14_rhs_c) - ct14_rhs.r = _ct14_rhs_r - ct14_rhs.c = _ct14_rhs_c - ct14_rhs.moduli = list(_ct14_rhs_moduli)[:_ct14_rhs_m] - ct14_rhs.moduli_array = jnp.array( - ct14_rhs.moduli, dtype=getattr(ct14_rhs, "modulus_dtype", jnp.uint32) - ) - ct14.add(ct14_rhs) - _moduli = jnp.array(ct14.moduli, dtype=jnp.uint32) - ct14.polynomial = jnp.where( - ct14.polynomial >= _moduli, ct14.polynomial - _moduli, ct14.polynomial - ) - _ct15_arg_data = ct14.polynomial if hasattr(ct14, "polynomial") else ct14 - _ct15_arg_m_in = _ct15_arg_data.shape[-1] - _ct15_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_arg_m_in - ) - _ct15_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_arg_r) - ) - _ct15_arg_moduli = getattr(ct14, "moduli", v0.q_towers) - if isinstance(_ct15_arg_moduli, (int, np.integer)): - _ct15_arg_moduli = [int(_ct15_arg_moduli)] - ct15_arg = Polynomial( - { - "batch": _ct15_arg_data.shape[0], - "num_elements": _ct15_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_arg_m, - "precision": 32, - "degree_layout": (_ct15_arg_r, _ct15_arg_c), - }, - {"moduli": list(_ct15_arg_moduli)[:_ct15_arg_m]}, - ) - ct15_arg.polynomial = _ct15_arg_data.reshape( - _ct15_arg_data.shape[0], - _ct15_arg_data.shape[1], - _ct15_arg_r, - _ct15_arg_c, - _ct15_arg_m_in, - )[..., :_ct15_arg_m].copy() - ct15_arg.batch = ct15_arg.polynomial.shape[0] - ct15_arg.num_elements = ct15_arg.polynomial.shape[1] - ct15_arg.num_moduli = _ct15_arg_m - ct15_arg.degree_layout = (_ct15_arg_r, _ct15_arg_c) - ct15_arg.r = _ct15_arg_r - ct15_arg.c = _ct15_arg_c - ct15_arg.moduli = list(_ct15_arg_moduli)[:_ct15_arg_m] - ct15_arg.moduli_array = jnp.array( - ct15_arg.moduli, dtype=getattr(ct15_arg, "modulus_dtype", jnp.uint32) - ) - ct15_raw = v0.he_rot[v0.max_level, 6].rotate(ct15_arg) - _ct15_data = ( - ct15_raw.polynomial if hasattr(ct15_raw, "polynomial") else ct15_raw - ) - _ct15_m_in = _ct15_data.shape[-1] - _ct15_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct15_m_in - ) - _ct15_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct15_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct15_r) - ) - _ct15_moduli = getattr(ct15_raw, "moduli", v0.q_towers) - if isinstance(_ct15_moduli, (int, np.integer)): - _ct15_moduli = [int(_ct15_moduli)] - ct15 = Polynomial( - { - "batch": _ct15_data.shape[0], - "num_elements": _ct15_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct15_m, - "precision": 32, - "degree_layout": (_ct15_r, _ct15_c), - }, - {"moduli": list(_ct15_moduli)[:_ct15_m]}, - ) - ct15.polynomial = _ct15_data.reshape( - _ct15_data.shape[0], _ct15_data.shape[1], _ct15_r, _ct15_c, _ct15_m_in - )[..., :_ct15_m].copy() - ct15.batch = ct15.polynomial.shape[0] - ct15.num_elements = ct15.polynomial.shape[1] - ct15.num_moduli = _ct15_m - ct15.degree_layout = (_ct15_r, _ct15_c) - ct15.r = _ct15_r - ct15.c = _ct15_c - ct15.moduli = list(_ct15_moduli)[:_ct15_m] - ct15.moduli_array = jnp.array( - ct15.moduli, dtype=getattr(ct15, "modulus_dtype", jnp.uint32) - ) - _ct16_data = ct1.polynomial if hasattr(ct1, "polynomial") else ct1 - _ct16_m_in = _ct16_data.shape[-1] - _ct16_m = _ct16_m_in - _ct16_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_r) - ) - _ct16_moduli = getattr(ct1, "moduli", v0.q_towers) - if isinstance(_ct16_moduli, (int, np.integer)): - _ct16_moduli = [int(_ct16_moduli)] - ct16 = Polynomial( - { - "batch": _ct16_data.shape[0], - "num_elements": _ct16_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_m, - "precision": 32, - "degree_layout": (_ct16_r, _ct16_c), - }, - {"moduli": list(_ct16_moduli)[:_ct16_m]}, - ) - ct16.polynomial = _ct16_data.reshape( - _ct16_data.shape[0], _ct16_data.shape[1], _ct16_r, _ct16_c, _ct16_m_in - )[..., :_ct16_m].copy() - ct16.batch = ct16.polynomial.shape[0] - ct16.num_elements = ct16.polynomial.shape[1] - ct16.num_moduli = _ct16_m - ct16.degree_layout = (_ct16_r, _ct16_c) - ct16.r = _ct16_r - ct16.c = _ct16_c - ct16.moduli = list(_ct16_moduli)[:_ct16_m] - ct16.moduli_array = jnp.array( - ct16.moduli, dtype=getattr(ct16, "modulus_dtype", jnp.uint32) - ) - _ct16_rhs_data = ct3.polynomial if hasattr(ct3, "polynomial") else ct3 - _ct16_rhs_m_in = _ct16_rhs_data.shape[-1] - _ct16_rhs_m = _ct16_rhs_m_in - _ct16_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct16_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct16_rhs_r) - ) - _ct16_rhs_moduli = getattr(ct3, "moduli", v0.q_towers) - if isinstance(_ct16_rhs_moduli, (int, np.integer)): - _ct16_rhs_moduli = [int(_ct16_rhs_moduli)] - ct16_rhs = Polynomial( - { - "batch": _ct16_rhs_data.shape[0], - "num_elements": _ct16_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct16_rhs_m, - "precision": 32, - "degree_layout": (_ct16_rhs_r, _ct16_rhs_c), - }, - {"moduli": list(_ct16_rhs_moduli)[:_ct16_rhs_m]}, - ) - ct16_rhs.polynomial = _ct16_rhs_data.reshape( - _ct16_rhs_data.shape[0], - _ct16_rhs_data.shape[1], - _ct16_rhs_r, - _ct16_rhs_c, - _ct16_rhs_m_in, - )[..., :_ct16_rhs_m].copy() - ct16_rhs.batch = ct16_rhs.polynomial.shape[0] - ct16_rhs.num_elements = ct16_rhs.polynomial.shape[1] - ct16_rhs.num_moduli = _ct16_rhs_m - ct16_rhs.degree_layout = (_ct16_rhs_r, _ct16_rhs_c) - ct16_rhs.r = _ct16_rhs_r - ct16_rhs.c = _ct16_rhs_c - ct16_rhs.moduli = list(_ct16_rhs_moduli)[:_ct16_rhs_m] - ct16_rhs.moduli_array = jnp.array( - ct16_rhs.moduli, dtype=getattr(ct16_rhs, "modulus_dtype", jnp.uint32) - ) - ct16.add(ct16_rhs) - _moduli = jnp.array(ct16.moduli, dtype=jnp.uint32) - ct16.polynomial = jnp.where( - ct16.polynomial >= _moduli, ct16.polynomial - _moduli, ct16.polynomial - ) - _ct17_data = ct5.polynomial if hasattr(ct5, "polynomial") else ct5 - _ct17_m_in = _ct17_data.shape[-1] - _ct17_m = _ct17_m_in - _ct17_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_r) - ) - _ct17_moduli = getattr(ct5, "moduli", v0.q_towers) - if isinstance(_ct17_moduli, (int, np.integer)): - _ct17_moduli = [int(_ct17_moduli)] - ct17 = Polynomial( - { - "batch": _ct17_data.shape[0], - "num_elements": _ct17_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_m, - "precision": 32, - "degree_layout": (_ct17_r, _ct17_c), - }, - {"moduli": list(_ct17_moduli)[:_ct17_m]}, - ) - ct17.polynomial = _ct17_data.reshape( - _ct17_data.shape[0], _ct17_data.shape[1], _ct17_r, _ct17_c, _ct17_m_in - )[..., :_ct17_m].copy() - ct17.batch = ct17.polynomial.shape[0] - ct17.num_elements = ct17.polynomial.shape[1] - ct17.num_moduli = _ct17_m - ct17.degree_layout = (_ct17_r, _ct17_c) - ct17.r = _ct17_r - ct17.c = _ct17_c - ct17.moduli = list(_ct17_moduli)[:_ct17_m] - ct17.moduli_array = jnp.array( - ct17.moduli, dtype=getattr(ct17, "modulus_dtype", jnp.uint32) - ) - _ct17_rhs_data = ct11.polynomial if hasattr(ct11, "polynomial") else ct11 - _ct17_rhs_m_in = _ct17_rhs_data.shape[-1] - _ct17_rhs_m = _ct17_rhs_m_in - _ct17_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct17_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct17_rhs_r) - ) - _ct17_rhs_moduli = getattr(ct11, "moduli", v0.q_towers) - if isinstance(_ct17_rhs_moduli, (int, np.integer)): - _ct17_rhs_moduli = [int(_ct17_rhs_moduli)] - ct17_rhs = Polynomial( - { - "batch": _ct17_rhs_data.shape[0], - "num_elements": _ct17_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct17_rhs_m, - "precision": 32, - "degree_layout": (_ct17_rhs_r, _ct17_rhs_c), - }, - {"moduli": list(_ct17_rhs_moduli)[:_ct17_rhs_m]}, - ) - ct17_rhs.polynomial = _ct17_rhs_data.reshape( - _ct17_rhs_data.shape[0], - _ct17_rhs_data.shape[1], - _ct17_rhs_r, - _ct17_rhs_c, - _ct17_rhs_m_in, - )[..., :_ct17_rhs_m].copy() - ct17_rhs.batch = ct17_rhs.polynomial.shape[0] - ct17_rhs.num_elements = ct17_rhs.polynomial.shape[1] - ct17_rhs.num_moduli = _ct17_rhs_m - ct17_rhs.degree_layout = (_ct17_rhs_r, _ct17_rhs_c) - ct17_rhs.r = _ct17_rhs_r - ct17_rhs.c = _ct17_rhs_c - ct17_rhs.moduli = list(_ct17_rhs_moduli)[:_ct17_rhs_m] - ct17_rhs.moduli_array = jnp.array( - ct17_rhs.moduli, dtype=getattr(ct17_rhs, "modulus_dtype", jnp.uint32) - ) - ct17.add(ct17_rhs) - _moduli = jnp.array(ct17.moduli, dtype=jnp.uint32) - ct17.polynomial = jnp.where( - ct17.polynomial >= _moduli, ct17.polynomial - _moduli, ct17.polynomial - ) - _ct18_data = ct17.polynomial if hasattr(ct17, "polynomial") else ct17 - _ct18_m_in = _ct18_data.shape[-1] - _ct18_m = _ct18_m_in - _ct18_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_r) - ) - _ct18_moduli = getattr(ct17, "moduli", v0.q_towers) - if isinstance(_ct18_moduli, (int, np.integer)): - _ct18_moduli = [int(_ct18_moduli)] - ct18 = Polynomial( - { - "batch": _ct18_data.shape[0], - "num_elements": _ct18_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_m, - "precision": 32, - "degree_layout": (_ct18_r, _ct18_c), - }, - {"moduli": list(_ct18_moduli)[:_ct18_m]}, - ) - ct18.polynomial = _ct18_data.reshape( - _ct18_data.shape[0], _ct18_data.shape[1], _ct18_r, _ct18_c, _ct18_m_in - )[..., :_ct18_m].copy() - ct18.batch = ct18.polynomial.shape[0] - ct18.num_elements = ct18.polynomial.shape[1] - ct18.num_moduli = _ct18_m - ct18.degree_layout = (_ct18_r, _ct18_c) - ct18.r = _ct18_r - ct18.c = _ct18_c - ct18.moduli = list(_ct18_moduli)[:_ct18_m] - ct18.moduli_array = jnp.array( - ct18.moduli, dtype=getattr(ct18, "modulus_dtype", jnp.uint32) - ) - _ct18_rhs_data = ct15.polynomial if hasattr(ct15, "polynomial") else ct15 - _ct18_rhs_m_in = _ct18_rhs_data.shape[-1] - _ct18_rhs_m = _ct18_rhs_m_in - _ct18_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct18_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct18_rhs_r) - ) - _ct18_rhs_moduli = getattr(ct15, "moduli", v0.q_towers) - if isinstance(_ct18_rhs_moduli, (int, np.integer)): - _ct18_rhs_moduli = [int(_ct18_rhs_moduli)] - ct18_rhs = Polynomial( - { - "batch": _ct18_rhs_data.shape[0], - "num_elements": _ct18_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct18_rhs_m, - "precision": 32, - "degree_layout": (_ct18_rhs_r, _ct18_rhs_c), - }, - {"moduli": list(_ct18_rhs_moduli)[:_ct18_rhs_m]}, - ) - ct18_rhs.polynomial = _ct18_rhs_data.reshape( - _ct18_rhs_data.shape[0], - _ct18_rhs_data.shape[1], - _ct18_rhs_r, - _ct18_rhs_c, - _ct18_rhs_m_in, - )[..., :_ct18_rhs_m].copy() - ct18_rhs.batch = ct18_rhs.polynomial.shape[0] - ct18_rhs.num_elements = ct18_rhs.polynomial.shape[1] - ct18_rhs.num_moduli = _ct18_rhs_m - ct18_rhs.degree_layout = (_ct18_rhs_r, _ct18_rhs_c) - ct18_rhs.r = _ct18_rhs_r - ct18_rhs.c = _ct18_rhs_c - ct18_rhs.moduli = list(_ct18_rhs_moduli)[:_ct18_rhs_m] - ct18_rhs.moduli_array = jnp.array( - ct18_rhs.moduli, dtype=getattr(ct18_rhs, "modulus_dtype", jnp.uint32) - ) - ct18.add(ct18_rhs) - _moduli = jnp.array(ct18.moduli, dtype=jnp.uint32) - ct18.polynomial = jnp.where( - ct18.polynomial >= _moduli, ct18.polynomial - _moduli, ct18.polynomial - ) - _ct19_data = ct16.polynomial if hasattr(ct16, "polynomial") else ct16 - _ct19_m_in = _ct19_data.shape[-1] - _ct19_m = _ct19_m_in - _ct19_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_r) - ) - _ct19_moduli = getattr(ct16, "moduli", v0.q_towers) - if isinstance(_ct19_moduli, (int, np.integer)): - _ct19_moduli = [int(_ct19_moduli)] - ct19 = Polynomial( - { - "batch": _ct19_data.shape[0], - "num_elements": _ct19_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_m, - "precision": 32, - "degree_layout": (_ct19_r, _ct19_c), - }, - {"moduli": list(_ct19_moduli)[:_ct19_m]}, - ) - ct19.polynomial = _ct19_data.reshape( - _ct19_data.shape[0], _ct19_data.shape[1], _ct19_r, _ct19_c, _ct19_m_in - )[..., :_ct19_m].copy() - ct19.batch = ct19.polynomial.shape[0] - ct19.num_elements = ct19.polynomial.shape[1] - ct19.num_moduli = _ct19_m - ct19.degree_layout = (_ct19_r, _ct19_c) - ct19.r = _ct19_r - ct19.c = _ct19_c - ct19.moduli = list(_ct19_moduli)[:_ct19_m] - ct19.moduli_array = jnp.array( - ct19.moduli, dtype=getattr(ct19, "modulus_dtype", jnp.uint32) - ) - _ct19_rhs_data = ct18.polynomial if hasattr(ct18, "polynomial") else ct18 - _ct19_rhs_m_in = _ct19_rhs_data.shape[-1] - _ct19_rhs_m = _ct19_rhs_m_in - _ct19_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct19_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct19_rhs_r) - ) - _ct19_rhs_moduli = getattr(ct18, "moduli", v0.q_towers) - if isinstance(_ct19_rhs_moduli, (int, np.integer)): - _ct19_rhs_moduli = [int(_ct19_rhs_moduli)] - ct19_rhs = Polynomial( - { - "batch": _ct19_rhs_data.shape[0], - "num_elements": _ct19_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct19_rhs_m, - "precision": 32, - "degree_layout": (_ct19_rhs_r, _ct19_rhs_c), - }, - {"moduli": list(_ct19_rhs_moduli)[:_ct19_rhs_m]}, - ) - ct19_rhs.polynomial = _ct19_rhs_data.reshape( - _ct19_rhs_data.shape[0], - _ct19_rhs_data.shape[1], - _ct19_rhs_r, - _ct19_rhs_c, - _ct19_rhs_m_in, - )[..., :_ct19_rhs_m].copy() - ct19_rhs.batch = ct19_rhs.polynomial.shape[0] - ct19_rhs.num_elements = ct19_rhs.polynomial.shape[1] - ct19_rhs.num_moduli = _ct19_rhs_m - ct19_rhs.degree_layout = (_ct19_rhs_r, _ct19_rhs_c) - ct19_rhs.r = _ct19_rhs_r - ct19_rhs.c = _ct19_rhs_c - ct19_rhs.moduli = list(_ct19_rhs_moduli)[:_ct19_rhs_m] - ct19_rhs.moduli_array = jnp.array( - ct19_rhs.moduli, dtype=getattr(ct19_rhs, "modulus_dtype", jnp.uint32) - ) - ct19.add(ct19_rhs) - _moduli = jnp.array(ct19.moduli, dtype=jnp.uint32) - ct19.polynomial = jnp.where( - ct19.polynomial >= _moduli, ct19.polynomial - _moduli, ct19.polynomial - ) - _ct20_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 - _ct20_arg_m_in = _ct20_arg_data.shape[-1] - _ct20_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct20_arg_m_in - ) - _ct20_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_arg_r) - ) - _ct20_arg_moduli = getattr(ct19, "moduli", v0.q_towers) - if isinstance(_ct20_arg_moduli, (int, np.integer)): - _ct20_arg_moduli = [int(_ct20_arg_moduli)] - ct20_arg = Polynomial( - { - "batch": _ct20_arg_data.shape[0], - "num_elements": _ct20_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_arg_m, - "precision": 32, - "degree_layout": (_ct20_arg_r, _ct20_arg_c), - }, - {"moduli": list(_ct20_arg_moduli)[:_ct20_arg_m]}, - ) - ct20_arg.polynomial = _ct20_arg_data.reshape( - _ct20_arg_data.shape[0], - _ct20_arg_data.shape[1], - _ct20_arg_r, - _ct20_arg_c, - _ct20_arg_m_in, - )[..., :_ct20_arg_m].copy() - ct20_arg.batch = ct20_arg.polynomial.shape[0] - ct20_arg.num_elements = ct20_arg.polynomial.shape[1] - ct20_arg.num_moduli = _ct20_arg_m - ct20_arg.degree_layout = (_ct20_arg_r, _ct20_arg_c) - ct20_arg.r = _ct20_arg_r - ct20_arg.c = _ct20_arg_c - ct20_arg.moduli = list(_ct20_arg_moduli)[:_ct20_arg_m] - ct20_arg.moduli_array = jnp.array( - ct20_arg.moduli, dtype=getattr(ct20_arg, "modulus_dtype", jnp.uint32) - ) - ct20_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct20_arg) - _ct20_data = ( - ct20_raw.polynomial if hasattr(ct20_raw, "polynomial") else ct20_raw - ) - _ct20_m_in = _ct20_data.shape[-1] - _ct20_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct20_m_in - ) - _ct20_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct20_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct20_r) - ) - _ct20_moduli = getattr(ct20_raw, "moduli", v0.q_towers) - if isinstance(_ct20_moduli, (int, np.integer)): - _ct20_moduli = [int(_ct20_moduli)] - ct20 = Polynomial( - { - "batch": _ct20_data.shape[0], - "num_elements": _ct20_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct20_m, - "precision": 32, - "degree_layout": (_ct20_r, _ct20_c), - }, - {"moduli": list(_ct20_moduli)[:_ct20_m]}, - ) - ct20.polynomial = _ct20_data.reshape( - _ct20_data.shape[0], _ct20_data.shape[1], _ct20_r, _ct20_c, _ct20_m_in - )[..., :_ct20_m].copy() - ct20.batch = ct20.polynomial.shape[0] - ct20.num_elements = ct20.polynomial.shape[1] - ct20.num_moduli = _ct20_m - ct20.degree_layout = (_ct20_r, _ct20_c) - ct20.r = _ct20_r - ct20.c = _ct20_c - ct20.moduli = list(_ct20_moduli)[:_ct20_m] - ct20.moduli_array = jnp.array( - ct20.moduli, dtype=getattr(ct20, "modulus_dtype", jnp.uint32) - ) - _ct21_arg_data = ct20.polynomial if hasattr(ct20, "polynomial") else ct20 - _ct21_arg_m_in = _ct21_arg_data.shape[-1] - _ct21_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct21_arg_m_in - ) - _ct21_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct21_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct21_arg_r) - ) - _ct21_arg_moduli = getattr(ct20, "moduli", v0.q_towers) - if isinstance(_ct21_arg_moduli, (int, np.integer)): - _ct21_arg_moduli = [int(_ct21_arg_moduli)] - ct21_arg = Polynomial( - { - "batch": _ct21_arg_data.shape[0], - "num_elements": _ct21_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct21_arg_m, - "precision": 32, - "degree_layout": (_ct21_arg_r, _ct21_arg_c), - }, - {"moduli": list(_ct21_arg_moduli)[:_ct21_arg_m]}, - ) - ct21_arg.polynomial = _ct21_arg_data.reshape( - _ct21_arg_data.shape[0], - _ct21_arg_data.shape[1], - _ct21_arg_r, - _ct21_arg_c, - _ct21_arg_m_in, - )[..., :_ct21_arg_m].copy() - ct21_arg.batch = ct21_arg.polynomial.shape[0] - ct21_arg.num_elements = ct21_arg.polynomial.shape[1] - ct21_arg.num_moduli = _ct21_arg_m - ct21_arg.degree_layout = (_ct21_arg_r, _ct21_arg_c) - ct21_arg.r = _ct21_arg_r - ct21_arg.c = _ct21_arg_c - ct21_arg.moduli = list(_ct21_arg_moduli)[:_ct21_arg_m] - ct21_arg.moduli_array = jnp.array( - ct21_arg.moduli, dtype=getattr(ct21_arg, "modulus_dtype", jnp.uint32) - ) - ct21_pt_ntt = ( - pt8.polynomial[0, 0, :, : ct21_arg.polynomial.shape[-1]] - .reshape(ct21_arg.r, ct21_arg.c, ct21_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct21_ptct = v0.ptct_mul[v0.max_level - 1] - ct21_ptct.set_plaintext(ct21_pt_ntt) - ct21_raw = ct21_ptct.mul(ct21_arg, use_bat=False) - _ct21_data = ( - ct21_raw.polynomial if hasattr(ct21_raw, "polynomial") else ct21_raw - ) - _ct21_m_in = _ct21_data.shape[-1] - _ct21_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct21_m_in - ) - _ct21_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct21_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct21_r) - ) - _ct21_moduli = getattr(ct21_raw, "moduli", v0.q_towers) - if isinstance(_ct21_moduli, (int, np.integer)): - _ct21_moduli = [int(_ct21_moduli)] - ct21 = Polynomial( - { - "batch": _ct21_data.shape[0], - "num_elements": _ct21_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct21_m, - "precision": 32, - "degree_layout": (_ct21_r, _ct21_c), - }, - {"moduli": list(_ct21_moduli)[:_ct21_m]}, - ) - ct21.polynomial = _ct21_data.reshape( - _ct21_data.shape[0], _ct21_data.shape[1], _ct21_r, _ct21_c, _ct21_m_in - )[..., :_ct21_m].copy() - ct21.batch = ct21.polynomial.shape[0] - ct21.num_elements = ct21.polynomial.shape[1] - ct21.num_moduli = _ct21_m - ct21.degree_layout = (_ct21_r, _ct21_c) - ct21.r = _ct21_r - ct21.c = _ct21_c - ct21.moduli = list(_ct21_moduli)[:_ct21_m] - ct21.moduli_array = jnp.array( - ct21.moduli, dtype=getattr(ct21, "modulus_dtype", jnp.uint32) - ) - _ct22_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 - _ct22_arg_m_in = _ct22_arg_data.shape[-1] - _ct22_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct22_arg_m_in - ) - _ct22_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct22_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct22_arg_r) - ) - _ct22_arg_moduli = getattr(ct19, "moduli", v0.q_towers) - if isinstance(_ct22_arg_moduli, (int, np.integer)): - _ct22_arg_moduli = [int(_ct22_arg_moduli)] - ct22_arg = Polynomial( - { - "batch": _ct22_arg_data.shape[0], - "num_elements": _ct22_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct22_arg_m, - "precision": 32, - "degree_layout": (_ct22_arg_r, _ct22_arg_c), - }, - {"moduli": list(_ct22_arg_moduli)[:_ct22_arg_m]}, - ) - ct22_arg.polynomial = _ct22_arg_data.reshape( - _ct22_arg_data.shape[0], - _ct22_arg_data.shape[1], - _ct22_arg_r, - _ct22_arg_c, - _ct22_arg_m_in, - )[..., :_ct22_arg_m].copy() - ct22_arg.batch = ct22_arg.polynomial.shape[0] - ct22_arg.num_elements = ct22_arg.polynomial.shape[1] - ct22_arg.num_moduli = _ct22_arg_m - ct22_arg.degree_layout = (_ct22_arg_r, _ct22_arg_c) - ct22_arg.r = _ct22_arg_r - ct22_arg.c = _ct22_arg_c - ct22_arg.moduli = list(_ct22_arg_moduli)[:_ct22_arg_m] - ct22_arg.moduli_array = jnp.array( - ct22_arg.moduli, dtype=getattr(ct22_arg, "modulus_dtype", jnp.uint32) - ) - ct22_raw = v0.he_rot[v0.max_level, 1].rotate(ct22_arg) - _ct22_data = ( - ct22_raw.polynomial if hasattr(ct22_raw, "polynomial") else ct22_raw - ) - _ct22_m_in = _ct22_data.shape[-1] - _ct22_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct22_m_in - ) - _ct22_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct22_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct22_r) - ) - _ct22_moduli = getattr(ct22_raw, "moduli", v0.q_towers) - if isinstance(_ct22_moduli, (int, np.integer)): - _ct22_moduli = [int(_ct22_moduli)] - ct22 = Polynomial( - { - "batch": _ct22_data.shape[0], - "num_elements": _ct22_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct22_m, - "precision": 32, - "degree_layout": (_ct22_r, _ct22_c), - }, - {"moduli": list(_ct22_moduli)[:_ct22_m]}, - ) - ct22.polynomial = _ct22_data.reshape( - _ct22_data.shape[0], _ct22_data.shape[1], _ct22_r, _ct22_c, _ct22_m_in - )[..., :_ct22_m].copy() - ct22.batch = ct22.polynomial.shape[0] - ct22.num_elements = ct22.polynomial.shape[1] - ct22.num_moduli = _ct22_m - ct22.degree_layout = (_ct22_r, _ct22_c) - ct22.r = _ct22_r - ct22.c = _ct22_c - ct22.moduli = list(_ct22_moduli)[:_ct22_m] - ct22.moduli_array = jnp.array( - ct22.moduli, dtype=getattr(ct22, "modulus_dtype", jnp.uint32) - ) - _ct23_arg_data = ct22.polynomial if hasattr(ct22, "polynomial") else ct22 - _ct23_arg_m_in = _ct23_arg_data.shape[-1] - _ct23_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct23_arg_m_in - ) - _ct23_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct23_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct23_arg_r) - ) - _ct23_arg_moduli = getattr(ct22, "moduli", v0.q_towers) - if isinstance(_ct23_arg_moduli, (int, np.integer)): - _ct23_arg_moduli = [int(_ct23_arg_moduli)] - ct23_arg = Polynomial( - { - "batch": _ct23_arg_data.shape[0], - "num_elements": _ct23_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct23_arg_m, - "precision": 32, - "degree_layout": (_ct23_arg_r, _ct23_arg_c), - }, - {"moduli": list(_ct23_arg_moduli)[:_ct23_arg_m]}, - ) - ct23_arg.polynomial = _ct23_arg_data.reshape( - _ct23_arg_data.shape[0], - _ct23_arg_data.shape[1], - _ct23_arg_r, - _ct23_arg_c, - _ct23_arg_m_in, - )[..., :_ct23_arg_m].copy() - ct23_arg.batch = ct23_arg.polynomial.shape[0] - ct23_arg.num_elements = ct23_arg.polynomial.shape[1] - ct23_arg.num_moduli = _ct23_arg_m - ct23_arg.degree_layout = (_ct23_arg_r, _ct23_arg_c) - ct23_arg.r = _ct23_arg_r - ct23_arg.c = _ct23_arg_c - ct23_arg.moduli = list(_ct23_arg_moduli)[:_ct23_arg_m] - ct23_arg.moduli_array = jnp.array( - ct23_arg.moduli, dtype=getattr(ct23_arg, "modulus_dtype", jnp.uint32) - ) - ct23_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct23_arg) - _ct23_data = ( - ct23_raw.polynomial if hasattr(ct23_raw, "polynomial") else ct23_raw - ) - _ct23_m_in = _ct23_data.shape[-1] - _ct23_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct23_m_in - ) - _ct23_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct23_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct23_r) - ) - _ct23_moduli = getattr(ct23_raw, "moduli", v0.q_towers) - if isinstance(_ct23_moduli, (int, np.integer)): - _ct23_moduli = [int(_ct23_moduli)] - ct23 = Polynomial( - { - "batch": _ct23_data.shape[0], - "num_elements": _ct23_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct23_m, - "precision": 32, - "degree_layout": (_ct23_r, _ct23_c), - }, - {"moduli": list(_ct23_moduli)[:_ct23_m]}, - ) - ct23.polynomial = _ct23_data.reshape( - _ct23_data.shape[0], _ct23_data.shape[1], _ct23_r, _ct23_c, _ct23_m_in - )[..., :_ct23_m].copy() - ct23.batch = ct23.polynomial.shape[0] - ct23.num_elements = ct23.polynomial.shape[1] - ct23.num_moduli = _ct23_m - ct23.degree_layout = (_ct23_r, _ct23_c) - ct23.r = _ct23_r - ct23.c = _ct23_c - ct23.moduli = list(_ct23_moduli)[:_ct23_m] - ct23.moduli_array = jnp.array( - ct23.moduli, dtype=getattr(ct23, "modulus_dtype", jnp.uint32) - ) - _ct24_arg_data = ct23.polynomial if hasattr(ct23, "polynomial") else ct23 - _ct24_arg_m_in = _ct24_arg_data.shape[-1] - _ct24_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct24_arg_m_in - ) - _ct24_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct24_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct24_arg_r) - ) - _ct24_arg_moduli = getattr(ct23, "moduli", v0.q_towers) - if isinstance(_ct24_arg_moduli, (int, np.integer)): - _ct24_arg_moduli = [int(_ct24_arg_moduli)] - ct24_arg = Polynomial( - { - "batch": _ct24_arg_data.shape[0], - "num_elements": _ct24_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct24_arg_m, - "precision": 32, - "degree_layout": (_ct24_arg_r, _ct24_arg_c), - }, - {"moduli": list(_ct24_arg_moduli)[:_ct24_arg_m]}, - ) - ct24_arg.polynomial = _ct24_arg_data.reshape( - _ct24_arg_data.shape[0], - _ct24_arg_data.shape[1], - _ct24_arg_r, - _ct24_arg_c, - _ct24_arg_m_in, - )[..., :_ct24_arg_m].copy() - ct24_arg.batch = ct24_arg.polynomial.shape[0] - ct24_arg.num_elements = ct24_arg.polynomial.shape[1] - ct24_arg.num_moduli = _ct24_arg_m - ct24_arg.degree_layout = (_ct24_arg_r, _ct24_arg_c) - ct24_arg.r = _ct24_arg_r - ct24_arg.c = _ct24_arg_c - ct24_arg.moduli = list(_ct24_arg_moduli)[:_ct24_arg_m] - ct24_arg.moduli_array = jnp.array( - ct24_arg.moduli, dtype=getattr(ct24_arg, "modulus_dtype", jnp.uint32) - ) - ct24_pt_ntt = ( - pt9.polynomial[0, 0, :, : ct24_arg.polynomial.shape[-1]] - .reshape(ct24_arg.r, ct24_arg.c, ct24_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct24_ptct = v0.ptct_mul[v0.max_level - 1] - ct24_ptct.set_plaintext(ct24_pt_ntt) - ct24_raw = ct24_ptct.mul(ct24_arg, use_bat=False) - _ct24_data = ( - ct24_raw.polynomial if hasattr(ct24_raw, "polynomial") else ct24_raw - ) - _ct24_m_in = _ct24_data.shape[-1] - _ct24_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct24_m_in - ) - _ct24_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct24_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct24_r) - ) - _ct24_moduli = getattr(ct24_raw, "moduli", v0.q_towers) - if isinstance(_ct24_moduli, (int, np.integer)): - _ct24_moduli = [int(_ct24_moduli)] - ct24 = Polynomial( - { - "batch": _ct24_data.shape[0], - "num_elements": _ct24_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct24_m, - "precision": 32, - "degree_layout": (_ct24_r, _ct24_c), - }, - {"moduli": list(_ct24_moduli)[:_ct24_m]}, - ) - ct24.polynomial = _ct24_data.reshape( - _ct24_data.shape[0], _ct24_data.shape[1], _ct24_r, _ct24_c, _ct24_m_in - )[..., :_ct24_m].copy() - ct24.batch = ct24.polynomial.shape[0] - ct24.num_elements = ct24.polynomial.shape[1] - ct24.num_moduli = _ct24_m - ct24.degree_layout = (_ct24_r, _ct24_c) - ct24.r = _ct24_r - ct24.c = _ct24_c - ct24.moduli = list(_ct24_moduli)[:_ct24_m] - ct24.moduli_array = jnp.array( - ct24.moduli, dtype=getattr(ct24, "modulus_dtype", jnp.uint32) - ) - _ct25_arg_data = ct19.polynomial if hasattr(ct19, "polynomial") else ct19 - _ct25_arg_m_in = _ct25_arg_data.shape[-1] - _ct25_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct25_arg_m_in - ) - _ct25_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct25_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct25_arg_r) - ) - _ct25_arg_moduli = getattr(ct19, "moduli", v0.q_towers) - if isinstance(_ct25_arg_moduli, (int, np.integer)): - _ct25_arg_moduli = [int(_ct25_arg_moduli)] - ct25_arg = Polynomial( - { - "batch": _ct25_arg_data.shape[0], - "num_elements": _ct25_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct25_arg_m, - "precision": 32, - "degree_layout": (_ct25_arg_r, _ct25_arg_c), - }, - {"moduli": list(_ct25_arg_moduli)[:_ct25_arg_m]}, - ) - ct25_arg.polynomial = _ct25_arg_data.reshape( - _ct25_arg_data.shape[0], - _ct25_arg_data.shape[1], - _ct25_arg_r, - _ct25_arg_c, - _ct25_arg_m_in, - )[..., :_ct25_arg_m].copy() - ct25_arg.batch = ct25_arg.polynomial.shape[0] - ct25_arg.num_elements = ct25_arg.polynomial.shape[1] - ct25_arg.num_moduli = _ct25_arg_m - ct25_arg.degree_layout = (_ct25_arg_r, _ct25_arg_c) - ct25_arg.r = _ct25_arg_r - ct25_arg.c = _ct25_arg_c - ct25_arg.moduli = list(_ct25_arg_moduli)[:_ct25_arg_m] - ct25_arg.moduli_array = jnp.array( - ct25_arg.moduli, dtype=getattr(ct25_arg, "modulus_dtype", jnp.uint32) - ) - ct25_raw = v0.he_rot[v0.max_level, 2].rotate(ct25_arg) - _ct25_data = ( - ct25_raw.polynomial if hasattr(ct25_raw, "polynomial") else ct25_raw - ) - _ct25_m_in = _ct25_data.shape[-1] - _ct25_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct25_m_in - ) - _ct25_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct25_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct25_r) - ) - _ct25_moduli = getattr(ct25_raw, "moduli", v0.q_towers) - if isinstance(_ct25_moduli, (int, np.integer)): - _ct25_moduli = [int(_ct25_moduli)] - ct25 = Polynomial( - { - "batch": _ct25_data.shape[0], - "num_elements": _ct25_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct25_m, - "precision": 32, - "degree_layout": (_ct25_r, _ct25_c), - }, - {"moduli": list(_ct25_moduli)[:_ct25_m]}, - ) - ct25.polynomial = _ct25_data.reshape( - _ct25_data.shape[0], _ct25_data.shape[1], _ct25_r, _ct25_c, _ct25_m_in - )[..., :_ct25_m].copy() - ct25.batch = ct25.polynomial.shape[0] - ct25.num_elements = ct25.polynomial.shape[1] - ct25.num_moduli = _ct25_m - ct25.degree_layout = (_ct25_r, _ct25_c) - ct25.r = _ct25_r - ct25.c = _ct25_c - ct25.moduli = list(_ct25_moduli)[:_ct25_m] - ct25.moduli_array = jnp.array( - ct25.moduli, dtype=getattr(ct25, "modulus_dtype", jnp.uint32) - ) - _ct26_arg_data = ct25.polynomial if hasattr(ct25, "polynomial") else ct25 - _ct26_arg_m_in = _ct26_arg_data.shape[-1] - _ct26_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level) - if hasattr(v0, "_param_cache") - else _ct26_arg_m_in - ) - _ct26_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct26_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct26_arg_r) - ) - _ct26_arg_moduli = getattr(ct25, "moduli", v0.q_towers) - if isinstance(_ct26_arg_moduli, (int, np.integer)): - _ct26_arg_moduli = [int(_ct26_arg_moduli)] - ct26_arg = Polynomial( - { - "batch": _ct26_arg_data.shape[0], - "num_elements": _ct26_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct26_arg_m, - "precision": 32, - "degree_layout": (_ct26_arg_r, _ct26_arg_c), - }, - {"moduli": list(_ct26_arg_moduli)[:_ct26_arg_m]}, - ) - ct26_arg.polynomial = _ct26_arg_data.reshape( - _ct26_arg_data.shape[0], - _ct26_arg_data.shape[1], - _ct26_arg_r, - _ct26_arg_c, - _ct26_arg_m_in, - )[..., :_ct26_arg_m].copy() - ct26_arg.batch = ct26_arg.polynomial.shape[0] - ct26_arg.num_elements = ct26_arg.polynomial.shape[1] - ct26_arg.num_moduli = _ct26_arg_m - ct26_arg.degree_layout = (_ct26_arg_r, _ct26_arg_c) - ct26_arg.r = _ct26_arg_r - ct26_arg.c = _ct26_arg_c - ct26_arg.moduli = list(_ct26_arg_moduli)[:_ct26_arg_m] - ct26_arg.moduli_array = jnp.array( - ct26_arg.moduli, dtype=getattr(ct26_arg, "modulus_dtype", jnp.uint32) - ) - ct26_raw = v0.he_rescale[v0.max_level, v0.max_level - 1](ct26_arg) - _ct26_data = ( - ct26_raw.polynomial if hasattr(ct26_raw, "polynomial") else ct26_raw - ) - _ct26_m_in = _ct26_data.shape[-1] - _ct26_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct26_m_in - ) - _ct26_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct26_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct26_r) - ) - _ct26_moduli = getattr(ct26_raw, "moduli", v0.q_towers) - if isinstance(_ct26_moduli, (int, np.integer)): - _ct26_moduli = [int(_ct26_moduli)] - ct26 = Polynomial( - { - "batch": _ct26_data.shape[0], - "num_elements": _ct26_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct26_m, - "precision": 32, - "degree_layout": (_ct26_r, _ct26_c), - }, - {"moduli": list(_ct26_moduli)[:_ct26_m]}, - ) - ct26.polynomial = _ct26_data.reshape( - _ct26_data.shape[0], _ct26_data.shape[1], _ct26_r, _ct26_c, _ct26_m_in - )[..., :_ct26_m].copy() - ct26.batch = ct26.polynomial.shape[0] - ct26.num_elements = ct26.polynomial.shape[1] - ct26.num_moduli = _ct26_m - ct26.degree_layout = (_ct26_r, _ct26_c) - ct26.r = _ct26_r - ct26.c = _ct26_c - ct26.moduli = list(_ct26_moduli)[:_ct26_m] - ct26.moduli_array = jnp.array( - ct26.moduli, dtype=getattr(ct26, "modulus_dtype", jnp.uint32) - ) - _ct27_arg_data = ct26.polynomial if hasattr(ct26, "polynomial") else ct26 - _ct27_arg_m_in = _ct27_arg_data.shape[-1] - _ct27_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct27_arg_m_in - ) - _ct27_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct27_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct27_arg_r) - ) - _ct27_arg_moduli = getattr(ct26, "moduli", v0.q_towers) - if isinstance(_ct27_arg_moduli, (int, np.integer)): - _ct27_arg_moduli = [int(_ct27_arg_moduli)] - ct27_arg = Polynomial( - { - "batch": _ct27_arg_data.shape[0], - "num_elements": _ct27_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct27_arg_m, - "precision": 32, - "degree_layout": (_ct27_arg_r, _ct27_arg_c), - }, - {"moduli": list(_ct27_arg_moduli)[:_ct27_arg_m]}, - ) - ct27_arg.polynomial = _ct27_arg_data.reshape( - _ct27_arg_data.shape[0], - _ct27_arg_data.shape[1], - _ct27_arg_r, - _ct27_arg_c, - _ct27_arg_m_in, - )[..., :_ct27_arg_m].copy() - ct27_arg.batch = ct27_arg.polynomial.shape[0] - ct27_arg.num_elements = ct27_arg.polynomial.shape[1] - ct27_arg.num_moduli = _ct27_arg_m - ct27_arg.degree_layout = (_ct27_arg_r, _ct27_arg_c) - ct27_arg.r = _ct27_arg_r - ct27_arg.c = _ct27_arg_c - ct27_arg.moduli = list(_ct27_arg_moduli)[:_ct27_arg_m] - ct27_arg.moduli_array = jnp.array( - ct27_arg.moduli, dtype=getattr(ct27_arg, "modulus_dtype", jnp.uint32) - ) - ct27_pt_ntt = ( - pt10.polynomial[0, 0, :, : ct27_arg.polynomial.shape[-1]] - .reshape(ct27_arg.r, ct27_arg.c, ct27_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct27_ptct = v0.ptct_mul[v0.max_level - 1] - ct27_ptct.set_plaintext(ct27_pt_ntt) - ct27_raw = ct27_ptct.mul(ct27_arg, use_bat=False) - _ct27_data = ( - ct27_raw.polynomial if hasattr(ct27_raw, "polynomial") else ct27_raw - ) - _ct27_m_in = _ct27_data.shape[-1] - _ct27_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct27_m_in - ) - _ct27_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct27_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct27_r) - ) - _ct27_moduli = getattr(ct27_raw, "moduli", v0.q_towers) - if isinstance(_ct27_moduli, (int, np.integer)): - _ct27_moduli = [int(_ct27_moduli)] - ct27 = Polynomial( - { - "batch": _ct27_data.shape[0], - "num_elements": _ct27_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct27_m, - "precision": 32, - "degree_layout": (_ct27_r, _ct27_c), - }, - {"moduli": list(_ct27_moduli)[:_ct27_m]}, - ) - ct27.polynomial = _ct27_data.reshape( - _ct27_data.shape[0], _ct27_data.shape[1], _ct27_r, _ct27_c, _ct27_m_in - )[..., :_ct27_m].copy() - ct27.batch = ct27.polynomial.shape[0] - ct27.num_elements = ct27.polynomial.shape[1] - ct27.num_moduli = _ct27_m - ct27.degree_layout = (_ct27_r, _ct27_c) - ct27.r = _ct27_r - ct27.c = _ct27_c - ct27.moduli = list(_ct27_moduli)[:_ct27_m] - ct27.moduli_array = jnp.array( - ct27.moduli, dtype=getattr(ct27, "modulus_dtype", jnp.uint32) - ) - _ct28_arg_data = ct20.polynomial if hasattr(ct20, "polynomial") else ct20 - _ct28_arg_m_in = _ct28_arg_data.shape[-1] - _ct28_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct28_arg_m_in - ) - _ct28_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct28_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct28_arg_r) - ) - _ct28_arg_moduli = getattr(ct20, "moduli", v0.q_towers) - if isinstance(_ct28_arg_moduli, (int, np.integer)): - _ct28_arg_moduli = [int(_ct28_arg_moduli)] - ct28_arg = Polynomial( - { - "batch": _ct28_arg_data.shape[0], - "num_elements": _ct28_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct28_arg_m, - "precision": 32, - "degree_layout": (_ct28_arg_r, _ct28_arg_c), - }, - {"moduli": list(_ct28_arg_moduli)[:_ct28_arg_m]}, - ) - ct28_arg.polynomial = _ct28_arg_data.reshape( - _ct28_arg_data.shape[0], - _ct28_arg_data.shape[1], - _ct28_arg_r, - _ct28_arg_c, - _ct28_arg_m_in, - )[..., :_ct28_arg_m].copy() - ct28_arg.batch = ct28_arg.polynomial.shape[0] - ct28_arg.num_elements = ct28_arg.polynomial.shape[1] - ct28_arg.num_moduli = _ct28_arg_m - ct28_arg.degree_layout = (_ct28_arg_r, _ct28_arg_c) - ct28_arg.r = _ct28_arg_r - ct28_arg.c = _ct28_arg_c - ct28_arg.moduli = list(_ct28_arg_moduli)[:_ct28_arg_m] - ct28_arg.moduli_array = jnp.array( - ct28_arg.moduli, dtype=getattr(ct28_arg, "modulus_dtype", jnp.uint32) - ) - ct28_pt_ntt = ( - pt11.polynomial[0, 0, :, : ct28_arg.polynomial.shape[-1]] - .reshape(ct28_arg.r, ct28_arg.c, ct28_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct28_ptct = v0.ptct_mul[v0.max_level - 1] - ct28_ptct.set_plaintext(ct28_pt_ntt) - ct28_raw = ct28_ptct.mul(ct28_arg, use_bat=False) - _ct28_data = ( - ct28_raw.polynomial if hasattr(ct28_raw, "polynomial") else ct28_raw - ) - _ct28_m_in = _ct28_data.shape[-1] - _ct28_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct28_m_in - ) - _ct28_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct28_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct28_r) - ) - _ct28_moduli = getattr(ct28_raw, "moduli", v0.q_towers) - if isinstance(_ct28_moduli, (int, np.integer)): - _ct28_moduli = [int(_ct28_moduli)] - ct28 = Polynomial( - { - "batch": _ct28_data.shape[0], - "num_elements": _ct28_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct28_m, - "precision": 32, - "degree_layout": (_ct28_r, _ct28_c), - }, - {"moduli": list(_ct28_moduli)[:_ct28_m]}, - ) - ct28.polynomial = _ct28_data.reshape( - _ct28_data.shape[0], _ct28_data.shape[1], _ct28_r, _ct28_c, _ct28_m_in - )[..., :_ct28_m].copy() - ct28.batch = ct28.polynomial.shape[0] - ct28.num_elements = ct28.polynomial.shape[1] - ct28.num_moduli = _ct28_m - ct28.degree_layout = (_ct28_r, _ct28_c) - ct28.r = _ct28_r - ct28.c = _ct28_c - ct28.moduli = list(_ct28_moduli)[:_ct28_m] - ct28.moduli_array = jnp.array( - ct28.moduli, dtype=getattr(ct28, "modulus_dtype", jnp.uint32) - ) - _ct29_arg_data = ct23.polynomial if hasattr(ct23, "polynomial") else ct23 - _ct29_arg_m_in = _ct29_arg_data.shape[-1] - _ct29_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct29_arg_m_in - ) - _ct29_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct29_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct29_arg_r) - ) - _ct29_arg_moduli = getattr(ct23, "moduli", v0.q_towers) - if isinstance(_ct29_arg_moduli, (int, np.integer)): - _ct29_arg_moduli = [int(_ct29_arg_moduli)] - ct29_arg = Polynomial( - { - "batch": _ct29_arg_data.shape[0], - "num_elements": _ct29_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct29_arg_m, - "precision": 32, - "degree_layout": (_ct29_arg_r, _ct29_arg_c), - }, - {"moduli": list(_ct29_arg_moduli)[:_ct29_arg_m]}, - ) - ct29_arg.polynomial = _ct29_arg_data.reshape( - _ct29_arg_data.shape[0], - _ct29_arg_data.shape[1], - _ct29_arg_r, - _ct29_arg_c, - _ct29_arg_m_in, - )[..., :_ct29_arg_m].copy() - ct29_arg.batch = ct29_arg.polynomial.shape[0] - ct29_arg.num_elements = ct29_arg.polynomial.shape[1] - ct29_arg.num_moduli = _ct29_arg_m - ct29_arg.degree_layout = (_ct29_arg_r, _ct29_arg_c) - ct29_arg.r = _ct29_arg_r - ct29_arg.c = _ct29_arg_c - ct29_arg.moduli = list(_ct29_arg_moduli)[:_ct29_arg_m] - ct29_arg.moduli_array = jnp.array( - ct29_arg.moduli, dtype=getattr(ct29_arg, "modulus_dtype", jnp.uint32) - ) - ct29_pt_ntt = ( - pt12.polynomial[0, 0, :, : ct29_arg.polynomial.shape[-1]] - .reshape(ct29_arg.r, ct29_arg.c, ct29_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct29_ptct = v0.ptct_mul[v0.max_level - 1] - ct29_ptct.set_plaintext(ct29_pt_ntt) - ct29_raw = ct29_ptct.mul(ct29_arg, use_bat=False) - _ct29_data = ( - ct29_raw.polynomial if hasattr(ct29_raw, "polynomial") else ct29_raw - ) - _ct29_m_in = _ct29_data.shape[-1] - _ct29_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct29_m_in - ) - _ct29_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct29_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct29_r) - ) - _ct29_moduli = getattr(ct29_raw, "moduli", v0.q_towers) - if isinstance(_ct29_moduli, (int, np.integer)): - _ct29_moduli = [int(_ct29_moduli)] - ct29 = Polynomial( - { - "batch": _ct29_data.shape[0], - "num_elements": _ct29_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct29_m, - "precision": 32, - "degree_layout": (_ct29_r, _ct29_c), - }, - {"moduli": list(_ct29_moduli)[:_ct29_m]}, - ) - ct29.polynomial = _ct29_data.reshape( - _ct29_data.shape[0], _ct29_data.shape[1], _ct29_r, _ct29_c, _ct29_m_in - )[..., :_ct29_m].copy() - ct29.batch = ct29.polynomial.shape[0] - ct29.num_elements = ct29.polynomial.shape[1] - ct29.num_moduli = _ct29_m - ct29.degree_layout = (_ct29_r, _ct29_c) - ct29.r = _ct29_r - ct29.c = _ct29_c - ct29.moduli = list(_ct29_moduli)[:_ct29_m] - ct29.moduli_array = jnp.array( - ct29.moduli, dtype=getattr(ct29, "modulus_dtype", jnp.uint32) - ) - _ct30_arg_data = ct26.polynomial if hasattr(ct26, "polynomial") else ct26 - _ct30_arg_m_in = _ct30_arg_data.shape[-1] - _ct30_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct30_arg_m_in - ) - _ct30_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct30_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct30_arg_r) - ) - _ct30_arg_moduli = getattr(ct26, "moduli", v0.q_towers) - if isinstance(_ct30_arg_moduli, (int, np.integer)): - _ct30_arg_moduli = [int(_ct30_arg_moduli)] - ct30_arg = Polynomial( - { - "batch": _ct30_arg_data.shape[0], - "num_elements": _ct30_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct30_arg_m, - "precision": 32, - "degree_layout": (_ct30_arg_r, _ct30_arg_c), - }, - {"moduli": list(_ct30_arg_moduli)[:_ct30_arg_m]}, - ) - ct30_arg.polynomial = _ct30_arg_data.reshape( - _ct30_arg_data.shape[0], - _ct30_arg_data.shape[1], - _ct30_arg_r, - _ct30_arg_c, - _ct30_arg_m_in, - )[..., :_ct30_arg_m].copy() - ct30_arg.batch = ct30_arg.polynomial.shape[0] - ct30_arg.num_elements = ct30_arg.polynomial.shape[1] - ct30_arg.num_moduli = _ct30_arg_m - ct30_arg.degree_layout = (_ct30_arg_r, _ct30_arg_c) - ct30_arg.r = _ct30_arg_r - ct30_arg.c = _ct30_arg_c - ct30_arg.moduli = list(_ct30_arg_moduli)[:_ct30_arg_m] - ct30_arg.moduli_array = jnp.array( - ct30_arg.moduli, dtype=getattr(ct30_arg, "modulus_dtype", jnp.uint32) - ) - ct30_pt_ntt = ( - pt13.polynomial[0, 0, :, : ct30_arg.polynomial.shape[-1]] - .reshape(ct30_arg.r, ct30_arg.c, ct30_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct30_ptct = v0.ptct_mul[v0.max_level - 1] - ct30_ptct.set_plaintext(ct30_pt_ntt) - ct30_raw = ct30_ptct.mul(ct30_arg, use_bat=False) - _ct30_data = ( - ct30_raw.polynomial if hasattr(ct30_raw, "polynomial") else ct30_raw - ) - _ct30_m_in = _ct30_data.shape[-1] - _ct30_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct30_m_in - ) - _ct30_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct30_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct30_r) - ) - _ct30_moduli = getattr(ct30_raw, "moduli", v0.q_towers) - if isinstance(_ct30_moduli, (int, np.integer)): - _ct30_moduli = [int(_ct30_moduli)] - ct30 = Polynomial( - { - "batch": _ct30_data.shape[0], - "num_elements": _ct30_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct30_m, - "precision": 32, - "degree_layout": (_ct30_r, _ct30_c), - }, - {"moduli": list(_ct30_moduli)[:_ct30_m]}, - ) - ct30.polynomial = _ct30_data.reshape( - _ct30_data.shape[0], _ct30_data.shape[1], _ct30_r, _ct30_c, _ct30_m_in - )[..., :_ct30_m].copy() - ct30.batch = ct30.polynomial.shape[0] - ct30.num_elements = ct30.polynomial.shape[1] - ct30.num_moduli = _ct30_m - ct30.degree_layout = (_ct30_r, _ct30_c) - ct30.r = _ct30_r - ct30.c = _ct30_c - ct30.moduli = list(_ct30_moduli)[:_ct30_m] - ct30.moduli_array = jnp.array( - ct30.moduli, dtype=getattr(ct30, "modulus_dtype", jnp.uint32) - ) - _ct31_data = ct28.polynomial if hasattr(ct28, "polynomial") else ct28 - _ct31_m_in = _ct31_data.shape[-1] - _ct31_m = _ct31_m_in - _ct31_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct31_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct31_r) - ) - _ct31_moduli = getattr(ct28, "moduli", v0.q_towers) - if isinstance(_ct31_moduli, (int, np.integer)): - _ct31_moduli = [int(_ct31_moduli)] - ct31 = Polynomial( - { - "batch": _ct31_data.shape[0], - "num_elements": _ct31_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct31_m, - "precision": 32, - "degree_layout": (_ct31_r, _ct31_c), - }, - {"moduli": list(_ct31_moduli)[:_ct31_m]}, - ) - ct31.polynomial = _ct31_data.reshape( - _ct31_data.shape[0], _ct31_data.shape[1], _ct31_r, _ct31_c, _ct31_m_in - )[..., :_ct31_m].copy() - ct31.batch = ct31.polynomial.shape[0] - ct31.num_elements = ct31.polynomial.shape[1] - ct31.num_moduli = _ct31_m - ct31.degree_layout = (_ct31_r, _ct31_c) - ct31.r = _ct31_r - ct31.c = _ct31_c - ct31.moduli = list(_ct31_moduli)[:_ct31_m] - ct31.moduli_array = jnp.array( - ct31.moduli, dtype=getattr(ct31, "modulus_dtype", jnp.uint32) - ) - _ct31_rhs_data = ct29.polynomial if hasattr(ct29, "polynomial") else ct29 - _ct31_rhs_m_in = _ct31_rhs_data.shape[-1] - _ct31_rhs_m = _ct31_rhs_m_in - _ct31_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct31_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct31_rhs_r) - ) - _ct31_rhs_moduli = getattr(ct29, "moduli", v0.q_towers) - if isinstance(_ct31_rhs_moduli, (int, np.integer)): - _ct31_rhs_moduli = [int(_ct31_rhs_moduli)] - ct31_rhs = Polynomial( - { - "batch": _ct31_rhs_data.shape[0], - "num_elements": _ct31_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct31_rhs_m, - "precision": 32, - "degree_layout": (_ct31_rhs_r, _ct31_rhs_c), - }, - {"moduli": list(_ct31_rhs_moduli)[:_ct31_rhs_m]}, - ) - ct31_rhs.polynomial = _ct31_rhs_data.reshape( - _ct31_rhs_data.shape[0], - _ct31_rhs_data.shape[1], - _ct31_rhs_r, - _ct31_rhs_c, - _ct31_rhs_m_in, - )[..., :_ct31_rhs_m].copy() - ct31_rhs.batch = ct31_rhs.polynomial.shape[0] - ct31_rhs.num_elements = ct31_rhs.polynomial.shape[1] - ct31_rhs.num_moduli = _ct31_rhs_m - ct31_rhs.degree_layout = (_ct31_rhs_r, _ct31_rhs_c) - ct31_rhs.r = _ct31_rhs_r - ct31_rhs.c = _ct31_rhs_c - ct31_rhs.moduli = list(_ct31_rhs_moduli)[:_ct31_rhs_m] - ct31_rhs.moduli_array = jnp.array( - ct31_rhs.moduli, dtype=getattr(ct31_rhs, "modulus_dtype", jnp.uint32) - ) - ct31.add(ct31_rhs) - _moduli = jnp.array(ct31.moduli, dtype=jnp.uint32) - ct31.polynomial = jnp.where( - ct31.polynomial >= _moduli, ct31.polynomial - _moduli, ct31.polynomial - ) - _ct32_data = ct31.polynomial if hasattr(ct31, "polynomial") else ct31 - _ct32_m_in = _ct32_data.shape[-1] - _ct32_m = _ct32_m_in - _ct32_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct32_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct32_r) - ) - _ct32_moduli = getattr(ct31, "moduli", v0.q_towers) - if isinstance(_ct32_moduli, (int, np.integer)): - _ct32_moduli = [int(_ct32_moduli)] - ct32 = Polynomial( - { - "batch": _ct32_data.shape[0], - "num_elements": _ct32_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct32_m, - "precision": 32, - "degree_layout": (_ct32_r, _ct32_c), - }, - {"moduli": list(_ct32_moduli)[:_ct32_m]}, - ) - ct32.polynomial = _ct32_data.reshape( - _ct32_data.shape[0], _ct32_data.shape[1], _ct32_r, _ct32_c, _ct32_m_in - )[..., :_ct32_m].copy() - ct32.batch = ct32.polynomial.shape[0] - ct32.num_elements = ct32.polynomial.shape[1] - ct32.num_moduli = _ct32_m - ct32.degree_layout = (_ct32_r, _ct32_c) - ct32.r = _ct32_r - ct32.c = _ct32_c - ct32.moduli = list(_ct32_moduli)[:_ct32_m] - ct32.moduli_array = jnp.array( - ct32.moduli, dtype=getattr(ct32, "modulus_dtype", jnp.uint32) - ) - _ct32_rhs_data = ct30.polynomial if hasattr(ct30, "polynomial") else ct30 - _ct32_rhs_m_in = _ct32_rhs_data.shape[-1] - _ct32_rhs_m = _ct32_rhs_m_in - _ct32_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct32_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct32_rhs_r) - ) - _ct32_rhs_moduli = getattr(ct30, "moduli", v0.q_towers) - if isinstance(_ct32_rhs_moduli, (int, np.integer)): - _ct32_rhs_moduli = [int(_ct32_rhs_moduli)] - ct32_rhs = Polynomial( - { - "batch": _ct32_rhs_data.shape[0], - "num_elements": _ct32_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct32_rhs_m, - "precision": 32, - "degree_layout": (_ct32_rhs_r, _ct32_rhs_c), - }, - {"moduli": list(_ct32_rhs_moduli)[:_ct32_rhs_m]}, - ) - ct32_rhs.polynomial = _ct32_rhs_data.reshape( - _ct32_rhs_data.shape[0], - _ct32_rhs_data.shape[1], - _ct32_rhs_r, - _ct32_rhs_c, - _ct32_rhs_m_in, - )[..., :_ct32_rhs_m].copy() - ct32_rhs.batch = ct32_rhs.polynomial.shape[0] - ct32_rhs.num_elements = ct32_rhs.polynomial.shape[1] - ct32_rhs.num_moduli = _ct32_rhs_m - ct32_rhs.degree_layout = (_ct32_rhs_r, _ct32_rhs_c) - ct32_rhs.r = _ct32_rhs_r - ct32_rhs.c = _ct32_rhs_c - ct32_rhs.moduli = list(_ct32_rhs_moduli)[:_ct32_rhs_m] - ct32_rhs.moduli_array = jnp.array( - ct32_rhs.moduli, dtype=getattr(ct32_rhs, "modulus_dtype", jnp.uint32) - ) - ct32.add(ct32_rhs) - _moduli = jnp.array(ct32.moduli, dtype=jnp.uint32) - ct32.polynomial = jnp.where( - ct32.polynomial >= _moduli, ct32.polynomial - _moduli, ct32.polynomial - ) - _ct33_arg_data = ct32.polynomial if hasattr(ct32, "polynomial") else ct32 - _ct33_arg_m_in = _ct33_arg_data.shape[-1] - _ct33_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct33_arg_m_in - ) - _ct33_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct33_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct33_arg_r) - ) - _ct33_arg_moduli = getattr(ct32, "moduli", v0.q_towers) - if isinstance(_ct33_arg_moduli, (int, np.integer)): - _ct33_arg_moduli = [int(_ct33_arg_moduli)] - ct33_arg = Polynomial( - { - "batch": _ct33_arg_data.shape[0], - "num_elements": _ct33_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct33_arg_m, - "precision": 32, - "degree_layout": (_ct33_arg_r, _ct33_arg_c), - }, - {"moduli": list(_ct33_arg_moduli)[:_ct33_arg_m]}, - ) - ct33_arg.polynomial = _ct33_arg_data.reshape( - _ct33_arg_data.shape[0], - _ct33_arg_data.shape[1], - _ct33_arg_r, - _ct33_arg_c, - _ct33_arg_m_in, - )[..., :_ct33_arg_m].copy() - ct33_arg.batch = ct33_arg.polynomial.shape[0] - ct33_arg.num_elements = ct33_arg.polynomial.shape[1] - ct33_arg.num_moduli = _ct33_arg_m - ct33_arg.degree_layout = (_ct33_arg_r, _ct33_arg_c) - ct33_arg.r = _ct33_arg_r - ct33_arg.c = _ct33_arg_c - ct33_arg.moduli = list(_ct33_arg_moduli)[:_ct33_arg_m] - ct33_arg.moduli_array = jnp.array( - ct33_arg.moduli, dtype=getattr(ct33_arg, "modulus_dtype", jnp.uint32) - ) - ct33_raw = v0.he_rot[v0.max_level - 1, 3].rotate(ct33_arg) - _ct33_data = ( - ct33_raw.polynomial if hasattr(ct33_raw, "polynomial") else ct33_raw - ) - _ct33_m_in = _ct33_data.shape[-1] - _ct33_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct33_m_in - ) - _ct33_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct33_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct33_r) - ) - _ct33_moduli = getattr(ct33_raw, "moduli", v0.q_towers) - if isinstance(_ct33_moduli, (int, np.integer)): - _ct33_moduli = [int(_ct33_moduli)] - ct33 = Polynomial( - { - "batch": _ct33_data.shape[0], - "num_elements": _ct33_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct33_m, - "precision": 32, - "degree_layout": (_ct33_r, _ct33_c), - }, - {"moduli": list(_ct33_moduli)[:_ct33_m]}, - ) - ct33.polynomial = _ct33_data.reshape( - _ct33_data.shape[0], _ct33_data.shape[1], _ct33_r, _ct33_c, _ct33_m_in - )[..., :_ct33_m].copy() - ct33.batch = ct33.polynomial.shape[0] - ct33.num_elements = ct33.polynomial.shape[1] - ct33.num_moduli = _ct33_m - ct33.degree_layout = (_ct33_r, _ct33_c) - ct33.r = _ct33_r - ct33.c = _ct33_c - ct33.moduli = list(_ct33_moduli)[:_ct33_m] - ct33.moduli_array = jnp.array( - ct33.moduli, dtype=getattr(ct33, "modulus_dtype", jnp.uint32) - ) - _ct34_arg_data = ct20.polynomial if hasattr(ct20, "polynomial") else ct20 - _ct34_arg_m_in = _ct34_arg_data.shape[-1] - _ct34_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct34_arg_m_in - ) - _ct34_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct34_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct34_arg_r) - ) - _ct34_arg_moduli = getattr(ct20, "moduli", v0.q_towers) - if isinstance(_ct34_arg_moduli, (int, np.integer)): - _ct34_arg_moduli = [int(_ct34_arg_moduli)] - ct34_arg = Polynomial( - { - "batch": _ct34_arg_data.shape[0], - "num_elements": _ct34_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct34_arg_m, - "precision": 32, - "degree_layout": (_ct34_arg_r, _ct34_arg_c), - }, - {"moduli": list(_ct34_arg_moduli)[:_ct34_arg_m]}, - ) - ct34_arg.polynomial = _ct34_arg_data.reshape( - _ct34_arg_data.shape[0], - _ct34_arg_data.shape[1], - _ct34_arg_r, - _ct34_arg_c, - _ct34_arg_m_in, - )[..., :_ct34_arg_m].copy() - ct34_arg.batch = ct34_arg.polynomial.shape[0] - ct34_arg.num_elements = ct34_arg.polynomial.shape[1] - ct34_arg.num_moduli = _ct34_arg_m - ct34_arg.degree_layout = (_ct34_arg_r, _ct34_arg_c) - ct34_arg.r = _ct34_arg_r - ct34_arg.c = _ct34_arg_c - ct34_arg.moduli = list(_ct34_arg_moduli)[:_ct34_arg_m] - ct34_arg.moduli_array = jnp.array( - ct34_arg.moduli, dtype=getattr(ct34_arg, "modulus_dtype", jnp.uint32) - ) - ct34_pt_ntt = ( - pt14.polynomial[0, 0, :, : ct34_arg.polynomial.shape[-1]] - .reshape(ct34_arg.r, ct34_arg.c, ct34_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct34_ptct = v0.ptct_mul[v0.max_level - 1] - ct34_ptct.set_plaintext(ct34_pt_ntt) - ct34_raw = ct34_ptct.mul(ct34_arg, use_bat=False) - _ct34_data = ( - ct34_raw.polynomial if hasattr(ct34_raw, "polynomial") else ct34_raw - ) - _ct34_m_in = _ct34_data.shape[-1] - _ct34_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct34_m_in - ) - _ct34_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct34_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct34_r) - ) - _ct34_moduli = getattr(ct34_raw, "moduli", v0.q_towers) - if isinstance(_ct34_moduli, (int, np.integer)): - _ct34_moduli = [int(_ct34_moduli)] - ct34 = Polynomial( - { - "batch": _ct34_data.shape[0], - "num_elements": _ct34_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct34_m, - "precision": 32, - "degree_layout": (_ct34_r, _ct34_c), - }, - {"moduli": list(_ct34_moduli)[:_ct34_m]}, - ) - ct34.polynomial = _ct34_data.reshape( - _ct34_data.shape[0], _ct34_data.shape[1], _ct34_r, _ct34_c, _ct34_m_in - )[..., :_ct34_m].copy() - ct34.batch = ct34.polynomial.shape[0] - ct34.num_elements = ct34.polynomial.shape[1] - ct34.num_moduli = _ct34_m - ct34.degree_layout = (_ct34_r, _ct34_c) - ct34.r = _ct34_r - ct34.c = _ct34_c - ct34.moduli = list(_ct34_moduli)[:_ct34_m] - ct34.moduli_array = jnp.array( - ct34.moduli, dtype=getattr(ct34, "modulus_dtype", jnp.uint32) - ) - _ct35_arg_data = ct23.polynomial if hasattr(ct23, "polynomial") else ct23 - _ct35_arg_m_in = _ct35_arg_data.shape[-1] - _ct35_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct35_arg_m_in - ) - _ct35_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct35_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct35_arg_r) - ) - _ct35_arg_moduli = getattr(ct23, "moduli", v0.q_towers) - if isinstance(_ct35_arg_moduli, (int, np.integer)): - _ct35_arg_moduli = [int(_ct35_arg_moduli)] - ct35_arg = Polynomial( - { - "batch": _ct35_arg_data.shape[0], - "num_elements": _ct35_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct35_arg_m, - "precision": 32, - "degree_layout": (_ct35_arg_r, _ct35_arg_c), - }, - {"moduli": list(_ct35_arg_moduli)[:_ct35_arg_m]}, - ) - ct35_arg.polynomial = _ct35_arg_data.reshape( - _ct35_arg_data.shape[0], - _ct35_arg_data.shape[1], - _ct35_arg_r, - _ct35_arg_c, - _ct35_arg_m_in, - )[..., :_ct35_arg_m].copy() - ct35_arg.batch = ct35_arg.polynomial.shape[0] - ct35_arg.num_elements = ct35_arg.polynomial.shape[1] - ct35_arg.num_moduli = _ct35_arg_m - ct35_arg.degree_layout = (_ct35_arg_r, _ct35_arg_c) - ct35_arg.r = _ct35_arg_r - ct35_arg.c = _ct35_arg_c - ct35_arg.moduli = list(_ct35_arg_moduli)[:_ct35_arg_m] - ct35_arg.moduli_array = jnp.array( - ct35_arg.moduli, dtype=getattr(ct35_arg, "modulus_dtype", jnp.uint32) - ) - ct35_pt_ntt = ( - pt15.polynomial[0, 0, :, : ct35_arg.polynomial.shape[-1]] - .reshape(ct35_arg.r, ct35_arg.c, ct35_arg.polynomial.shape[-1]) - .astype(jnp.uint32) - ) - ct35_ptct = v0.ptct_mul[v0.max_level - 1] - ct35_ptct.set_plaintext(ct35_pt_ntt) - ct35_raw = ct35_ptct.mul(ct35_arg, use_bat=False) - _ct35_data = ( - ct35_raw.polynomial if hasattr(ct35_raw, "polynomial") else ct35_raw - ) - _ct35_m_in = _ct35_data.shape[-1] - _ct35_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct35_m_in - ) - _ct35_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct35_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct35_r) - ) - _ct35_moduli = getattr(ct35_raw, "moduli", v0.q_towers) - if isinstance(_ct35_moduli, (int, np.integer)): - _ct35_moduli = [int(_ct35_moduli)] - ct35 = Polynomial( - { - "batch": _ct35_data.shape[0], - "num_elements": _ct35_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct35_m, - "precision": 32, - "degree_layout": (_ct35_r, _ct35_c), - }, - {"moduli": list(_ct35_moduli)[:_ct35_m]}, - ) - ct35.polynomial = _ct35_data.reshape( - _ct35_data.shape[0], _ct35_data.shape[1], _ct35_r, _ct35_c, _ct35_m_in - )[..., :_ct35_m].copy() - ct35.batch = ct35.polynomial.shape[0] - ct35.num_elements = ct35.polynomial.shape[1] - ct35.num_moduli = _ct35_m - ct35.degree_layout = (_ct35_r, _ct35_c) - ct35.r = _ct35_r - ct35.c = _ct35_c - ct35.moduli = list(_ct35_moduli)[:_ct35_m] - ct35.moduli_array = jnp.array( - ct35.moduli, dtype=getattr(ct35, "modulus_dtype", jnp.uint32) - ) - _ct36_data = ct34.polynomial if hasattr(ct34, "polynomial") else ct34 - _ct36_m_in = _ct36_data.shape[-1] - _ct36_m = _ct36_m_in - _ct36_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct36_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct36_r) - ) - _ct36_moduli = getattr(ct34, "moduli", v0.q_towers) - if isinstance(_ct36_moduli, (int, np.integer)): - _ct36_moduli = [int(_ct36_moduli)] - ct36 = Polynomial( - { - "batch": _ct36_data.shape[0], - "num_elements": _ct36_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct36_m, - "precision": 32, - "degree_layout": (_ct36_r, _ct36_c), - }, - {"moduli": list(_ct36_moduli)[:_ct36_m]}, - ) - ct36.polynomial = _ct36_data.reshape( - _ct36_data.shape[0], _ct36_data.shape[1], _ct36_r, _ct36_c, _ct36_m_in - )[..., :_ct36_m].copy() - ct36.batch = ct36.polynomial.shape[0] - ct36.num_elements = ct36.polynomial.shape[1] - ct36.num_moduli = _ct36_m - ct36.degree_layout = (_ct36_r, _ct36_c) - ct36.r = _ct36_r - ct36.c = _ct36_c - ct36.moduli = list(_ct36_moduli)[:_ct36_m] - ct36.moduli_array = jnp.array( - ct36.moduli, dtype=getattr(ct36, "modulus_dtype", jnp.uint32) - ) - _ct36_rhs_data = ct35.polynomial if hasattr(ct35, "polynomial") else ct35 - _ct36_rhs_m_in = _ct36_rhs_data.shape[-1] - _ct36_rhs_m = _ct36_rhs_m_in - _ct36_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct36_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct36_rhs_r) - ) - _ct36_rhs_moduli = getattr(ct35, "moduli", v0.q_towers) - if isinstance(_ct36_rhs_moduli, (int, np.integer)): - _ct36_rhs_moduli = [int(_ct36_rhs_moduli)] - ct36_rhs = Polynomial( - { - "batch": _ct36_rhs_data.shape[0], - "num_elements": _ct36_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct36_rhs_m, - "precision": 32, - "degree_layout": (_ct36_rhs_r, _ct36_rhs_c), - }, - {"moduli": list(_ct36_rhs_moduli)[:_ct36_rhs_m]}, - ) - ct36_rhs.polynomial = _ct36_rhs_data.reshape( - _ct36_rhs_data.shape[0], - _ct36_rhs_data.shape[1], - _ct36_rhs_r, - _ct36_rhs_c, - _ct36_rhs_m_in, - )[..., :_ct36_rhs_m].copy() - ct36_rhs.batch = ct36_rhs.polynomial.shape[0] - ct36_rhs.num_elements = ct36_rhs.polynomial.shape[1] - ct36_rhs.num_moduli = _ct36_rhs_m - ct36_rhs.degree_layout = (_ct36_rhs_r, _ct36_rhs_c) - ct36_rhs.r = _ct36_rhs_r - ct36_rhs.c = _ct36_rhs_c - ct36_rhs.moduli = list(_ct36_rhs_moduli)[:_ct36_rhs_m] - ct36_rhs.moduli_array = jnp.array( - ct36_rhs.moduli, dtype=getattr(ct36_rhs, "modulus_dtype", jnp.uint32) - ) - ct36.add(ct36_rhs) - _moduli = jnp.array(ct36.moduli, dtype=jnp.uint32) - ct36.polynomial = jnp.where( - ct36.polynomial >= _moduli, ct36.polynomial - _moduli, ct36.polynomial - ) - _ct37_arg_data = ct36.polynomial if hasattr(ct36, "polynomial") else ct36 - _ct37_arg_m_in = _ct37_arg_data.shape[-1] - _ct37_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct37_arg_m_in - ) - _ct37_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct37_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct37_arg_r) - ) - _ct37_arg_moduli = getattr(ct36, "moduli", v0.q_towers) - if isinstance(_ct37_arg_moduli, (int, np.integer)): - _ct37_arg_moduli = [int(_ct37_arg_moduli)] - ct37_arg = Polynomial( - { - "batch": _ct37_arg_data.shape[0], - "num_elements": _ct37_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct37_arg_m, - "precision": 32, - "degree_layout": (_ct37_arg_r, _ct37_arg_c), - }, - {"moduli": list(_ct37_arg_moduli)[:_ct37_arg_m]}, - ) - ct37_arg.polynomial = _ct37_arg_data.reshape( - _ct37_arg_data.shape[0], - _ct37_arg_data.shape[1], - _ct37_arg_r, - _ct37_arg_c, - _ct37_arg_m_in, - )[..., :_ct37_arg_m].copy() - ct37_arg.batch = ct37_arg.polynomial.shape[0] - ct37_arg.num_elements = ct37_arg.polynomial.shape[1] - ct37_arg.num_moduli = _ct37_arg_m - ct37_arg.degree_layout = (_ct37_arg_r, _ct37_arg_c) - ct37_arg.r = _ct37_arg_r - ct37_arg.c = _ct37_arg_c - ct37_arg.moduli = list(_ct37_arg_moduli)[:_ct37_arg_m] - ct37_arg.moduli_array = jnp.array( - ct37_arg.moduli, dtype=getattr(ct37_arg, "modulus_dtype", jnp.uint32) - ) - ct37_raw = v0.he_rot[v0.max_level - 1, 6].rotate(ct37_arg) - _ct37_data = ( - ct37_raw.polynomial if hasattr(ct37_raw, "polynomial") else ct37_raw - ) - _ct37_m_in = _ct37_data.shape[-1] - _ct37_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct37_m_in - ) - _ct37_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct37_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct37_r) - ) - _ct37_moduli = getattr(ct37_raw, "moduli", v0.q_towers) - if isinstance(_ct37_moduli, (int, np.integer)): - _ct37_moduli = [int(_ct37_moduli)] - ct37 = Polynomial( - { - "batch": _ct37_data.shape[0], - "num_elements": _ct37_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct37_m, - "precision": 32, - "degree_layout": (_ct37_r, _ct37_c), - }, - {"moduli": list(_ct37_moduli)[:_ct37_m]}, - ) - ct37.polynomial = _ct37_data.reshape( - _ct37_data.shape[0], _ct37_data.shape[1], _ct37_r, _ct37_c, _ct37_m_in - )[..., :_ct37_m].copy() - ct37.batch = ct37.polynomial.shape[0] - ct37.num_elements = ct37.polynomial.shape[1] - ct37.num_moduli = _ct37_m - ct37.degree_layout = (_ct37_r, _ct37_c) - ct37.r = _ct37_r - ct37.c = _ct37_c - ct37.moduli = list(_ct37_moduli)[:_ct37_m] - ct37.moduli_array = jnp.array( - ct37.moduli, dtype=getattr(ct37, "modulus_dtype", jnp.uint32) - ) - _ct38_data = ct21.polynomial if hasattr(ct21, "polynomial") else ct21 - _ct38_m_in = _ct38_data.shape[-1] - _ct38_m = _ct38_m_in - _ct38_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct38_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct38_r) - ) - _ct38_moduli = getattr(ct21, "moduli", v0.q_towers) - if isinstance(_ct38_moduli, (int, np.integer)): - _ct38_moduli = [int(_ct38_moduli)] - ct38 = Polynomial( - { - "batch": _ct38_data.shape[0], - "num_elements": _ct38_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct38_m, - "precision": 32, - "degree_layout": (_ct38_r, _ct38_c), - }, - {"moduli": list(_ct38_moduli)[:_ct38_m]}, - ) - ct38.polynomial = _ct38_data.reshape( - _ct38_data.shape[0], _ct38_data.shape[1], _ct38_r, _ct38_c, _ct38_m_in - )[..., :_ct38_m].copy() - ct38.batch = ct38.polynomial.shape[0] - ct38.num_elements = ct38.polynomial.shape[1] - ct38.num_moduli = _ct38_m - ct38.degree_layout = (_ct38_r, _ct38_c) - ct38.r = _ct38_r - ct38.c = _ct38_c - ct38.moduli = list(_ct38_moduli)[:_ct38_m] - ct38.moduli_array = jnp.array( - ct38.moduli, dtype=getattr(ct38, "modulus_dtype", jnp.uint32) - ) - _ct38_rhs_data = ct24.polynomial if hasattr(ct24, "polynomial") else ct24 - _ct38_rhs_m_in = _ct38_rhs_data.shape[-1] - _ct38_rhs_m = _ct38_rhs_m_in - _ct38_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct38_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct38_rhs_r) - ) - _ct38_rhs_moduli = getattr(ct24, "moduli", v0.q_towers) - if isinstance(_ct38_rhs_moduli, (int, np.integer)): - _ct38_rhs_moduli = [int(_ct38_rhs_moduli)] - ct38_rhs = Polynomial( - { - "batch": _ct38_rhs_data.shape[0], - "num_elements": _ct38_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct38_rhs_m, - "precision": 32, - "degree_layout": (_ct38_rhs_r, _ct38_rhs_c), - }, - {"moduli": list(_ct38_rhs_moduli)[:_ct38_rhs_m]}, - ) - ct38_rhs.polynomial = _ct38_rhs_data.reshape( - _ct38_rhs_data.shape[0], - _ct38_rhs_data.shape[1], - _ct38_rhs_r, - _ct38_rhs_c, - _ct38_rhs_m_in, - )[..., :_ct38_rhs_m].copy() - ct38_rhs.batch = ct38_rhs.polynomial.shape[0] - ct38_rhs.num_elements = ct38_rhs.polynomial.shape[1] - ct38_rhs.num_moduli = _ct38_rhs_m - ct38_rhs.degree_layout = (_ct38_rhs_r, _ct38_rhs_c) - ct38_rhs.r = _ct38_rhs_r - ct38_rhs.c = _ct38_rhs_c - ct38_rhs.moduli = list(_ct38_rhs_moduli)[:_ct38_rhs_m] - ct38_rhs.moduli_array = jnp.array( - ct38_rhs.moduli, dtype=getattr(ct38_rhs, "modulus_dtype", jnp.uint32) - ) - ct38.add(ct38_rhs) - _moduli = jnp.array(ct38.moduli, dtype=jnp.uint32) - ct38.polynomial = jnp.where( - ct38.polynomial >= _moduli, ct38.polynomial - _moduli, ct38.polynomial - ) - _ct39_data = ct27.polynomial if hasattr(ct27, "polynomial") else ct27 - _ct39_m_in = _ct39_data.shape[-1] - _ct39_m = _ct39_m_in - _ct39_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct39_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct39_r) - ) - _ct39_moduli = getattr(ct27, "moduli", v0.q_towers) - if isinstance(_ct39_moduli, (int, np.integer)): - _ct39_moduli = [int(_ct39_moduli)] - ct39 = Polynomial( - { - "batch": _ct39_data.shape[0], - "num_elements": _ct39_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct39_m, - "precision": 32, - "degree_layout": (_ct39_r, _ct39_c), - }, - {"moduli": list(_ct39_moduli)[:_ct39_m]}, - ) - ct39.polynomial = _ct39_data.reshape( - _ct39_data.shape[0], _ct39_data.shape[1], _ct39_r, _ct39_c, _ct39_m_in - )[..., :_ct39_m].copy() - ct39.batch = ct39.polynomial.shape[0] - ct39.num_elements = ct39.polynomial.shape[1] - ct39.num_moduli = _ct39_m - ct39.degree_layout = (_ct39_r, _ct39_c) - ct39.r = _ct39_r - ct39.c = _ct39_c - ct39.moduli = list(_ct39_moduli)[:_ct39_m] - ct39.moduli_array = jnp.array( - ct39.moduli, dtype=getattr(ct39, "modulus_dtype", jnp.uint32) - ) - _ct39_rhs_data = ct33.polynomial if hasattr(ct33, "polynomial") else ct33 - _ct39_rhs_m_in = _ct39_rhs_data.shape[-1] - _ct39_rhs_m = _ct39_rhs_m_in - _ct39_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct39_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct39_rhs_r) - ) - _ct39_rhs_moduli = getattr(ct33, "moduli", v0.q_towers) - if isinstance(_ct39_rhs_moduli, (int, np.integer)): - _ct39_rhs_moduli = [int(_ct39_rhs_moduli)] - ct39_rhs = Polynomial( - { - "batch": _ct39_rhs_data.shape[0], - "num_elements": _ct39_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct39_rhs_m, - "precision": 32, - "degree_layout": (_ct39_rhs_r, _ct39_rhs_c), - }, - {"moduli": list(_ct39_rhs_moduli)[:_ct39_rhs_m]}, - ) - ct39_rhs.polynomial = _ct39_rhs_data.reshape( - _ct39_rhs_data.shape[0], - _ct39_rhs_data.shape[1], - _ct39_rhs_r, - _ct39_rhs_c, - _ct39_rhs_m_in, - )[..., :_ct39_rhs_m].copy() - ct39_rhs.batch = ct39_rhs.polynomial.shape[0] - ct39_rhs.num_elements = ct39_rhs.polynomial.shape[1] - ct39_rhs.num_moduli = _ct39_rhs_m - ct39_rhs.degree_layout = (_ct39_rhs_r, _ct39_rhs_c) - ct39_rhs.r = _ct39_rhs_r - ct39_rhs.c = _ct39_rhs_c - ct39_rhs.moduli = list(_ct39_rhs_moduli)[:_ct39_rhs_m] - ct39_rhs.moduli_array = jnp.array( - ct39_rhs.moduli, dtype=getattr(ct39_rhs, "modulus_dtype", jnp.uint32) - ) - ct39.add(ct39_rhs) - _moduli = jnp.array(ct39.moduli, dtype=jnp.uint32) - ct39.polynomial = jnp.where( - ct39.polynomial >= _moduli, ct39.polynomial - _moduli, ct39.polynomial - ) - _ct40_data = ct39.polynomial if hasattr(ct39, "polynomial") else ct39 - _ct40_m_in = _ct40_data.shape[-1] - _ct40_m = _ct40_m_in - _ct40_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct40_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct40_r) - ) - _ct40_moduli = getattr(ct39, "moduli", v0.q_towers) - if isinstance(_ct40_moduli, (int, np.integer)): - _ct40_moduli = [int(_ct40_moduli)] - ct40 = Polynomial( - { - "batch": _ct40_data.shape[0], - "num_elements": _ct40_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct40_m, - "precision": 32, - "degree_layout": (_ct40_r, _ct40_c), - }, - {"moduli": list(_ct40_moduli)[:_ct40_m]}, - ) - ct40.polynomial = _ct40_data.reshape( - _ct40_data.shape[0], _ct40_data.shape[1], _ct40_r, _ct40_c, _ct40_m_in - )[..., :_ct40_m].copy() - ct40.batch = ct40.polynomial.shape[0] - ct40.num_elements = ct40.polynomial.shape[1] - ct40.num_moduli = _ct40_m - ct40.degree_layout = (_ct40_r, _ct40_c) - ct40.r = _ct40_r - ct40.c = _ct40_c - ct40.moduli = list(_ct40_moduli)[:_ct40_m] - ct40.moduli_array = jnp.array( - ct40.moduli, dtype=getattr(ct40, "modulus_dtype", jnp.uint32) - ) - _ct40_rhs_data = ct37.polynomial if hasattr(ct37, "polynomial") else ct37 - _ct40_rhs_m_in = _ct40_rhs_data.shape[-1] - _ct40_rhs_m = _ct40_rhs_m_in - _ct40_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct40_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct40_rhs_r) - ) - _ct40_rhs_moduli = getattr(ct37, "moduli", v0.q_towers) - if isinstance(_ct40_rhs_moduli, (int, np.integer)): - _ct40_rhs_moduli = [int(_ct40_rhs_moduli)] - ct40_rhs = Polynomial( - { - "batch": _ct40_rhs_data.shape[0], - "num_elements": _ct40_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct40_rhs_m, - "precision": 32, - "degree_layout": (_ct40_rhs_r, _ct40_rhs_c), - }, - {"moduli": list(_ct40_rhs_moduli)[:_ct40_rhs_m]}, - ) - ct40_rhs.polynomial = _ct40_rhs_data.reshape( - _ct40_rhs_data.shape[0], - _ct40_rhs_data.shape[1], - _ct40_rhs_r, - _ct40_rhs_c, - _ct40_rhs_m_in, - )[..., :_ct40_rhs_m].copy() - ct40_rhs.batch = ct40_rhs.polynomial.shape[0] - ct40_rhs.num_elements = ct40_rhs.polynomial.shape[1] - ct40_rhs.num_moduli = _ct40_rhs_m - ct40_rhs.degree_layout = (_ct40_rhs_r, _ct40_rhs_c) - ct40_rhs.r = _ct40_rhs_r - ct40_rhs.c = _ct40_rhs_c - ct40_rhs.moduli = list(_ct40_rhs_moduli)[:_ct40_rhs_m] - ct40_rhs.moduli_array = jnp.array( - ct40_rhs.moduli, dtype=getattr(ct40_rhs, "modulus_dtype", jnp.uint32) - ) - ct40.add(ct40_rhs) - _moduli = jnp.array(ct40.moduli, dtype=jnp.uint32) - ct40.polynomial = jnp.where( - ct40.polynomial >= _moduli, ct40.polynomial - _moduli, ct40.polynomial - ) - _ct41_data = ct38.polynomial if hasattr(ct38, "polynomial") else ct38 - _ct41_m_in = _ct41_data.shape[-1] - _ct41_m = _ct41_m_in - _ct41_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct41_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct41_r) - ) - _ct41_moduli = getattr(ct38, "moduli", v0.q_towers) - if isinstance(_ct41_moduli, (int, np.integer)): - _ct41_moduli = [int(_ct41_moduli)] - ct41 = Polynomial( - { - "batch": _ct41_data.shape[0], - "num_elements": _ct41_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct41_m, - "precision": 32, - "degree_layout": (_ct41_r, _ct41_c), - }, - {"moduli": list(_ct41_moduli)[:_ct41_m]}, - ) - ct41.polynomial = _ct41_data.reshape( - _ct41_data.shape[0], _ct41_data.shape[1], _ct41_r, _ct41_c, _ct41_m_in - )[..., :_ct41_m].copy() - ct41.batch = ct41.polynomial.shape[0] - ct41.num_elements = ct41.polynomial.shape[1] - ct41.num_moduli = _ct41_m - ct41.degree_layout = (_ct41_r, _ct41_c) - ct41.r = _ct41_r - ct41.c = _ct41_c - ct41.moduli = list(_ct41_moduli)[:_ct41_m] - ct41.moduli_array = jnp.array( - ct41.moduli, dtype=getattr(ct41, "modulus_dtype", jnp.uint32) - ) - _ct41_rhs_data = ct40.polynomial if hasattr(ct40, "polynomial") else ct40 - _ct41_rhs_m_in = _ct41_rhs_data.shape[-1] - _ct41_rhs_m = _ct41_rhs_m_in - _ct41_rhs_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct41_rhs_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct41_rhs_r) - ) - _ct41_rhs_moduli = getattr(ct40, "moduli", v0.q_towers) - if isinstance(_ct41_rhs_moduli, (int, np.integer)): - _ct41_rhs_moduli = [int(_ct41_rhs_moduli)] - ct41_rhs = Polynomial( - { - "batch": _ct41_rhs_data.shape[0], - "num_elements": _ct41_rhs_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct41_rhs_m, - "precision": 32, - "degree_layout": (_ct41_rhs_r, _ct41_rhs_c), - }, - {"moduli": list(_ct41_rhs_moduli)[:_ct41_rhs_m]}, - ) - ct41_rhs.polynomial = _ct41_rhs_data.reshape( - _ct41_rhs_data.shape[0], - _ct41_rhs_data.shape[1], - _ct41_rhs_r, - _ct41_rhs_c, - _ct41_rhs_m_in, - )[..., :_ct41_rhs_m].copy() - ct41_rhs.batch = ct41_rhs.polynomial.shape[0] - ct41_rhs.num_elements = ct41_rhs.polynomial.shape[1] - ct41_rhs.num_moduli = _ct41_rhs_m - ct41_rhs.degree_layout = (_ct41_rhs_r, _ct41_rhs_c) - ct41_rhs.r = _ct41_rhs_r - ct41_rhs.c = _ct41_rhs_c - ct41_rhs.moduli = list(_ct41_rhs_moduli)[:_ct41_rhs_m] - ct41_rhs.moduli_array = jnp.array( - ct41_rhs.moduli, dtype=getattr(ct41_rhs, "modulus_dtype", jnp.uint32) - ) - ct41.add(ct41_rhs) - _moduli = jnp.array(ct41.moduli, dtype=jnp.uint32) - ct41.polynomial = jnp.where( - ct41.polynomial >= _moduli, ct41.polynomial - _moduli, ct41.polynomial - ) - v20 = [None] * 1 - _ct42_arg_data = ct41.polynomial if hasattr(ct41, "polynomial") else ct41 - _ct42_arg_m_in = _ct42_arg_data.shape[-1] - _ct42_arg_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 1) - if hasattr(v0, "_param_cache") - else _ct42_arg_m_in - ) - _ct42_arg_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct42_arg_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct42_arg_r) - ) - _ct42_arg_moduli = getattr(ct41, "moduli", v0.q_towers) - if isinstance(_ct42_arg_moduli, (int, np.integer)): - _ct42_arg_moduli = [int(_ct42_arg_moduli)] - ct42_arg = Polynomial( - { - "batch": _ct42_arg_data.shape[0], - "num_elements": _ct42_arg_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct42_arg_m, - "precision": 32, - "degree_layout": (_ct42_arg_r, _ct42_arg_c), - }, - {"moduli": list(_ct42_arg_moduli)[:_ct42_arg_m]}, - ) - ct42_arg.polynomial = _ct42_arg_data.reshape( - _ct42_arg_data.shape[0], - _ct42_arg_data.shape[1], - _ct42_arg_r, - _ct42_arg_c, - _ct42_arg_m_in, - )[..., :_ct42_arg_m].copy() - ct42_arg.batch = ct42_arg.polynomial.shape[0] - ct42_arg.num_elements = ct42_arg.polynomial.shape[1] - ct42_arg.num_moduli = _ct42_arg_m - ct42_arg.degree_layout = (_ct42_arg_r, _ct42_arg_c) - ct42_arg.r = _ct42_arg_r - ct42_arg.c = _ct42_arg_c - ct42_arg.moduli = list(_ct42_arg_moduli)[:_ct42_arg_m] - ct42_arg.moduli_array = jnp.array( - ct42_arg.moduli, dtype=getattr(ct42_arg, "modulus_dtype", jnp.uint32) - ) - ct42_raw = v0.he_rescale[v0.max_level - 1, v0.max_level - 2](ct42_arg) - _ct42_data = ( - ct42_raw.polynomial if hasattr(ct42_raw, "polynomial") else ct42_raw - ) - _ct42_m_in = _ct42_data.shape[-1] - _ct42_m = ( - v0._param_cache.num_q_at_level(v0.max_level - 2) - if hasattr(v0, "_param_cache") - else _ct42_m_in - ) - _ct42_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct42_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct42_r) - ) - _ct42_moduli = getattr(ct42_raw, "moduli", v0.q_towers) - if isinstance(_ct42_moduli, (int, np.integer)): - _ct42_moduli = [int(_ct42_moduli)] - ct42 = Polynomial( - { - "batch": _ct42_data.shape[0], - "num_elements": _ct42_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct42_m, - "precision": 32, - "degree_layout": (_ct42_r, _ct42_c), - }, - {"moduli": list(_ct42_moduli)[:_ct42_m]}, - ) - ct42.polynomial = _ct42_data.reshape( - _ct42_data.shape[0], _ct42_data.shape[1], _ct42_r, _ct42_c, _ct42_m_in - )[..., :_ct42_m].copy() - ct42.batch = ct42.polynomial.shape[0] - ct42.num_elements = ct42.polynomial.shape[1] - ct42.num_moduli = _ct42_m - ct42.degree_layout = (_ct42_r, _ct42_c) - ct42.r = _ct42_r - ct42.c = _ct42_c - ct42.moduli = list(_ct42_moduli)[:_ct42_m] - ct42.moduli_array = jnp.array( - ct42.moduli, dtype=getattr(ct42, "modulus_dtype", jnp.uint32) - ) - v20[0] = ct42 - v21 = v20 - return v21 - - -def matvec_chain( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, -) -> np.ndarray: - (v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) = ( - matvec_chain__preprocessing(v0, v1) - ) - v15 = matvec_chain__preprocessed( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 - ) - return v15 - - -def matvec_chain__encrypt__arg0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = np.full( - ( - 1, - 8, - ), - 0.000000e00, - dtype=np.float32, - ) - v6 = 0 - v7 = 1 - v8 = 8 - v9 = v5.copy() - for v10 in range(0, 8): - v12 = int(v10) - v13 = v2[v12] - v9[0, v12] = v13 - v15 = v9[0 : 0 + 1, 0 : 0 + 8].reshape(8) - pt = v0.encode(v15) - v0.public_key = v3 - ct_raw = v0.encrypt(pt) - _ct_data = ct_raw.polynomial if hasattr(ct_raw, "polynomial") else ct_raw - _ct_m_in = _ct_data.shape[-1] - _ct_m = _ct_m_in - _ct_r = ( - v0._param_cache.r - if hasattr(v0, "_param_cache") - else v0.parameters.get("r", int(np.sqrt(v0.degree))) - ) - _ct_c = ( - v0._param_cache.c - if hasattr(v0, "_param_cache") - else v0.parameters.get("c", v0.degree // _ct_r) - ) - _ct_moduli = getattr(ct_raw, "moduli", v0.q_towers) - if isinstance(_ct_moduli, (int, np.integer)): - _ct_moduli = [int(_ct_moduli)] - ct = Polynomial( - { - "batch": _ct_data.shape[0], - "num_elements": _ct_data.shape[1], - "degree": v0.degree, - "num_moduli": _ct_m, - "precision": 32, - "degree_layout": (_ct_r, _ct_c), - }, - {"moduli": list(_ct_moduli)[:_ct_m]}, - ) - ct.polynomial = _ct_data.reshape( - _ct_data.shape[0], _ct_data.shape[1], _ct_r, _ct_c, _ct_m_in - )[..., :_ct_m].copy() - ct.batch = ct.polynomial.shape[0] - ct.num_elements = ct.polynomial.shape[1] - ct.num_moduli = _ct_m - ct.degree_layout = (_ct_r, _ct_c) - ct.r = _ct_r - ct.c = _ct_c - ct.moduli = list(_ct_moduli)[:_ct_m] - ct.moduli_array = jnp.array( - ct.moduli, dtype=getattr(ct, "modulus_dtype", jnp.uint32) - ) - v16 = [ct] - return v16 - - -def matvec_chain__decrypt__result0( - v0: ckks.CKKSContext, - v1: dict, - v2: np.ndarray, - v3: np.ndarray, -) -> np.ndarray: - v4 = 0 - v5 = 8 - v6 = 1 - v7 = 0 - v8 = np.full((8,), 0.000000e00, dtype=np.float32) - ct = v2[0] - v0.secret_key = v3 - _num_moduli = ct.polynomial.shape[-1] - _q_sub = list(getattr(ct, "moduli", v0.q_towers))[:_num_moduli] - _ct_for_dec = Polynomial( - { - "batch": ct.polynomial.shape[0], - "num_elements": ct.polynomial.shape[1], - "degree": v0.degree, - "precision": 32, - "num_moduli": _num_moduli, - "degree_layout": (v0.degree,), - }, - {"moduli": _q_sub}, - ) - _ct_for_dec.set_batch_polynomial( - ct.polynomial.reshape( - ct.polynomial.shape[0], ct.polynomial.shape[1], v0.degree, _num_moduli - ) - ) - pt = v0.decrypt(_ct_for_dec) - v9 = v0.decode(pt, is_ntt=False).real.reshape(1, 8) - v10 = v8.copy() - for v11 in range(0, 8): - v13 = int(v11) - v14 = v9[0, v13] - v10[v13] = v14 - return v10 - - -def matvec_identity__generate_crypto_context( - public_key, - secret_key, - evaluation_key, -) -> ckks.CKKSContext: - params = { - "degree": 16, - "num_slots": 8, - "batch": 1, - "r": 4, - "c": 4, - "dnum": 3, - "numEvalMult": 1, - "scaling_factor": 563019763943521, - "q_towers": [1073742881, 1073742721, 1073741441, 1073741857, 524353], - "p_towers": [1073740609, 1073739937, 1073739649], - "composite_degree": 1, - "p": 30, - "max_bits_in_word": 61, - "max_bits_value": 9223372036854775295, - "noise_scale_degree": 1, - "CKKS_M_FACTOR": 1, - "public_key": public_key, - "secret_key": secret_key, - "evaluation_key": evaluation_key, - } - v0 = ckks.CKKSContext(params) - return v0 - - -def matvec_identity__configure_crypto_context( - v0: ckks.CKKSContext, -): - v0.program_initialization( - total_hemul_levels=v0.max_level, - total_rotation_indices=[1, 2, 3, 6], - dnum=3, - r=4, - c=4, - batch=1, - )