Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions lib/Conversions/CheddarToEmitC/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
load("@rules_cc//cc:cc_library.bzl", "cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "CheddarToEmitC",
srcs = ["CheddarToEmitC.cpp"],
hdrs = ["CheddarToEmitC.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/Cheddar/IR:Dialect",
"@heir//lib/Utils:ConversionUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithToEmitC",
"@llvm-project//mlir:EmitCDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefToEmitC",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFToEmitC",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
],
)

add_heir_transforms(
header_filename = "CheddarToEmitC.h.inc",
pass_name = "CheddarToEmitC",
td_file = "CheddarToEmitC.td",
)
1,440 changes: 1,440 additions & 0 deletions lib/Conversions/CheddarToEmitC/CheddarToEmitC.cpp

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions lib/Conversions/CheddarToEmitC/CheddarToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_H_
#define LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_H_

#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir::heir {

// Attaches MemRefElementTypeInterface as an external (marker-only) model to
// emitc::OpaqueType. Needed so that the cheddar-to-emitc type converter can
// form `memref<Nx!emitc.opaque<...>>` as the converted form of
// `memref<Nx!cheddar.*>` after bufferization. Call once at tool startup.
void registerCheddarToEmitCExternalModels(DialectRegistry& registry);

#define GEN_PASS_DECL
#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Conversions/CheddarToEmitC/CheddarToEmitC.h.inc"

} // namespace mlir::heir

#endif // LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_H_
34 changes: 34 additions & 0 deletions lib/Conversions/CheddarToEmitC/CheddarToEmitC.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_TD_
#define LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_TD_

include "mlir/Pass/PassBase.td"

def CheddarToEmitC : Pass<"cheddar-to-emitc"> {
let summary = "Lower the cheddar dialect to EmitC.";

let description = [{
Translates each cheddar op into an out-parameter-style `Context`/
`UserInterface` method call expressed in the EmitC dialect.
`mlir-translate --mlir-to-cpp` renders the resulting IR as host-side
C++ against the CHEDDAR library API.

Bufferized loop kernels are handled in the same conversion: the SCF and
Arith EmitC patterns are run alongside the cheddar patterns so that the
shared type converter sees cheddar types (e.g. an `scf.for` carrying an
`!cheddar.*` iter_arg lowers to a move-assigned `emitc.variable`, and
loop-index values feed `memref` subscripts without a stranded
`index -> emitc.size_t` cast).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a doctest


(* example filepath=tests/Conversions/CheddarToEmitC/doctest.mlir *)
}];

let dependentDialects = [
"::mlir::arith::ArithDialect",
"::mlir::emitc::EmitCDialect",
"::mlir::func::FuncDialect",
"::mlir::memref::MemRefDialect",
"::mlir::scf::SCFDialect",
];
}

#endif // LIB_CONVERSIONS_CHEDDARTOEMITC_CHEDDARTOEMITC_TD_
1 change: 1 addition & 0 deletions lib/Dialect/Cheddar/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:TransformUtils",
],
)

Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/Cheddar/IR/CheddarDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/InliningUtils.h" // from @llvm-project

// NOLINTNEXTLINE(misc-include-cleaner): Required to define CheddarOps

