diff --git a/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp b/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp index 272053f5f4..99a78b1111 100644 --- a/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp +++ b/lib/Dialect/Polynomial/Conversions/PolynomialToModArith/PolynomialToModArith.cpp @@ -291,6 +291,12 @@ struct ConvertConstant : public OpConversionPattern { LogicalResult matchAndRewrite( ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (isa(op.getValue())) { + op->emitWarning( + "Native lowering for RNSPolynomialAttr is not implemented yet"); + // TODO(#97): Implement native RNS constant lowering. + return failure(); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto res = getCommonConversionInfo(op, typeConverter); if (failed(res)) diff --git a/lib/Dialect/Polynomial/IR/BUILD b/lib/Dialect/Polynomial/IR/BUILD index 05afcb55eb..45738e0331 100644 --- a/lib/Dialect/Polynomial/IR/BUILD +++ b/lib/Dialect/Polynomial/IR/BUILD @@ -23,8 +23,10 @@ cc_library( "@heir//lib/Utils/Polynomial", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", ], ) diff --git a/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp index 704ab8829e..22359a7744 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp +++ b/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -1,8 +1,12 @@ #include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include +#include #include #include +#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h" +#include "lib/Dialect/RNS/IR/RNSTypes.h" #include "lib/Utils/Polynomial/Polynomial.h" #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project @@ -36,6 +40,8 @@ void FloatPolynomialAttr::print(AsmPrinter& p) const { template using ParseCoefficientFn = std::function; +constexpr unsigned attrAPIntBitWidth = 64; + /// Try to parse a monomial. If successful, populate the fields of the outparam /// `monomial` with the results, and the `variable` outparam with the parsed /// variable name. Sets shouldParseMore to true if the monomial is followed by @@ -58,7 +64,7 @@ ParseResult parseMonomial( if (!parsedCoeffResult.has_value()) { return failure(); } - monomial.setExponent(APInt(apintBitWidth, 0)); + monomial.setExponent(APInt(attrAPIntBitWidth, 0)); isConstantTerm = true; shouldParseMore = true; return success(); @@ -72,7 +78,7 @@ ParseResult parseMonomial( return failure(); } - monomial.setExponent(APInt(apintBitWidth, 0)); + monomial.setExponent(APInt(attrAPIntBitWidth, 0)); isConstantTerm = true; return success(); } @@ -87,7 +93,7 @@ ParseResult parseMonomial( } // If there's a **, then the integer exponent is required. - APInt parsedExponent(apintBitWidth, 0); + APInt parsedExponent(attrAPIntBitWidth, 0); if (failed(parser.parseInteger(parsedExponent))) { parser.emitError(parser.getCurrentLocation(), "found invalid integer exponent"); @@ -96,7 +102,7 @@ ParseResult parseMonomial( monomial.setExponent(parsedExponent); } else { - monomial.setExponent(APInt(apintBitWidth, 1)); + monomial.setExponent(APInt(attrAPIntBitWidth, 1)); } if (succeeded(parser.parseOptionalPlus())) { @@ -160,7 +166,7 @@ Attribute IntPolynomialAttr::parse(AsmParser& parser, Type type) { if (failed(parsePolynomialAttr( parser, monomials, variables, [&](IntMonomial& monomial) -> OptionalParseResult { - APInt parsedCoeff(apintBitWidth, 1); + APInt parsedCoeff(attrAPIntBitWidth, 1); OptionalParseResult result = parser.parseOptionalInteger(parsedCoeff); monomial.setCoefficient(parsedCoeff); @@ -251,6 +257,40 @@ ::mlir::OpAsmDialectInterface::AliasResult RingAttr::getAlias( return AliasResult::FinalAlias; } +RNSPolynomial RNSPolynomialAttr::getPolynomial() const { + auto polyType = dyn_cast(getType()); + assert(polyType && "expected PolynomialType"); + auto rnsType = + dyn_cast(polyType.getRing().getCoefficientType()); + assert(rnsType && "expected RNSType"); + auto basisTypes = rnsType.getBasisTypes(); + + SmallVector moduli; + moduli.reserve(basisTypes.size()); + for (Type b : basisTypes) { + auto modType = cast(b); + moduli.push_back(modType.getModulus().getValue().getZExtValue()); + } + + [[maybe_unused]] auto shape = getCoefficients().getType().getShape(); + assert(shape[0] == static_cast(moduli.size()) && + "number of limbs must match RNS basis size"); + + auto valuesRange = getCoefficients().getValues(); + SmallVector flatValues; + flatValues.reserve(valuesRange.size()); + for (const APInt& v : valuesRange) { + flatValues.push_back(v.getZExtValue()); + } + + RNSPolynomial::Representation representation = + getForm() == Form::COEFF ? RNSPolynomial::Representation::Coefficient + : RNSPolynomial::Representation::NTT; + + return RNSPolynomial(std::move(flatValues), std::move(moduli), + representation); +} + } // namespace polynomial } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Polynomial/IR/PolynomialAttributes.h b/lib/Dialect/Polynomial/IR/PolynomialAttributes.h index 3dee317b1e..0955d23dac 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialAttributes.h +++ b/lib/Dialect/Polynomial/IR/PolynomialAttributes.h @@ -3,11 +3,11 @@ #include "lib/Dialect/ModArith/IR/ModArithAttributes.h" #include "lib/Dialect/Polynomial/IR/PolynomialDialect.h" +#include "lib/Dialect/Polynomial/IR/PolynomialEnums.h.inc" #include "lib/Dialect/RNS/IR/RNSAttributes.h" #include "lib/Utils/Polynomial/Polynomial.h" - +#include "lib/Utils/Polynomial/RNSPolynomial.h" #define GET_ATTRDEF_CLASSES #include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h.inc" -#include "lib/Dialect/Polynomial/IR/PolynomialEnums.h.inc" #endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_ diff --git a/lib/Dialect/Polynomial/IR/PolynomialAttributes.td b/lib/Dialect/Polynomial/IR/PolynomialAttributes.td index afc02fd6be..07ccc03426 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/lib/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -260,6 +260,24 @@ def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_ro ); let assemblyFormat = "`<` struct(params) `>`"; } - +def Polynomial_RNSPolynomialAttr : Polynomial_Attr<"RNSPolynomial", "rns_polynomial", [TypedAttrInterface]> { + let summary = "A polynomial represented in RNS form"; + let description = [{ + A polynomial attribute represented as an array of RNS limbs. + }]; + let parameters = (ins + "::mlir::DenseIntElementsAttr":$coefficients, + "::mlir::Type":$type, + DefaultValuedParameter< + "::mlir::heir::polynomial::Form", + "::mlir::heir::polynomial::Form::COEFF" + >:$form + ); + let assemblyFormat = "`<` $coefficients (`,` $form^)? `>` `:` $type"; + let extraClassDeclaration = [{ + using ValueType = ::mlir::DenseIntElementsAttr; + ::mlir::heir::polynomial::RNSPolynomial getPolynomial() const; + }]; +} #endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_TD_ diff --git a/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp index 77ec61fad6..e1d7c57d6e 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp +++ b/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp @@ -7,14 +7,18 @@ #include "lib/Utils/Polynomial/Polynomial.h" #include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project @@ -45,3 +49,13 @@ void PolynomialDialect::initialize() { #include "lib/Dialect/Polynomial/IR/PolynomialOps.cpp.inc" >(); } + +Operation* PolynomialDialect::materializeConstant(OpBuilder& builder, + Attribute value, Type type, + Location loc) { + if (llvm::isa(value)) { + return ConstantOp::create(builder, loc, type, value); + } + return nullptr; +} diff --git a/lib/Dialect/Polynomial/IR/PolynomialDialect.td b/lib/Dialect/Polynomial/IR/PolynomialDialect.td index dff5b50c01..560b0acf4c 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialDialect.td +++ b/lib/Dialect/Polynomial/IR/PolynomialDialect.td @@ -43,6 +43,7 @@ def Polynomial_Dialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + let hasConstantMaterializer = 1; } #endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALDIALECT_TD_ diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index a893790a11..01ae8f66ac 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -389,6 +389,9 @@ LogicalResult EvalOp::verify() { [&](TypedFloatPolynomialAttr floatAttr) { return floatAttr.getValue().getPolynomial().getTerms().empty(); }) + .Case([&](RNSPolynomialAttr rnsAttr) { + return rnsAttr.getCoefficients().empty(); + }) .Default([&](Attribute) { return false; }); if (empty) { return emitError() << "Empty polynomials are not supported for eval op"; @@ -456,6 +459,10 @@ ParseResult ConstantOp::parse(OpAsmParser& parser, OperationState& result) { result.addTypes(floatAttr.getType()); return success(); } + if (auto rnsAttr = dyn_cast(attr)) { + result.addTypes(rnsAttr.getType()); + return success(); + } return parser.emitError(loc, "expected a typed polynomial attribute"); } @@ -471,6 +478,8 @@ void ConstantOp::print(OpAsmPrinter& p) { } else if (auto floatPoly = dyn_cast(getValue())) { p << "float"; floatPoly.getValue().print(p); + } else if (auto rnsPoly = dyn_cast(getValue())) { + p.printAttribute(rnsPoly); } else { assert(false && "unexpected attribute type"); } diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.td b/lib/Dialect/Polynomial/IR/PolynomialOps.td index 3678032de0..a00d493111 100644 --- a/lib/Dialect/Polynomial/IR/PolynomialOps.td +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.td @@ -291,11 +291,11 @@ def Polynomial_ModSwitchOp : Polynomial_Op<"mod_switch", [Pure, SameOperandsAndR let hasVerifier = 1; } - def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[ Polynomial_TypedFloatPolynomialAttr, Polynomial_TypedIntPolynomialAttr, - Polynomial_TypedChebyshevPolynomialAttr + Polynomial_TypedChebyshevPolynomialAttr, + Polynomial_RNSPolynomialAttr ]>; def Polynomial_ConstantOp : Op { diff --git a/lib/Utils/Polynomial/BUILD b/lib/Utils/Polynomial/BUILD index dcefd33784..9fc01b3b66 100644 --- a/lib/Utils/Polynomial/BUILD +++ b/lib/Utils/Polynomial/BUILD @@ -11,9 +11,19 @@ package( cc_library( name = "Polynomial", - srcs = ["Polynomial.cpp"], - hdrs = ["Polynomial.h"], + srcs = [ + "Polynomial.cpp", + "RNSPolynomial.cpp", + ], + hdrs = [ + "Polynomial.h", + "RNSPolynomial.h", + ], deps = [ + ":NTT", + "@heir//lib/Dialect/ModArith/IR:Dialect", + "@heir//lib/Dialect/RNS/IR:Dialect", + "@heir//lib/Utils:MathUtils", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -26,6 +36,11 @@ cc_test( deps = [ ":Polynomial", "@googletest//:gtest_main", + "@heir//lib/Dialect/ModArith/IR:Dialect", + "@heir//lib/Dialect/RNS/IR:Dialect", + "@heir//lib/Dialect/RNS/IR:TypeInterfaces", + "@heir//lib/Utils:MathUtils", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], ) diff --git a/lib/Utils/Polynomial/PolynomialTest.cpp b/lib/Utils/Polynomial/PolynomialTest.cpp index 34fab093dc..4da592a69a 100644 --- a/lib/Utils/Polynomial/PolynomialTest.cpp +++ b/lib/Utils/Polynomial/PolynomialTest.cpp @@ -1,7 +1,18 @@ #include "gtest/gtest.h" // from @googletest +#include "lib/Dialect/ModArith/IR/ModArithAttributes.h" +#include "lib/Dialect/ModArith/IR/ModArithDialect.h" +#include "lib/Dialect/ModArith/IR/ModArithTypes.h" +#include "lib/Dialect/RNS/IR/RNSAttributes.h" +#include "lib/Dialect/RNS/IR/RNSDialect.h" +#include "lib/Dialect/RNS/IR/RNSTypes.h" +#include "lib/Utils/MathUtils.h" #include "lib/Utils/Polynomial/Polynomial.h" -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "lib/Utils/Polynomial/RNSPolynomial.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace heir { @@ -133,6 +144,149 @@ TEST(PolynomialTest, TestComposeSkippingDegree) { EXPECT_EQ(expected, result); } +TEST(RNSPolynomialTest, Arithmetic) { + SmallVector moduli = {17, 13}; + SmallVector coeffs1 = {1, 2, 3, 4, 5, 6}; + SmallVector coeffs2 = {16, 15, 14, 12, 11, 10}; + + RNSPolynomial poly1(coeffs1, moduli); + RNSPolynomial poly2(coeffs2, moduli); + + RNSPolynomial sum = poly1.add(poly2); + SmallVector expectedSum = {0, 0, 0, 3, 3, 3}; + EXPECT_EQ(sum.getData(), llvm::ArrayRef(expectedSum)); + + RNSPolynomial diff = poly1.sub(poly2); + SmallVector expectedDiff = {2, 4, 6, 5, 7, 9}; + EXPECT_EQ(diff.getData(), llvm::ArrayRef(expectedDiff)); +} + +TEST(RNSPolynomialTest, TestRepresentation) { + SmallVector moduli = {17, 13}; + SmallVector coeffs = {1, 2, 3, 4, 5, 6}; + + // Default representation should be Coefficient + RNSPolynomial polyCoeff(coeffs, moduli); + EXPECT_EQ(polyCoeff.getRepresentation(), + RNSPolynomial::Representation::Coefficient); + EXPECT_FALSE(polyCoeff.isNtt()); + + // Explicit NTT representation + RNSPolynomial polyNtt(coeffs, moduli, RNSPolynomial::Representation::NTT); + EXPECT_EQ(polyNtt.getRepresentation(), RNSPolynomial::Representation::NTT); + EXPECT_TRUE(polyNtt.isNtt()); + + // Test that adding/subtracting mismatched representations asserts + EXPECT_DEBUG_DEATH(polyCoeff.add(polyNtt), + "Representations must match for arithmetic"); + EXPECT_DEBUG_DEATH(polyCoeff.sub(polyNtt), + "Representations must match for arithmetic"); +} + +TEST(RNSPolynomialTest, TestConversions) { + SmallVector moduli = {17, 41}; + SmallVector coeffs = {1, 2, 3, 4, 5, 6, 7, 8}; + + RNSPolynomial poly(coeffs, moduli); + + // Test round-trip toNtt -> toCoefficient + RNSPolynomial ntt = poly.toNtt(); + EXPECT_TRUE(ntt.isNtt()); + + RNSPolynomial roundtrip = ntt.toCoefficient(); + EXPECT_FALSE(roundtrip.isNtt()); + EXPECT_EQ(roundtrip, poly); +} + +TEST(RNSPolynomialTest, TestMul) { + SmallVector moduli = {17, 41}; + SmallVector coeffs1 = {1, 2, 0, 0, 3, 4, 0, 0}; + SmallVector coeffs2 = {5, 6, 0, 0, 7, 8, 0, 0}; + + RNSPolynomial poly1(coeffs1, moduli); + RNSPolynomial poly2(coeffs2, moduli); + + // Test multiplication in Coefficient form (uses NTT under the hood) + RNSPolynomial prodCoeff = poly1.mul(poly2); + EXPECT_FALSE(prodCoeff.isNtt()); + SmallVector expectedProd = {5, 16, 12, 0, 21, 11, 32, 0}; + EXPECT_EQ(prodCoeff.getData(), llvm::ArrayRef(expectedProd)); + + // Test multiplication in NTT form + RNSPolynomial ntt1 = poly1.toNtt(); + RNSPolynomial ntt2 = poly2.toNtt(); + RNSPolynomial prodNtt = ntt1.mul(ntt2); + EXPECT_TRUE(prodNtt.isNtt()); + EXPECT_EQ(prodNtt.toCoefficient(), prodCoeff); +} + +TEST(RNSPolynomialTest, TestPrecomputedRoots) { + mlir::MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + SmallVector moduli = {17, 41}; + SmallVector coeffs = {1, 2, 3, 4, 5, 6, 7, 8}; + + RNSPolynomial poly(coeffs, moduli); + + auto i64Type = mlir::IntegerType::get(&context, 64); + auto mod17Type = mlir::heir::mod_arith::ModArithType::get( + &context, mlir::IntegerAttr::get(i64Type, 17)); + auto mod41Type = mlir::heir::mod_arith::ModArithType::get( + &context, mlir::IntegerAttr::get(i64Type, 41)); + + auto rnsType = + mlir::heir::rns::RNSType::get(&context, {mod17Type, mod41Type}); + + // 1. Test with roots returned by findPrimitive2nthRoot (should match + // on-the-fly) + auto root16_17 = mlir::heir::findPrimitive2nthRoot(mlir::APInt(64, 17), 4); + auto root16_41 = mlir::heir::findPrimitive2nthRoot(mlir::APInt(64, 41), 4); + ASSERT_TRUE(root16_17.has_value()); + ASSERT_TRUE(root16_41.has_value()); + + auto root17Attr = mlir::heir::mod_arith::ModArithAttr::get( + &context, mod17Type, + mlir::IntegerAttr::get(i64Type, root16_17->getZExtValue())); + auto root41Attr = mlir::heir::mod_arith::ModArithAttr::get( + &context, mod41Type, + mlir::IntegerAttr::get(i64Type, root16_41->getZExtValue())); + + auto rnsAttr = + mlir::heir::rns::RNSAttr::get(rnsType, {root17Attr, root41Attr}); + + // Test toNtt with precomputed roots (matching on-the-fly) + RNSPolynomial ntt = poly.toNtt(rnsAttr); + EXPECT_TRUE(ntt.isNtt()); + + // Compare with on-the-fly computation + RNSPolynomial nttOnTheFly = poly.toNtt(); + EXPECT_EQ(ntt, nttOnTheFly); + + // Test toCoefficient with precomputed roots + RNSPolynomial roundtrip = ntt.toCoefficient(rnsAttr); + EXPECT_FALSE(roundtrip.isNtt()); + EXPECT_EQ(roundtrip, poly); + + // 2. Test with DIFFERENT valid roots (should round-trip, but might not match + // on-the-fly) We know 9 is primitive 8-th root mod 17, and 3 is primitive + // 8-th root mod 41. + auto diffRoot17Attr = mlir::heir::mod_arith::ModArithAttr::get( + &context, mod17Type, mlir::IntegerAttr::get(i64Type, 9)); + auto diffRoot41Attr = mlir::heir::mod_arith::ModArithAttr::get( + &context, mod41Type, mlir::IntegerAttr::get(i64Type, 3)); + auto diffRnsAttr = + mlir::heir::rns::RNSAttr::get(rnsType, {diffRoot17Attr, diffRoot41Attr}); + + RNSPolynomial nttDiff = poly.toNtt(diffRnsAttr); + EXPECT_TRUE(nttDiff.isNtt()); + + RNSPolynomial roundtripDiff = nttDiff.toCoefficient(diffRnsAttr); + EXPECT_FALSE(roundtripDiff.isNtt()); + EXPECT_EQ(roundtripDiff, poly); +} + } // namespace } // namespace polynomial } // namespace heir diff --git a/lib/Utils/Polynomial/RNSPolynomial.cpp b/lib/Utils/Polynomial/RNSPolynomial.cpp new file mode 100644 index 0000000000..7f41ffeb28 --- /dev/null +++ b/lib/Utils/Polynomial/RNSPolynomial.cpp @@ -0,0 +1,197 @@ +#include "lib/Utils/Polynomial/RNSPolynomial.h" + +#include +#include +#include + +#include "lib/Dialect/ModArith/IR/ModArithAttributes.h" +#include "lib/Dialect/RNS/IR/RNSAttributes.h" +#include "lib/Utils/MathUtils.h" +#include "lib/Utils/Polynomial/NTT.h" +#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace polynomial { + +RNSPolynomial::RNSPolynomial(llvm::SmallVector data, + llvm::SmallVector moduli, + Representation representation) + : data(std::move(data)), + moduli(std::move(moduli)), + representation(representation) { + assert(!this->moduli.empty() && "moduli cannot be empty"); + assert(this->data.size() % this->moduli.size() == 0 && + "numLimbs must divide data size"); + numCoeffs = this->data.size() / this->moduli.size(); +} + +RNSPolynomial RNSPolynomial::add(const RNSPolynomial& other) const { + assert(representation == other.representation && + "Representations must match for arithmetic"); + assert(moduli == other.moduli && "Moduli must match for addition"); + assert(numCoeffs == other.numCoeffs && "Number of coefficients must match"); + + llvm::SmallVector resultData; + resultData.reserve(data.size()); + + for (size_t limbIdx = 0; limbIdx < getNumLimbs(); ++limbIdx) { + uint64_t modulus = moduli[limbIdx]; + for (size_t coeffIdx = 0; coeffIdx < numCoeffs; ++coeffIdx) { + uint64_t a = getElement(limbIdx, coeffIdx); + uint64_t b = other.getElement(limbIdx, coeffIdx); + uint64_t sum = a + b; + if (sum >= modulus) { + sum -= modulus; + } + resultData.push_back(sum); + } + } + + return RNSPolynomial(std::move(resultData), moduli, representation); +} + +RNSPolynomial RNSPolynomial::sub(const RNSPolynomial& other) const { + assert(representation == other.representation && + "Representations must match for arithmetic"); + assert(moduli == other.moduli && "Moduli must match for subtraction"); + assert(numCoeffs == other.numCoeffs && "Number of coefficients must match"); + + llvm::SmallVector resultData; + resultData.reserve(data.size()); + + for (size_t limbIdx = 0; limbIdx < getNumLimbs(); ++limbIdx) { + uint64_t modulus = moduli[limbIdx]; + for (size_t coeffIdx = 0; coeffIdx < numCoeffs; ++coeffIdx) { + uint64_t a = getElement(limbIdx, coeffIdx); + uint64_t b = other.getElement(limbIdx, coeffIdx); + uint64_t diff = a; + if (diff < b) { + diff += modulus; + } + diff -= b; + resultData.push_back(diff); + } + } + + return RNSPolynomial(std::move(resultData), moduli, representation); +} + +RNSPolynomial RNSPolynomial::mul(const RNSPolynomial& other) const { + assert(moduli == other.moduli && "Moduli must match for multiplication"); + assert(numCoeffs == other.numCoeffs && "Number of coefficients must match"); + + if (representation == Representation::NTT && + other.representation == Representation::NTT) { + llvm::SmallVector resultData; + resultData.reserve(data.size()); + + for (size_t limbIdx = 0; limbIdx < getNumLimbs(); ++limbIdx) { + uint64_t modulus = moduli[limbIdx]; + for (size_t coeffIdx = 0; coeffIdx < numCoeffs; ++coeffIdx) { + uint64_t a = getElement(limbIdx, coeffIdx); + uint64_t b = other.getElement(limbIdx, coeffIdx); + unsigned __int128 prod = (unsigned __int128)a * b; + uint64_t res = prod % modulus; + resultData.push_back(res); + } + } + return RNSPolynomial(std::move(resultData), moduli, representation); + } + + if (representation == Representation::Coefficient && + other.representation == Representation::Coefficient) { + return toNtt().mul(other.toNtt()).toCoefficient(); + } + + assert(false && "Mismatched representations or unsupported conversion"); + return RNSPolynomial(); +} + +RNSPolynomial RNSPolynomial::toNtt(rns::RNSAttr rootAttr) const { + assert(representation == Representation::Coefficient && + "Already in NTT representation"); + + llvm::SmallVector resultData; + resultData.reserve(data.size()); + + if (rootAttr) { + assert(rootAttr.getValues().size() == getNumLimbs() && + "mismatch in number of limbs for root attribute"); + } + + for (size_t limbIdx = 0; limbIdx < getNumLimbs(); ++limbIdx) { + uint64_t modulus = moduli[limbIdx]; + uint64_t rootOfUnity; + + if (rootAttr) { + auto rootMA = + llvm::cast(rootAttr.getValues()[limbIdx]); + rootOfUnity = rootMA.getValue().getValue().getZExtValue(); + } else { + llvm::APInt qAp(64, modulus); + std::optional rootOpt = + findPrimitive2nthRoot(qAp, numCoeffs); + assert(rootOpt.has_value() && "Primitive 2n-th root of unity not found"); + rootOfUnity = rootOpt->getZExtValue(); + } + + std::vector limbCoeffs; + limbCoeffs.reserve(numCoeffs); + for (size_t coeffIdx = 0; coeffIdx < numCoeffs; ++coeffIdx) { + limbCoeffs.push_back(getElement(limbIdx, coeffIdx)); + } + + nttInPlace(limbCoeffs, modulus, rootOfUnity); + resultData.insert(resultData.end(), limbCoeffs.begin(), limbCoeffs.end()); + } + + return RNSPolynomial(std::move(resultData), moduli, Representation::NTT); +} + +RNSPolynomial RNSPolynomial::toCoefficient(rns::RNSAttr rootAttr) const { + assert(representation == Representation::NTT && + "Already in Coefficient representation"); + + llvm::SmallVector resultData; + resultData.reserve(data.size()); + + if (rootAttr) { + assert(rootAttr.getValues().size() == getNumLimbs() && + "mismatch in number of limbs for root attribute"); + } + + for (size_t limbIdx = 0; limbIdx < getNumLimbs(); ++limbIdx) { + uint64_t modulus = moduli[limbIdx]; + uint64_t rootOfUnity; + + if (rootAttr) { + auto rootMA = + llvm::cast(rootAttr.getValues()[limbIdx]); + rootOfUnity = rootMA.getValue().getValue().getZExtValue(); + } else { + llvm::APInt qAp(64, modulus); + std::optional rootOpt = + findPrimitive2nthRoot(qAp, numCoeffs); + assert(rootOpt.has_value() && "Primitive 2n-th root of unity not found"); + rootOfUnity = rootOpt->getZExtValue(); + } + + std::vector limbCoeffs; + limbCoeffs.reserve(numCoeffs); + for (size_t coeffIdx = 0; coeffIdx < numCoeffs; ++coeffIdx) { + limbCoeffs.push_back(getElement(limbIdx, coeffIdx)); + } + + inttInPlace(limbCoeffs, modulus, rootOfUnity); + resultData.insert(resultData.end(), limbCoeffs.begin(), limbCoeffs.end()); + } + + return RNSPolynomial(std::move(resultData), moduli, + Representation::Coefficient); +} + +} // namespace polynomial +} // namespace heir +} // namespace mlir diff --git a/lib/Utils/Polynomial/RNSPolynomial.h b/lib/Utils/Polynomial/RNSPolynomial.h new file mode 100644 index 0000000000..57580676a0 --- /dev/null +++ b/lib/Utils/Polynomial/RNSPolynomial.h @@ -0,0 +1,107 @@ +#ifndef LIB_UTILS_POLYNOMIAL_RNSPOLYNOMIAL_H_ +#define LIB_UTILS_POLYNOMIAL_RNSPOLYNOMIAL_H_ + +#include "lib/Dialect/RNS/IR/RNSAttributes.h" +#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project +#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace polynomial { + +/// A class representing a polynomial in Residue Number System (RNS) form. +/// +/// An RNS polynomial is represented by its residues modulo a set of coprime +/// moduli (limbs). For a polynomial of degree < N and an RNS basis of k moduli +/// [q_0, ..., q_{k-1}], the polynomial is stored as k independent polynomials +/// (limbs), where the i-th limb is a polynomial with coefficients modulo q_i. +/// +/// The coefficients of all limbs are stored in a single flat array (`data`) +/// in limb-major order: +/// [limb_0_coeff_0, limb_0_coeff_1, ..., limb_0_coeff_{N-1}, +/// limb_1_coeff_0, ..., limb_{k-1}_coeff_{N-1}] +/// +/// An RNS polynomial can be in either Coefficient representation (storing the +/// polynomial coefficients directly) or NTT representation (storing the +/// evaluations/slots of the polynomial at the roots of unity). Arithmetic +/// operations (add, sub) are only valid between polynomials in the same +/// representation. +class RNSPolynomial { + public: + /// The representation form of the RNS polynomial. + enum class Representation { Coefficient, NTT }; + + RNSPolynomial() = default; + RNSPolynomial(llvm::SmallVector data, + llvm::SmallVector moduli, + Representation representation = Representation::Coefficient); + + /// Returns the flat data array. + llvm::ArrayRef getData() const { return data; } + + /// Returns the moduli (RNS basis). + llvm::ArrayRef getModuli() const { return moduli; } + + /// Returns the number of limbs (moduli) in the RNS basis. + size_t getNumLimbs() const { return moduli.size(); } + + /// Returns the number of coefficients (or slots) per limb. + unsigned getNumCoeffs() const { return numCoeffs; } + + /// Returns the representation form of the polynomial. + Representation getRepresentation() const { return representation; } + + /// Returns true if the polynomial is in NTT representation. + bool isNtt() const { return representation == Representation::NTT; } + + /// Returns the element (coefficient or evaluation slot) at the given limb + /// and coefficient index. + uint64_t getElement(size_t limbIdx, size_t coeffIdx) const { + return data[limbIdx * numCoeffs + coeffIdx]; + } + + /// Performs modular addition limb-wise. + RNSPolynomial add(const RNSPolynomial& other) const; + + /// Performs modular subtraction limb-wise. + RNSPolynomial sub(const RNSPolynomial& other) const; + + /// Performs modular multiplication limb-wise. + RNSPolynomial mul(const RNSPolynomial& other) const; + + /// Convert the polynomial to NTT representation. + RNSPolynomial toNtt(rns::RNSAttr rootAttr = nullptr) const; + + /// Convert the polynomial to Coefficient representation. + RNSPolynomial toCoefficient(rns::RNSAttr rootAttr = nullptr) const; + + bool operator==(const RNSPolynomial& other) const { + return data == other.data && moduli == other.moduli && + numCoeffs == other.numCoeffs && + representation == other.representation; + } + bool operator!=(const RNSPolynomial& other) const { + return !(*this == other); + } + + private: + /// Flat array holding the polynomial data. + /// In Coefficient form, this holds the coefficients in limb-major order. + /// In NTT form, this holds the evaluations (slots) in limb-major order. + llvm::SmallVector data; + + /// The moduli (RNS basis). + llvm::SmallVector moduli; + + unsigned numCoeffs = 0; + + /// The representation form of the polynomial. + Representation representation = Representation::Coefficient; +}; + +} // namespace polynomial +} // namespace heir +} // namespace mlir + +#endif // LIB_UTILS_POLYNOMIAL_RNSPOLYNOMIAL_H_ diff --git a/tests/Dialect/Polynomial/Conversions/polynomial_to_mod_arith/rns_constant_failure.mlir b/tests/Dialect/Polynomial/Conversions/polynomial_to_mod_arith/rns_constant_failure.mlir new file mode 100644 index 0000000000..9f09d2287f --- /dev/null +++ b/tests/Dialect/Polynomial/Conversions/polynomial_to_mod_arith/rns_constant_failure.mlir @@ -0,0 +1,18 @@ +// RUN: heir-opt --polynomial-to-mod-arith --verify-diagnostics %s + +!rns_basis_0 = !mod_arith.int<17 : i32> +!rns_basis_1 = !mod_arith.int<19 : i32> +!rns_ty = !rns.rns + +#poly_mod = #polynomial.int_polynomial<1 + x**4> +#ring = #polynomial.ring +!rns_poly_ty = !polynomial.polynomial + +#rns_poly = #polynomial.rns_polynomial : tensor<2x2xi32>> : !rns_poly_ty + +func.func @test_rns_poly_failure() -> !rns_poly_ty { + // expected-warning@+2 {{Native lowering for RNSPolynomialAttr is not implemented yet}} + // expected-error@+1 {{failed to legalize operation}} + %0 = polynomial.constant #rns_poly + return %0 : !rns_poly_ty +} diff --git a/tests/Dialect/Polynomial/IR/attribute_parsing.mlir b/tests/Dialect/Polynomial/IR/attribute_parsing.mlir new file mode 100644 index 0000000000..e0fc14d8a4 --- /dev/null +++ b/tests/Dialect/Polynomial/IR/attribute_parsing.mlir @@ -0,0 +1,31 @@ +// RUN: heir-opt %s | FileCheck %s + +!rns_basis_0 = !mod_arith.int<17 : i32> +!rns_basis_1 = !mod_arith.int<19 : i32> +!rns_ty = !rns.rns +!rns_poly_ty = !polynomial.polynomial> + +#rns_poly = #polynomial.rns_polynomial : tensor<2x2xi32>> : !rns_poly_ty +#rns_poly_coeff = #polynomial.rns_polynomial : tensor<2x2xi32>, coeff> : !rns_poly_ty +#rns_poly_eval = #polynomial.rns_polynomial : tensor<2x2xi32>, eval> : !rns_poly_ty + +// CHECK: func @test_rns_poly +func.func @test_rns_poly() -> !rns_poly_ty { + // CHECK: polynomial.constant #polynomial : tensor<2x2xi32>> : {{![a-zA-Z0-9_]+}}> : {{![a-zA-Z0-9_]+}} + %0 = polynomial.constant #rns_poly + return %0 : !rns_poly_ty +} + +// CHECK: func @test_rns_poly_coeff +func.func @test_rns_poly_coeff() -> !rns_poly_ty { + // CHECK: polynomial.constant #polynomial : tensor<2x2xi32>> : {{![a-zA-Z0-9_]+}}> : {{![a-zA-Z0-9_]+}} + %0 = polynomial.constant #rns_poly_coeff + return %0 : !rns_poly_ty +} + +// CHECK: func @test_rns_poly_eval +func.func @test_rns_poly_eval() -> !rns_poly_ty { + // CHECK: polynomial.constant #polynomial : tensor<2x2xi32>, eval> : {{![a-zA-Z0-9_]+}}> : {{![a-zA-Z0-9_]+}} + %0 = polynomial.constant #rns_poly_eval + return %0 : !rns_poly_ty +}