Skip to content

Commit

Permalink
Add polynomial and ring attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Aug 1, 2023
1 parent 84d6b77 commit 71c8b6f
Show file tree
Hide file tree
Showing 16 changed files with 748 additions and 29 deletions.
19 changes: 19 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,22 @@ http_archive(
load("@googletest//:googletest_deps.bzl", "googletest_deps")

googletest_deps()

# compile_commands extracts the relevant compile data from bazel into
# `compile_commands.json` so that clangd, clang-tidy, etc., can use it.
# Whenever a build file changes, you must re-run
#
# bazel run @hedron_compile_commands//:refresh_all
#
# to ingest new data into these tools.
#
# See the project repo for more details and configuration options
# https://github.com/hedronvision/bazel-compile-commands-extractor
http_archive(
name = "hedron_compile_commands",
url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/3dddf205a1f5cde20faf2444c1757abe0564ff4c.tar.gz",
strip_prefix = "bazel-compile-commands-extractor-3dddf205a1f5cde20faf2444c1757abe0564ff4c",
sha256 = "3cd0e49f0f4a6d406c1d74b53b7616f5e24f5fd319eafc1bf8eee6e14124d115",
)
load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup")
hedron_compile_commands_setup()
32 changes: 32 additions & 0 deletions include/Dialect/Poly/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ package(

exports_files(
[
"PolyAttributes.h",
"PolyDialect.h",
"PolyOps.h",
"PolyTypes.h",
"Polynomial.h",
"PolynomialDetail.h",
],
)

td_library(
name = "td_files",
srcs = [
"PolyAttributes.td",
"PolyDialect.td",
"PolyOps.td",
"PolyTypes.td",
Expand Down Expand Up @@ -52,6 +56,34 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "attributes_inc_gen",
tbl_outs = [
(
[
"-gen-attrdef-decls",
],
"PolyAttributes.h.inc",
),
(
[
"-gen-attrdef-defs",
],
"PolyAttributes.cpp.inc",
),
(
["-gen-attrdef-doc"],
"PolyAttributes.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "PolyAttributes.td",
deps = [
":dialect_inc_gen",
":td_files",
],
)

gentbl_cc_library(
name = "types_inc_gen",
tbl_outs = [
Expand Down
10 changes: 10 additions & 0 deletions include/Dialect/Poly/IR/PolyAttributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_H_
#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_H_

#include "include/Dialect/Poly/IR/PolyDialect.h"
#include "include/Dialect/Poly/IR/Polynomial.h"

#define GET_ATTRDEF_CLASSES
#include "include/Dialect/Poly/IR/PolyAttributes.h.inc"

#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_H_
71 changes: 71 additions & 0 deletions include/Dialect/Poly/IR/PolyAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_TD_
#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_TD_

include "PolyDialect.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"

class Poly_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Poly_Dialect, name, traits> {
let mnemonic = attrMnemonic;
}

class Poly_Attr_With_Custom_Format<string name, list<Trait> traits = []>
: AttrDef<Poly_Dialect, name, traits> {
let mnemonic = ?;
}

def Polynomial_Attr : Poly_Attr<"Polynomial", "polynomial"> {
let summary = "An attribute containing a polynomial.";
let description = [{
#poly = #poly.poly<x**1024 + 1>
}];

let parameters = (ins "Polynomial":$value);

let builders = [
AttrBuilderWithInferredContext<(ins "Polynomial":$value), [{
return $_get(value.getContext(), value);
}]>
];
let extraClassDeclaration = [{
using ValueType = Polynomial;
Polynomial getPolynomial() const { return getValue(); }
}];

let skipDefaultBuilders = 1;
let hasCustomAssemblyFormat = 1;
}

def Ring_Attr : Poly_Attr<"Ring", "ring"> {
let summary = "An attribute specifying a ring.";
let description = [{
An attribute specifying a polynomial quotient ring with integer
coefficients, $\mathbb{Z}/n\mathbb{Z}[x] / (p(x))$.

`cmod` is the coefficient modulus $n$, and `ideal` is the ring ideal
$(p(x))$. Because all ideals in a single-variable polynomial ring are
principal, the ideal is defined by a single polynomial.

#ring = #poly.ring<cmod=1234, ideal=#poly.polynomial<x**1024 + 1>>
}];

let parameters = (ins "APInt": $cmod, "Polynomial":$ideal);

let builders = [
AttrBuilderWithInferredContext<(ins "APInt": $cmod, "Polynomial":$ideal), [{
return $_get(ideal.getContext(), cmod, ideal);
}]>
];
let extraClassDeclaration = [{
Polynomial ideal() const { return getIdeal(); }
APInt coefficientModulus() const { return getCmod(); }
}];

let skipDefaultBuilders = 1;
let hasCustomAssemblyFormat = 1;
}


#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYATTRIBUTES_TD_
1 change: 1 addition & 0 deletions include/Dialect/Poly/IR/PolyDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def Poly_Dialect : Dialect {
let cppNamespace = "::mlir::heir::poly";

let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
}

#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYDIALECT_TD_
5 changes: 4 additions & 1 deletion include/Dialect/Poly/IR/PolyOps.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYOPS_TD_
#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYOPS_TD_

include "PolyAttributes.td"
include "PolyDialect.td"
include "PolyTypes.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -56,7 +57,9 @@ def Poly_MulOp : Poly_Op<"mul", [SameOperandsAndResultType]> {
let summary = "Multiplication operation between polynomials.";

let arguments = (ins
Variadic<Poly>:$x
Poly:$lhs,
Poly:$rhs,
Ring_Attr:$ring
);

let results = (outs
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/Poly/IR/PolyTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYTYPES_H_

#include "include/Dialect/Poly/IR/PolyDialect.h"
#include "include/Dialect/Poly/IR/PolyAttributes.h"

#define GET_TYPEDEF_CLASSES
#include "include/Dialect/Poly/IR/PolyTypes.h.inc"
Expand Down
9 changes: 2 additions & 7 deletions include/Dialect/Poly/IR/PolyTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYTYPES_TD_

include "PolyDialect.td"
include "PolyAttributes.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"

//===----------------------------------------------------------------------===//
// Poly type definitions
//===----------------------------------------------------------------------===//

// A base class for all types in this dialect
class Poly_Type<string name, string typeMnemonic>
: TypeDef<Poly_Dialect, name> {
Expand All @@ -21,10 +18,8 @@ def Poly : Poly_Type<"Polynomial", "poly"> {

let description = [{
A type for polynomials in a polynomial quotient ring.

The type is parametrized by the degree of the polynomial

}];

}

#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYTYPES_TD_
139 changes: 139 additions & 0 deletions include/Dialect/Poly/IR/Polynomial.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#ifndef HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIAL_H_
#define HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIAL_H_

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
#include "mlir/Support/LLVM.h"

namespace mlir {

class MLIRContext;

namespace heir {
namespace poly {

constexpr unsigned APINT_BIT_WIDTH = 64;

namespace detail {
struct PolynomialStorage;
} // namespace detail

class Monomial {
public:
Monomial(int64_t coeff, uint64_t expo)
: coefficient(APINT_BIT_WIDTH, coeff), exponent(APINT_BIT_WIDTH, expo) {}

Monomial(APInt coeff, APInt expo) : coefficient(coeff), exponent(expo) {}

Monomial() : coefficient(APINT_BIT_WIDTH, 0), exponent(APINT_BIT_WIDTH, 0) {}

bool operator==(Monomial other) const {
return other.coefficient == coefficient && other.exponent == exponent;
}
bool operator!=(Monomial other) const {
return other.coefficient != coefficient || other.exponent != exponent;
}

/// Monomials are ordered by exponent.
bool operator<(const Monomial &other) const {
return (exponent.ule(other.exponent));
}

// Prints polynomial to 'os'.
void print(raw_ostream &os) const;

friend ::llvm::hash_code hash_value(Monomial arg);

public:
APInt coefficient;

// Always unsigned
APInt exponent;
};

/// A single-variable polynomial with integer coefficients. Polynomials are
/// immutable and uniqued.
///
/// Eg: x^1024 + x + 1
///
/// The symbols used as the polynomial's indeterminate don't matter, so long as
/// it is used consistently throughout the polynomial.
class Polynomial {
public:
using ImplType = detail::PolynomialStorage;

constexpr Polynomial() = default;
explicit Polynomial(ImplType *terms) : terms(terms) {}

static Polynomial fromMonomials(ArrayRef<Monomial> monomials,
MLIRContext *context);
/// Returns a polynomial with coefficients given by `coeffs`
static Polynomial fromCoefficients(ArrayRef<int64_t> coeffs,
MLIRContext *context);

MLIRContext *getContext() const;

explicit operator bool() const { return terms != nullptr; }
bool operator==(Polynomial other) const { return other.terms == terms; }
bool operator!=(Polynomial other) const { return !(other.terms == terms); }

// Prints polynomial to 'os'.
void print(raw_ostream &os) const;
void dump() const;

ArrayRef<Monomial> getTerms() const;

unsigned getDegree() const;

friend ::llvm::hash_code hash_value(Polynomial arg);

private:
ImplType *terms{nullptr};
};

// Make Polynomial hashable.
inline ::llvm::hash_code hash_value(Polynomial arg) {
return ::llvm::hash_value(arg.terms);
}

inline ::llvm::hash_code hash_value(Monomial arg) {
return ::llvm::hash_value(arg.coefficient) ^ ::llvm::hash_value(arg.exponent);
}

inline raw_ostream &operator<<(raw_ostream &os, Polynomial polynomial) {
polynomial.print(os);
return os;
}

} // namespace poly
} // namespace heir
} // namespace mlir

namespace llvm {

// Polynomials hash just like pointers
template <>
struct DenseMapInfo<mlir::heir::poly::Polynomial> {
static mlir::heir::poly::Polynomial getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::heir::poly::Polynomial(
static_cast<mlir::heir::poly::Polynomial::ImplType *>(pointer));
}
static mlir::heir::poly::Polynomial getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::heir::poly::Polynomial(
static_cast<mlir::heir::poly::Polynomial::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::heir::poly::Polynomial val) {
return mlir::heir::poly::hash_value(val);
}
static bool isEqual(mlir::heir::poly::Polynomial LHS,
mlir::heir::poly::Polynomial RHS) {
return LHS == RHS;
}
};

} // namespace llvm

#endif // HEIR_INCLUDE_DIALECT_POLY_IR_POLYNOMIAL_H_
Loading

0 comments on commit 71c8b6f

Please sign in to comment.