diff --git a/WORKSPACE b/WORKSPACE index c411341ff0..453b3e1fc6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -109,3 +109,22 @@ http_archive( load("@googletest//:googletest_deps.bzl", "googletest_deps") googletest_deps() + +# compile_commands extracts the relevant compile data from bazel into +# `compile_commands.json` so that clangd, clang-tidy, etc., can use it. +# Whenever a build file changes, you must re-run +# +# bazel run @hedron_compile_commands//:refresh_all +# +# to ingest new data into these tools. +# +# See the project repo for more details and configuration options +# https://github.com/hedronvision/bazel-compile-commands-extractor +http_archive( + name = "hedron_compile_commands", + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/3dddf205a1f5cde20faf2444c1757abe0564ff4c.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-3dddf205a1f5cde20faf2444c1757abe0564ff4c", + sha256 = "3cd0e49f0f4a6d406c1d74b53b7616f5e24f5fd319eafc1bf8eee6e14124d115", +) +load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") +hedron_compile_commands_setup() diff --git a/include/Dialect/Poly/IR/BUILD b/include/Dialect/Poly/IR/BUILD index 2623c86869..612ee73d3f 100644 --- a/include/Dialect/Poly/IR/BUILD +++ b/include/Dialect/Poly/IR/BUILD @@ -9,15 +9,19 @@ package( exports_files( [ + "PolyAttributes.h", "PolyDialect.h", "PolyOps.h", "PolyTypes.h", + "Polynomial.h", + "PolynomialDetail.h", ], ) td_library( name = "td_files", srcs = [ + "PolyAttributes.td", "PolyDialect.td", "PolyOps.td", "PolyTypes.td", @@ -52,6 +56,34 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "attributes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-attrdef-decls", + ], + "PolyAttributes.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + ], + "PolyAttributes.cpp.inc", + ), + ( + ["-gen-attrdef-doc"], + "PolyAttributes.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "PolyAttributes.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + gentbl_cc_library( name = "types_inc_gen", tbl_outs = [ diff --git a/include/Dialect/Poly/IR/PolyAttributes.h b/include/Dialect/Poly/IR/PolyAttributes.h new file mode 100644 index 0000000000..0f9050ed1c --- /dev/null +++ b/include/Dialect/Poly/IR/PolyAttributes.h @@ -0,0 +1,10 @@ +#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_H_ +#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_H_ + +#include "include/Dialect/Poly/IR/PolyDialect.h" +#include "include/Dialect/Poly/IR/Polynomial.h" + +#define GET_ATTRDEF_CLASSES +#include "include/Dialect/Poly/IR/PolyAttributes.h.inc" + +#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_H_ diff --git a/include/Dialect/Poly/IR/PolyAttributes.td b/include/Dialect/Poly/IR/PolyAttributes.td new file mode 100644 index 0000000000..0a83eafad8 --- /dev/null +++ b/include/Dialect/Poly/IR/PolyAttributes.td @@ -0,0 +1,41 @@ +#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_TD_ +#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_TD_ + +include "PolyDialect.td" + +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" + +class Poly_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +class Poly_Attr_With_Custom_Format traits = []> + : AttrDef { + let mnemonic = ?; +} + +def Polynomial_Attr : Poly_Attr<"Polynomial", "polynomial"> { + let summary = "An attribute containing a polynomial."; + let description = [{ + #poly = #poly.poly + }]; + + let parameters = (ins "Polynomial":$value); + + let builders = [ + AttrBuilderWithInferredContext<(ins "Polynomial":$value), [{ + return $_get(value.getContext(), value); + }]> + ]; + let extraClassDeclaration = [{ + using ValueType = Polynomial; + Polynomial getPolynomial() const { return getValue(); } + }]; + + let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; +} + +#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_TD_ diff --git a/include/Dialect/Poly/IR/PolyDialect.td b/include/Dialect/Poly/IR/PolyDialect.td index 4744cd9312..4bab0bc98a 100644 --- a/include/Dialect/Poly/IR/PolyDialect.td +++ b/include/Dialect/Poly/IR/PolyDialect.td @@ -33,6 +33,7 @@ def Poly_Dialect : Dialect { let cppNamespace = "::mlir::heir::poly"; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; } #endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYDIALECT_TD_ diff --git a/include/Dialect/Poly/IR/PolyOps.td b/include/Dialect/Poly/IR/PolyOps.td index ceba500c72..6ec4b3cc83 100644 --- a/include/Dialect/Poly/IR/PolyOps.td +++ b/include/Dialect/Poly/IR/PolyOps.td @@ -1,6 +1,7 @@ #ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYOPS_TD_ #define HEIR_INCLUDE_DIALECT_POLY_IR_POLYOPS_TD_ +include "PolyAttributes.td" include "PolyDialect.td" include "PolyTypes.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -56,7 +57,11 @@ def Poly_MulOp : Poly_Op<"mul", [SameOperandsAndResultType]> { let summary = "Multiplication operation between polynomials."; let arguments = (ins - Variadic:$x + Poly:$lhs, + Poly:$rhs, + // TODO: upgrade this to a `#polynomial.ring` attribute that includes + // both the coefficient modulus and the polynomial modulus + Polynomial_Attr:$modulus ); let results = (outs diff --git a/include/Dialect/Poly/IR/PolyTypes.h b/include/Dialect/Poly/IR/PolyTypes.h index 5f9f4717c3..39b81dee3d 100644 --- a/include/Dialect/Poly/IR/PolyTypes.h +++ b/include/Dialect/Poly/IR/PolyTypes.h @@ -2,6 +2,7 @@ #define HEIR_INCLUDE_DIALECT_POLY_IR_POLYTYPES_H_ #include "include/Dialect/Poly/IR/PolyDialect.h" +#include "include/Dialect/Poly/IR/PolyAttributes.h" #define GET_TYPEDEF_CLASSES #include "include/Dialect/Poly/IR/PolyTypes.h.inc" diff --git a/include/Dialect/Poly/IR/PolyTypes.td b/include/Dialect/Poly/IR/PolyTypes.td index 6ff90afe8a..3140501315 100644 --- a/include/Dialect/Poly/IR/PolyTypes.td +++ b/include/Dialect/Poly/IR/PolyTypes.td @@ -2,14 +2,11 @@ #define HEIR_INCLUDE_DIALECT_POLY_IR_POLYTYPES_TD_ include "PolyDialect.td" +include "PolyAttributes.td" include "mlir/IR/DialectBase.td" include "mlir/IR/AttrTypeBase.td" -//===----------------------------------------------------------------------===// -// Poly type definitions -//===----------------------------------------------------------------------===// - // A base class for all types in this dialect class Poly_Type : TypeDef { @@ -21,10 +18,8 @@ def Poly : Poly_Type<"Polynomial", "poly"> { let description = [{ A type for polynomials in a polynomial quotient ring. - - The type is parametrized by the degree of the polynomial - }]; + } #endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYTYPES_TD_ diff --git a/include/Dialect/Poly/IR/Polynomial.h b/include/Dialect/Poly/IR/Polynomial.h new file mode 100644 index 0000000000..69c514493f --- /dev/null +++ b/include/Dialect/Poly/IR/Polynomial.h @@ -0,0 +1,144 @@ +#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIAL_H_ +#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIAL_H_ + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { + +class MLIRContext; + +namespace heir { +namespace poly { + +constexpr unsigned APINT_BIT_WIDTH = 64; + +namespace detail { +struct PolynomialStorage; +} // namespace detail + +class Monomial { + public: + Monomial(int64_t coeff, uint64_t expo) + : coefficient(APINT_BIT_WIDTH, coeff), exponent(APINT_BIT_WIDTH, expo) {} + + Monomial(APInt coeff, APInt expo) : coefficient(coeff), exponent(expo) {} + + Monomial() : coefficient(APINT_BIT_WIDTH, 0), exponent(APINT_BIT_WIDTH, 0) {} + + bool operator==(Monomial other) const { + return other.coefficient == coefficient && other.exponent == exponent; + } + bool operator!=(Monomial other) const { + return other.coefficient != coefficient || other.exponent != exponent; + } + + /// Monomials are ordered by exponent. + bool operator<(const Monomial &other) const { + return (exponent.ule(other.exponent)); + } + + // Prints polynomial to 'os'. + void print(raw_ostream &os) const; + + friend ::llvm::hash_code hash_value(Monomial arg); + + public: + APInt coefficient; + + // Always unsigned + APInt exponent; +}; + +/// A single-variable polynomial with integer coefficients. Polynomials are +/// immutable and uniqued. +/// +/// Eg: x^1024 + x + 1 +/// +/// The symbols used as the polynomial's indeterminate don't matter, so long as +/// it is used consistently throughout the polynomial. +class Polynomial { + public: + using ImplType = detail::PolynomialStorage; + + constexpr Polynomial() = default; + explicit Polynomial(ImplType *terms) : terms(terms) {} + + static Polynomial fromMonomials(ArrayRef monomials, + MLIRContext *context); + /// Returns a polynomial with coefficients given by `coeffs` + static Polynomial fromCoefficients(ArrayRef coeffs, + MLIRContext *context); + + /// Builds a monomial of the given degree. + static Polynomial monomialOfDegree(uint64_t degree, MLIRContext *context); + + MLIRContext *getContext() const; + + explicit operator bool() const { return terms != nullptr; } + bool operator==(Polynomial other) const { return other.terms == terms; } + bool operator!=(Polynomial other) const { return !(other.terms == terms); } + + Polynomial operator+(Polynomial other) const; + + // Prints polynomial to 'os'. + void print(raw_ostream &os) const; + void dump() const; + + ArrayRef getTerms() const; + + unsigned getDegree() const; + + friend ::llvm::hash_code hash_value(Polynomial arg); + + private: + ImplType *terms{nullptr}; +}; + +// Make Polynomial hashable. +inline ::llvm::hash_code hash_value(Polynomial arg) { + return ::llvm::hash_value(arg.terms); +} + +inline ::llvm::hash_code hash_value(Monomial arg) { + return ::llvm::hash_value(arg.coefficient) ^ ::llvm::hash_value(arg.exponent); +} + +inline raw_ostream &operator<<(raw_ostream &os, Polynomial polynomial) { + polynomial.print(os); + return os; +} + +} // namespace poly +} // namespace heir +} // namespace mlir + +namespace llvm { + +// Polynomials hash just like pointers +template <> +struct DenseMapInfo { + static mlir::heir::poly::Polynomial getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::heir::poly::Polynomial( + static_cast(pointer)); + } + static mlir::heir::poly::Polynomial getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::heir::poly::Polynomial( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::heir::poly::Polynomial val) { + return mlir::heir::poly::hash_value(val); + } + static bool isEqual(mlir::heir::poly::Polynomial LHS, + mlir::heir::poly::Polynomial RHS) { + return LHS == RHS; + } +}; + +} // namespace llvm + +#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIAL_H_ diff --git a/include/Dialect/Poly/IR/PolynomialDetail.h b/include/Dialect/Poly/IR/PolynomialDetail.h new file mode 100644 index 0000000000..93f3e498e0 --- /dev/null +++ b/include/Dialect/Poly/IR/PolynomialDetail.h @@ -0,0 +1,58 @@ +#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIALDETAIL_H_ +#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIALDETAIL_H_ + +#include "include/Dialect/Poly/IR/Polynomial.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/TrailingObjects.h" +#include "mlir/Support/StorageUniquer.h" +#include "mlir/Support/TypeID.h" + +namespace mlir { +namespace heir { +namespace poly { +namespace detail { + +// A Polynomial is stored as an ordered list of monomial terms, each of which +// is a tuple of coefficient and exponent. +struct PolynomialStorage final + : public StorageUniquer::BaseStorage, + public llvm::TrailingObjects { + /// The hash key used for uniquing. + using KeyTy = std::tuple>; + + unsigned numTerms; + + MLIRContext *context; + + /// The monomial terms for this polynomial. + ArrayRef terms() const { + return {getTrailingObjects(), numTerms}; + } + + bool operator==(const KeyTy &key) const { + return std::get<0>(key) == numTerms && std::get<1>(key) == terms(); + } + + // Constructs a PolynomialStorage from a key. The context must be set by the + // caller. + static PolynomialStorage *construct( + StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { + auto terms = std::get<1>(key); + auto byteSize = PolynomialStorage::totalSizeToAlloc(terms.size()); + auto *rawMem = allocator.allocate(byteSize, alignof(PolynomialStorage)); + auto *res = new (rawMem) PolynomialStorage(); + res->numTerms = std::get<0>(key); + std::uninitialized_copy(terms.begin(), terms.end(), + res->getTrailingObjects()); + return res; + } +}; + +} // namespace detail +} // namespace poly +} // namespace heir +} // namespace mlir + +MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::heir::poly::detail::PolynomialStorage) + +#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIALDETAIL_H_ diff --git a/lib/Dialect/Poly/IR/BUILD b/lib/Dialect/Poly/IR/BUILD index c57c3ec619..ef6ebafa3c 100644 --- a/lib/Dialect/Poly/IR/BUILD +++ b/lib/Dialect/Poly/IR/BUILD @@ -11,12 +11,16 @@ cc_library( "PolyDialect.cpp", ], hdrs = [ + "@heir//include/Dialect/Poly/IR:PolyAttributes.h", "@heir//include/Dialect/Poly/IR:PolyDialect.h", "@heir//include/Dialect/Poly/IR:PolyOps.h", "@heir//include/Dialect/Poly/IR:PolyTypes.h", ], includes = ["@heir//include"], deps = [ + ":PolyAttributes", + ":Polynomial", + "@heir//include/Dialect/Poly/IR:attributes_inc_gen", "@heir//include/Dialect/Poly/IR:dialect_inc_gen", "@heir//include/Dialect/Poly/IR:ops_inc_gen", "@heir//include/Dialect/Poly/IR:types_inc_gen", @@ -25,3 +29,49 @@ cc_library( "@llvm-project//mlir:InferTypeOpInterface", ], ) + +cc_library( + name = "PolyAttributes", + srcs = [ + "PolyAttributes.cpp", + ], + hdrs = [ + "@heir//include/Dialect/Poly/IR:PolyAttributes.h", + "@heir//include/Dialect/Poly/IR:PolyDialect.h", + ], + deps = [ + ":Polynomial", + "@heir//include/Dialect/Poly/IR:attributes_inc_gen", + "@heir//include/Dialect/Poly/IR:dialect_inc_gen", + "@llvm-project//mlir:AsmParser", + ], +) + +cc_library( + name = "Polynomial", + srcs = [ + "Polynomial.cpp", + ], + hdrs = [ + "@heir//include/Dialect/Poly/IR:Polynomial.h", + "@heir//include/Dialect/Poly/IR:PolynomialDetail.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_test( + name = "PolynomialTest", + size = "small", + srcs = ["PolynomialTest.cpp"], + deps = [ + ":Polynomial", + "@googletest//:gtest", + "@googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/Poly/IR/PolyAttributes.cpp b/lib/Dialect/Poly/IR/PolyAttributes.cpp new file mode 100644 index 0000000000..89b473e712 --- /dev/null +++ b/lib/Dialect/Poly/IR/PolyAttributes.cpp @@ -0,0 +1,144 @@ +#include "include/Dialect/Poly/IR/PolyAttributes.h" + +#include "llvm/include/llvm/ADT/SmallSet.h" +#include "llvm/include/llvm/ADT/StringExtras.h" +#include + +namespace mlir { +namespace heir { +namespace poly { + +void PolynomialAttr::print(AsmPrinter &p) const { + p << '<'; + p << getPolynomial(); + p << '>'; +} + +/// Try to parse a monomial. If successful, populate the fields of the outparam +/// `monomial` with the results, and the `variable` outparam with the parsed +/// variable name. +ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, + llvm::StringRef *variable, bool *isConstantTerm) { + APInt parsedCoeff(APINT_BIT_WIDTH, 0); + auto result = parser.parseOptionalInteger(parsedCoeff); + if (!result.has_value()) { + parsedCoeff = APInt(APINT_BIT_WIDTH, 1); + } else { + if (failed(*result)) { + parser.emitError(parser.getCurrentLocation(), + "Invalid integer coefficient."); + return failure(); + } + } + + // Variable name + result = parser.parseOptionalKeyword(variable); + if (!result.has_value() || failed(*result)) { + // we allow "failed" because it triggers when the next token is a +, + // which is allowed when the input is the constant term. + monomial.coefficient = parsedCoeff; + monomial.exponent = APInt(APINT_BIT_WIDTH, 0); + *isConstantTerm = true; + return success(); + } + + // Parse exponentiation symbol as ** + // We can't use caret because it's reserved for basic block identifiers + // If no star is present, it's treated as a polynomial with exponent 1 + if (failed(parser.parseOptionalStar())) { + monomial.coefficient = parsedCoeff; + monomial.exponent = APInt(APINT_BIT_WIDTH, 1); + return success(); + } + + // If there's one * there must be two + if (failed(parser.parseStar())) { + parser.emitError(parser.getCurrentLocation(), + "Exponents must be specified as a double-asterisk `**`."); + return failure(); + } + + // If there's a **, then the integer exponent is required. + APInt parsedExponent(APINT_BIT_WIDTH, 0); + if (failed(parser.parseInteger(parsedExponent))) { + parser.emitError(parser.getCurrentLocation(), + "Found invalid integer exponent."); + return failure(); + } + + monomial.coefficient = parsedCoeff; + monomial.exponent = parsedExponent; + return success(); +} + +mlir::Attribute mlir::heir::poly::PolynomialAttr::parse(AsmParser &parser, + Type type) { + if (failed(parser.parseLess())) + return {}; + + std::vector monomials; + llvm::SmallSet variables; + llvm::DenseSet exponents; + + while (true) { + Monomial parsedMonomial; + llvm::StringRef parsedVariableRef; + bool isConstantTerm = false; + if (failed(parseMonomial(parser, parsedMonomial, &parsedVariableRef, + &isConstantTerm))) { + return {}; + } + + if (!isConstantTerm) { + std::string parsedVariable = parsedVariableRef.str(); + variables.insert(parsedVariable); + } + monomials.push_back(parsedMonomial); + + if (exponents.count(parsedMonomial.exponent) > 0) { + llvm::SmallString<512> coeff_string; + parsedMonomial.exponent.toStringSigned(coeff_string); + parser.emitError(parser.getCurrentLocation(), + "At most one monomial may have exponent " + + coeff_string + ", but found multiple."); + return {}; + } + exponents.insert(parsedMonomial.exponent); + + // Parse optional +. If a + is absent, require > and break, otherwise forbid + // > and continue with the next monomial. + // ParseOptional{Plus, Greater} does not return an OptionalParseResult, so + // failed means that the token was not found. + if (failed(parser.parseOptionalPlus())) { + if (succeeded(parser.parseGreater())) { + break; + } else { + parser.emitError( + parser.getCurrentLocation(), + "Expected + and more monomials, or > to end polynomial attribute."); + return {}; + } + } else if (succeeded(parser.parseOptionalGreater())) { + parser.emitError( + parser.getCurrentLocation(), + "Expected another monomial after +, but found > ending attribute."); + return {}; + } + } + + if (variables.size() > 1) { + std::string vars = llvm::join(variables.begin(), variables.end(), ", "); + parser.emitError( + parser.getCurrentLocation(), + "Polynomials must have one indeterminate, but there were multiple: " + + vars); + } + + Polynomial poly = + Polynomial::fromMonomials(std::move(monomials), parser.getContext()); + return PolynomialAttr::get(poly); +} + +} // namespace poly +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Poly/IR/PolyDialect.cpp b/lib/Dialect/Poly/IR/PolyDialect.cpp index 1c24738067..9e3e61f1ec 100644 --- a/lib/Dialect/Poly/IR/PolyDialect.cpp +++ b/lib/Dialect/Poly/IR/PolyDialect.cpp @@ -1,12 +1,16 @@ -#include "include/Dialect/Poly/IR/PolyDialect.h" -#include "include/Dialect/Poly/IR/PolyTypes.h" -#include "include/Dialect/Poly/IR/PolyOps.h" - #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project +#include "include/Dialect/Poly/IR/PolyAttributes.h" +#include "include/Dialect/Poly/IR/PolyDialect.h" +#include "include/Dialect/Poly/IR/PolyOps.h" +#include "include/Dialect/Poly/IR/PolyTypes.h" +#include "include/Dialect/Poly/IR/PolynomialDetail.h" + #include "include/Dialect/Poly/IR/PolyDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "include/Dialect/Poly/IR/PolyAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES #include "include/Dialect/Poly/IR/PolyTypes.cpp.inc" #define GET_OP_CLASSES @@ -23,6 +27,10 @@ namespace poly { // Dialect construction: there is one instance per context and it registers its // operations, types, and interfaces here. void PolyDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "include/Dialect/Poly/IR/PolyAttributes.cpp.inc" + >(); addTypes< #define GET_TYPEDEF_LIST #include "include/Dialect/Poly/IR/PolyTypes.cpp.inc" @@ -31,6 +39,8 @@ void PolyDialect::initialize() { #define GET_OP_LIST #include "include/Dialect/Poly/IR/PolyOps.cpp.inc" >(); + + getContext()->getAttributeUniquer().registerParametricStorageType(); } } // namespace poly diff --git a/lib/Dialect/Poly/IR/Polynomial.cpp b/lib/Dialect/Poly/IR/Polynomial.cpp new file mode 100644 index 0000000000..7c4af3ecae --- /dev/null +++ b/lib/Dialect/Poly/IR/Polynomial.cpp @@ -0,0 +1,116 @@ +#include "include/Dialect/Poly/IR/Polynomial.h" + +#include "include/Dialect/Poly/IR/PolynomialDetail.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" + +namespace mlir { +namespace heir { +namespace poly { + +MLIRContext *Polynomial::getContext() const { return terms->context; } + +ArrayRef Polynomial::getTerms() const { return terms->terms(); } + +Polynomial Polynomial::fromMonomials(ArrayRef monomials, + MLIRContext *context) { + auto assignCtx = [context](detail::PolynomialStorage *storage) { + storage->context = context; + }; + + // A polynomial's terms are canonically stored in order of increasing degree. + llvm::OwningArrayRef monomials_copy = + llvm::OwningArrayRef(monomials); + std::sort(monomials_copy.begin(), monomials_copy.end()); + + StorageUniquer &uniquer = context->getAttributeUniquer(); + return Polynomial(uniquer.get( + assignCtx, monomials.size(), monomials_copy)); +} + +Polynomial Polynomial::fromCoefficients(ArrayRef coeffs, + MLIRContext *context) { + std::vector monomials; + for (size_t i = 0; i < coeffs.size(); i++) { + monomials.push_back(Monomial(coeffs[i], i)); + } + return Polynomial::fromMonomials(std::move(monomials), context); +} + +Polynomial Polynomial::monomialOfDegree(uint64_t degree, MLIRContext *context) { + return Polynomial::fromMonomials({Monomial(1, degree)}, context); +} + +void Polynomial::print(raw_ostream &os) const { + bool first = true; + for (auto term : terms->terms()) { + if (first) { + first = false; + } else { + os << " + "; + } + std::string coeff_to_print; + if (term.coefficient == 1 && term.exponent.uge(1)) { + coeff_to_print = ""; + } else { + llvm::SmallString<512> coeff_string; + term.coefficient.toStringSigned(coeff_string); + coeff_to_print = coeff_string.str(); + } + + if (term.exponent == 0) { + os << coeff_to_print; + } else if (term.exponent == 1) { + os << coeff_to_print << "x"; + } else { + os << coeff_to_print << "x**" << term.exponent; + } + } +} + +Polynomial Polynomial::operator+(Polynomial other) const { + ArrayRef thisTerms = terms->terms(); + ArrayRef otherTerms = other.terms->terms(); + + std::vector result; + auto thisTerm = thisTerms.begin(); + auto otherTerm = otherTerms.begin(); + + while (thisTerm != thisTerms.end() || otherTerm != otherTerms.end()) { + if (thisTerm == thisTerms.end()) { + result.push_back(*otherTerm); + otherTerm++; + continue; + } + + if (otherTerm == otherTerms.end()) { + result.push_back(*thisTerm); + thisTerm++; + continue; + } + + if (thisTerm->exponent == otherTerm->exponent) { + Monomial sum = Monomial(thisTerm->coefficient + otherTerm->coefficient, + thisTerm->exponent); + result.push_back(sum); + thisTerm++; + otherTerm++; + } else if (thisTerm->exponent.ule(otherTerm->exponent)) { + result.push_back(*thisTerm); + thisTerm++; + } else { // (thisTerm->exponent > otherTerm->exponent) + result.push_back(*otherTerm); + otherTerm++; + } + } + + return Polynomial::fromMonomials(result, getContext()); +} + +} // end namespace poly +} // end namespace heir +} // end namespace mlir + +MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::heir::poly::detail::PolynomialStorage); diff --git a/lib/Dialect/Poly/IR/PolynomialTest.cpp b/lib/Dialect/Poly/IR/PolynomialTest.cpp new file mode 100644 index 0000000000..1d38891ab4 --- /dev/null +++ b/lib/Dialect/Poly/IR/PolynomialTest.cpp @@ -0,0 +1,63 @@ +#include +#include + +#include "include/Dialect/Poly/IR/Polynomial.h" +#include "include/Dialect/Poly/IR/PolynomialDetail.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/StorageUniquer.h" + +namespace mlir::heir::poly { +namespace { + +using ::testing::ElementsAre; + +TEST(PolynomialTest, TestBuilder) { + mlir::MLIRContext context; + context.getAttributeUniquer() + .registerParametricStorageType(); + auto poly = Polynomial::fromCoefficients({1, 2, 3}, &context); + std::string result; + llvm::raw_string_ostream stream(result); + poly.print(stream); + + EXPECT_EQ(result, "1 + 2x + 3x**2"); +} + +TEST(PolynomialTest, TestBuilderFromMonomials) { + mlir::MLIRContext context; + context.getAttributeUniquer() + .registerParametricStorageType(); + auto m1 = Polynomial::monomialOfDegree(1024, &context); + auto m2 = Polynomial::monomialOfDegree(0, &context); + auto poly = m1 + m2; + + std::string result; + llvm::raw_string_ostream stream(result); + poly.print(stream); + + EXPECT_EQ(result, "1 + x**1024"); +} + +TEST(PolynomialTest, TestSortedDegree) { + mlir::MLIRContext context; + context.getAttributeUniquer() + .registerParametricStorageType(); + std::vector monomials; + monomials.push_back(Monomial(1, 9)); + monomials.push_back(Monomial(1, 1)); + monomials.push_back(Monomial(1, 3)); + monomials.push_back(Monomial(1, 0)); + + auto poly = Polynomial::fromMonomials(monomials, &context); + + std::vector actualDegrees; + for (auto term : poly.getTerms()) { + actualDegrees.push_back(term.exponent.getZExtValue()); + } + + EXPECT_THAT(actualDegrees, ElementsAre(0, 1, 3, 9)); +} + +} // namespace +} // namespace mlir::heir::poly diff --git a/tests/poly_ops.mlir b/tests/poly_ops.mlir index 8433e0cf06..115cfa8109 100644 --- a/tests/poly_ops.mlir +++ b/tests/poly_ops.mlir @@ -1,25 +1,27 @@ -// RUN: heir-opt %s | FileCheck %s +// RUN: heir-opt %s > %t +// RUN: FileCheck %s < %t // This simply tests for syntax. +#my_poly = #poly.polynomial<1 + x**1024> +#my_poly_2 = #poly.polynomial +#my_poly_3 = #poly.polynomial<2> +#my_poly_4 = #poly.polynomial module { -// CHECK-LABEL: func @fooFunc - func.func @fooFunc(%arg0: !poly.poly, %arg1: !poly.poly) -> !poly.poly { + func.func @test_multiply() -> i32 { %c0 = arith.constant 0 : index - %c3 = arith.constant 3 : index - %c1_i32 = arith.constant 1 : i32 - // CHECK: poly.mul - %0 = poly.mul(%arg0, %arg1) : !poly.poly - // CHECK: poly.extract_slice - %1 = poly.extract_slice(%0, %c0, %c3) : (!poly.poly, index, index) -> tensor<3xi32> - // CHECK: poly.from_coeffs - %2 = poly.from_coeffs(%1) : (tensor<3xi32>) -> !poly.poly - // CHECK: poly.add - %3 = poly.add(%arg0, %arg1, %2) : !poly.poly - // CHECK: poly.get_coeff + %two = arith.constant 2 : i32 + %five = arith.constant 5 : i32 + %coeffs1 = tensor.from_elements %two, %two, %five : tensor<3xi32> + %coeffs2 = tensor.from_elements %five, %five, %two : tensor<3xi32> + + %poly1 = poly.from_coeffs(%coeffs1) : (tensor<3xi32>) -> !poly.poly + %poly2 = poly.from_coeffs(%coeffs2) : (tensor<3xi32>) -> !poly.poly + + // CHECK: #poly.polynomial<1 + x**1024> + %3 = poly.mul(%poly1, %poly2) {modulus = #my_poly} : !poly.poly %4 = poly.get_coeff(%3, %c0) : (!poly.poly, index) -> i32 - // CHECK: poly.mul_constant - %5 = poly.mul_constant(%3, %4) : (!poly.poly, i32) -> !poly.poly - return %5 : !poly.poly + + return %4 : i32 } }