Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {
LogicalResult matchAndRewrite(
ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<RNSPolynomialAttr>(op.getValue())) {
op->emitWarning(
"Native lowering for RNSPolynomialAttr is not implemented yet");
// TODO(#97): Implement native RNS constant lowering.
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto res = getCommonConversionInfo(op, typeConverter);
if (failed(res))
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Polynomial/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ cc_library(
"@heir//lib/Utils/Polynomial",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BytecodeOpInterface",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
],
)
Expand Down
50 changes: 45 additions & 5 deletions lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"

#include <cassert>
#include <cstdint>
#include <functional>
#include <string>

#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "lib/Dialect/RNS/IR/RNSTypes.h"
#include "lib/Utils/Polynomial/Polynomial.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project
Expand Down Expand Up @@ -36,6 +40,8 @@ void FloatPolynomialAttr::print(AsmPrinter& p) const {
template <typename MonomialType>
using ParseCoefficientFn = std::function<OptionalParseResult(MonomialType&)>;

constexpr unsigned attrAPIntBitWidth = 64;

/// Try to parse a monomial. If successful, populate the fields of the outparam
/// `monomial` with the results, and the `variable` outparam with the parsed
/// variable name. Sets shouldParseMore to true if the monomial is followed by
Expand All @@ -58,7 +64,7 @@ ParseResult parseMonomial(
if (!parsedCoeffResult.has_value()) {
return failure();
}
monomial.setExponent(APInt(apintBitWidth, 0));
monomial.setExponent(APInt(attrAPIntBitWidth, 0));
isConstantTerm = true;
shouldParseMore = true;
return success();
Expand All @@ -72,7 +78,7 @@ ParseResult parseMonomial(
return failure();
}

monomial.setExponent(APInt(apintBitWidth, 0));
monomial.setExponent(APInt(attrAPIntBitWidth, 0));
isConstantTerm = true;
return success();
}
Expand All @@ -87,7 +93,7 @@ ParseResult parseMonomial(
}

// If there's a **, then the integer exponent is required.
APInt parsedExponent(apintBitWidth, 0);
APInt parsedExponent(attrAPIntBitWidth, 0);
if (failed(parser.parseInteger(parsedExponent))) {
parser.emitError(parser.getCurrentLocation(),
"found invalid integer exponent");
Expand All @@ -96,7 +102,7 @@ ParseResult parseMonomial(

monomial.setExponent(parsedExponent);
} else {
monomial.setExponent(APInt(apintBitWidth, 1));
monomial.setExponent(APInt(attrAPIntBitWidth, 1));
}

if (succeeded(parser.parseOptionalPlus())) {
Expand Down Expand Up @@ -160,7 +166,7 @@ Attribute IntPolynomialAttr::parse(AsmParser& parser, Type type) {
if (failed(parsePolynomialAttr<IntMonomial>(
parser, monomials, variables,
[&](IntMonomial& monomial) -> OptionalParseResult {
APInt parsedCoeff(apintBitWidth, 1);
APInt parsedCoeff(attrAPIntBitWidth, 1);
OptionalParseResult result =
parser.parseOptionalInteger(parsedCoeff);
monomial.setCoefficient(parsedCoeff);
Expand Down Expand Up @@ -251,6 +257,40 @@ ::mlir::OpAsmDialectInterface::AliasResult RingAttr::getAlias(
return AliasResult::FinalAlias;
}

RNSPolynomial RNSPolynomialAttr::getPolynomial() const {
auto polyType = dyn_cast<PolynomialType>(getType());
assert(polyType && "expected PolynomialType");
auto rnsType =
dyn_cast<rns::RNSType>(polyType.getRing().getCoefficientType());
assert(rnsType && "expected RNSType");
auto basisTypes = rnsType.getBasisTypes();

SmallVector<uint64_t> moduli;
moduli.reserve(basisTypes.size());
for (Type b : basisTypes) {
auto modType = cast<mod_arith::ModArithType>(b);
moduli.push_back(modType.getModulus().getValue().getZExtValue());
}

[[maybe_unused]] auto shape = getCoefficients().getType().getShape();
assert(shape[0] == static_cast<int64_t>(moduli.size()) &&
"number of limbs must match RNS basis size");

auto valuesRange = getCoefficients().getValues<APInt>();
SmallVector<uint64_t> flatValues;
flatValues.reserve(valuesRange.size());
for (const APInt& v : valuesRange) {
flatValues.push_back(v.getZExtValue());
}

RNSPolynomial::Representation representation =
getForm() == Form::COEFF ? RNSPolynomial::Representation::Coefficient
: RNSPolynomial::Representation::NTT;

return RNSPolynomial(std::move(flatValues), std::move(moduli),
representation);
}

} // namespace polynomial
} // namespace heir
} // namespace mlir
4 changes: 2 additions & 2 deletions lib/Dialect/Polynomial/IR/PolynomialAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/Polynomial/IR/PolynomialDialect.h"
#include "lib/Dialect/Polynomial/IR/PolynomialEnums.h.inc"
#include "lib/Dialect/RNS/IR/RNSAttributes.h"
#include "lib/Utils/Polynomial/Polynomial.h"

#include "lib/Utils/Polynomial/RNSPolynomial.h"
#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h.inc"
#include "lib/Dialect/Polynomial/IR/PolynomialEnums.h.inc"

#endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_
20 changes: 19 additions & 1 deletion lib/Dialect/Polynomial/IR/PolynomialAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_ro
);
let assemblyFormat = "`<` struct(params) `>`";
}

def Polynomial_RNSPolynomialAttr : Polynomial_Attr<"RNSPolynomial", "rns_polynomial", [TypedAttrInterface]> {
let summary = "A polynomial represented in RNS form";
let description = [{
A polynomial attribute represented as an array of RNS limbs.
}];
let parameters = (ins
"::mlir::DenseIntElementsAttr":$coefficients,
"::mlir::Type":$type,
DefaultValuedParameter<
"::mlir::heir::polynomial::Form",
"::mlir::heir::polynomial::Form::COEFF"
>:$form
);
let assemblyFormat = "`<` $coefficients (`,` $form^)? `>` `:` $type";
let extraClassDeclaration = [{
using ValueType = ::mlir::DenseIntElementsAttr;
::mlir::heir::polynomial::RNSPolynomial getPolynomial() const;
}];
}

#endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_TD_
24 changes: 19 additions & 5 deletions lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
#include "lib/Utils/Polynomial/Polynomial.h"
#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

Expand Down Expand Up @@ -45,3 +49,13 @@ void PolynomialDialect::initialize() {
#include "lib/Dialect/Polynomial/IR/PolynomialOps.cpp.inc"
>();
}

Operation* PolynomialDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
Location loc) {
if (llvm::isa<TypedIntPolynomialAttr, TypedFloatPolynomialAttr,
TypedChebyshevPolynomialAttr, RNSPolynomialAttr>(value)) {
return ConstantOp::create(builder, loc, type, value);
}
return nullptr;
}
1 change: 1 addition & 0 deletions lib/Dialect/Polynomial/IR/PolynomialDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def Polynomial_Dialect : Dialect {

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
}

#endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_TD_
9 changes: 9 additions & 0 deletions lib/Dialect/Polynomial/IR/PolynomialOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,9 @@ LogicalResult EvalOp::verify() {
[&](TypedFloatPolynomialAttr floatAttr) {
return floatAttr.getValue().getPolynomial().getTerms().empty();
})
.Case<RNSPolynomialAttr>([&](RNSPolynomialAttr rnsAttr) {
return rnsAttr.getCoefficients().empty();
})
.Default([&](Attribute) { return false; });
if (empty) {
return emitError() << "Empty polynomials are not supported for eval op";
Expand Down Expand Up @@ -456,6 +459,10 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) {
result.addTypes(floatAttr.getType());
return success();
}
if (auto rnsAttr = dyn_cast<RNSPolynomialAttr>(attr)) {
result.addTypes(rnsAttr.getType());
return success();
}
return parser.emitError(loc, "expected a typed polynomial attribute");
}

Expand All @@ -471,6 +478,8 @@ void ConstantOp::print(OpAsmPrinter& p) {
} else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
p << "float";
floatPoly.getValue().print(p);
} else if (auto rnsPoly = dyn_cast<RNSPolynomialAttr>(getValue())) {
p.printAttribute(rnsPoly);
} else {
assert(false && "unexpected attribute type");
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Polynomial/IR/PolynomialOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,11 @@ def Polynomial_ModSwitchOp : Polynomial_Op<"mod_switch", [Pure, SameOperandsAndR
let hasVerifier = 1;
}


def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
Polynomial_TypedFloatPolynomialAttr,
Polynomial_TypedIntPolynomialAttr,
Polynomial_TypedChebyshevPolynomialAttr
Polynomial_TypedChebyshevPolynomialAttr,
Polynomial_RNSPolynomialAttr
]>;

def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure]> {
Expand Down
19 changes: 17 additions & 2 deletions lib/Utils/Polynomial/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,19 @@ package(

cc_library(
name = "Polynomial",
srcs = ["Polynomial.cpp"],
hdrs = ["Polynomial.h"],
srcs = [
"Polynomial.cpp",
"RNSPolynomial.cpp",
],
hdrs = [
"Polynomial.h",
"RNSPolynomial.h",
],
deps = [
":NTT",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
"@heir//lib/Utils:MathUtils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand All @@ -26,6 +36,11 @@ cc_test(
deps = [
":Polynomial",
"@googletest//:gtest_main",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:Dialect",
"@heir//lib/Dialect/RNS/IR:TypeInterfaces",
"@heir//lib/Utils:MathUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
Expand Down
Loading
Loading