Expand All @@ -22,6 +24,42 @@ namespace mlir {
namespace heir {
namespace cheddar {

namespace {
// Cheddar setup/keygen ops interact with stateful handles and (secret) key
// material rather than being pure value computations: create_context and
// create_user_interface mint distinct objects, prepare_rot_key generates a
// rotation key as a side effect, and encrypt/decrypt touch the user interface.
// Duplicating any of these into multiple call sites would create divergent
// contexts or redundantly (and observably) regenerate keys, so they must not be
// *cloned*. The pure ciphertext-algebra ops have value semantics and are always
// safe to clone.
bool isStatefulHandleOp(Operation* op) {
return isa<CreateContextOp, CreateUserInterfaceOp, PrepareRotKeyOp, EncryptOp,
DecryptOp>(op);
}

// Lets the inliner fold client functions that contain cheddar ops (e.g. the
// per-`func` preprocessing/compute decomposition back into a combined entry
// point). Without an interface the inliner treats every cheddar op as
// illegal-to-inline and blocks all such inlining.
struct CheddarInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;

// An op may be inlined as long as we are not duplicating a stateful op:
// moving (wouldBeCloned == false) preserves single execution and order and is
// always fine; cloning is only safe for the pure ciphertext-algebra ops.
bool isLegalToInline(Operation* op, Region*, bool wouldBeCloned,
IRMapping&) const final {
return !wouldBeCloned || !isStatefulHandleOp(op);
}
// Cheddar ops carry no nested regions or control flow, so a callee body is
// structurally inlinable; per-op cloning safety is enforced above.
bool isLegalToInline(Region*, Region*, bool, IRMapping&) const final {
return true;
}
};
} // namespace

void CheddarDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
Expand All @@ -32,6 +70,8 @@ void CheddarDialect::initialize() {
#define GET_OP_LIST
#include "lib/Dialect/Cheddar/IR/CheddarOps.cpp.inc"
>();

addInterfaces<CheddarInlinerInterface>();
}

} // namespace cheddar
Expand Down
71 changes: 67 additions & 4 deletions lib/Dialect/Cheddar/IR/CheddarOps.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "lib/Dialect/Cheddar/IR/CheddarOps.h"

#include <algorithm>

#include "lib/Dialect/Cheddar/IR/CheddarTypes.h"
#include "lib/Utils/RotationUtils.h"
#include "lib/Utils/Utils.h"
Expand Down Expand Up @@ -27,11 +29,15 @@ ::llvm::SmallVector<::mlir::OpFoldResult> HRotAddOp::getRotationIndices() {

::llvm::SmallVector<::mlir::OpFoldResult>
LinearTransformOp::getRotationIndices() {
auto diagonalsType = cast<RankedTensorType>(getDiagonals().getType());
// diagonals is TensorOrMemRef and is bufferized to a memref downstream, so
// match on ShapedType rather than RankedTensorType (which would crash post
// bufferization). The slot count is the row width (second dim).
auto diagonalsType = cast<ShapedType>(getDiagonals().getType());
int64_t slots = diagonalsType.getShape()[1];
int64_t logBSGS = getLogBabyStepGiantStepRatio().getInt();
auto rotations = lintransRotationIndices(
getDiagonalIndicesAttr().asArrayRef(), slots, logBSGS);
// Derive the key set from the op's baby-step count, the same value the
// emitter hands to CHEDDAR, so key generation and evaluation agree.
auto rotations = lintransRotationIndicesWithBabyStep(
getDiagonalIndicesAttr().asArrayRef(), slots, getBs().getInt());
SmallVector<OpFoldResult> result;
result.reserve(rotations.size());
auto* mlirCtx = (*this)->getContext();
Expand All @@ -41,6 +47,63 @@ LinearTransformOp::getRotationIndices() {
return result;
}

LogicalResult LinearTransformOp::verify() {
// `diagonals` is the matrix as a set of non-zero diagonals: one row per
// diagonal, each row `slots` wide. getRotationIndices() and the emitter both
// index getShape()[1], so a non-2D operand must be rejected here rather than
// crashing them.
auto diagonalsType = cast<ShapedType>(getDiagonals().getType());
if (diagonalsType.getRank() != 2)
return emitOpError(
"expected `diagonals` to be 2D (one row per diagonal), but got "
"rank ")
<< diagonalsType.getRank();

auto indices = getDiagonalIndicesAttr().asArrayRef();
int64_t numRows = diagonalsType.getShape()[0];
if (numRows != static_cast<int64_t>(indices.size()))
return emitOpError("expected one `diagonal_indices` entry per `diagonals` "
"row, but got ")
<< indices.size() << " indices for " << numRows << " rows";

int64_t bs = getBs().getInt();
int64_t gs = getGs().getInt();
if (bs < 1 || gs < 1)
return emitOpError("expected `bs` and `gs` to be >= 1, but got bs=")
<< bs << " gs=" << gs;

// The BSGS decomposition only reaches diagonals within the `bs * gs` grid;
// anything past it would be silently dropped. (Indices are non-negative,
// with wrap-around diagonals encoded as `slot - k`.)
if (!indices.empty()) {
int32_t maxIdx = *std::max_element(indices.begin(), indices.end());
if (bs * gs <= maxIdx)
return emitOpError("BSGS grid `bs * gs` (")
<< (bs * gs) << ") must exceed the largest diagonal index ("
<< maxIdx << ") so the decomposition covers every diagonal";
}
return success();
}

LogicalResult EvalPolyOp::verify() {
if (getCoefficients().empty())
return emitOpError("expected a non-empty `coefficients` array");

int64_t level = getLevel().getInt();
int64_t outputLevel = getOutputLevel().getInt();
if (level < 0 || outputLevel < 0)
return emitOpError(
"expected non-negative `level`/`outputLevel`, but got level=")
<< level << " outputLevel=" << outputLevel;
// EvalPoly consumes multiplicative depth, so the result lands at or below the
// input level -- it can never raise it.
if (outputLevel > level)
return emitOpError("expected `outputLevel` (")
<< outputLevel << ") <= `level` (" << level
<< "): the polynomial consumes depth, it cannot raise the level";
return success();
}

} // namespace cheddar
} // namespace heir
} // namespace mlir
Loading
Loading