Skip to content

Commit

Permalink
feat: add constant propagation for different unary and binary ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost committed Jul 3, 2024
1 parent d5c1c58 commit 2cec11f
Show file tree
Hide file tree
Showing 3 changed files with 744 additions and 0 deletions.
150 changes: 150 additions & 0 deletions src/Dialect/ONNX/Transforms/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -122,6 +123,16 @@ Value createReplacingConstantOp(
template <typename T>
using EnableNotBool = std::enable_if_t<!std::is_same_v<T, bool>>;

template <typename T>
using EnableBool = std::enable_if_t<std::is_same_v<T, bool>>;

template <typename T>
using EnableInteger =
std::enable_if_t<std::is_integral_v<T> && !std::is_same_v<T, bool>>;

template <typename T>
using EnableFloatingPoint = std::enable_if_t<std::is_floating_point_v<T>>;

/// Checks whether a variadic value is produced by dense ONNXConstantOps.
bool isVariadicOperandFromDenseONNXConstantOp(ValueRange operands) {
return llvm::all_of(operands, [](Value v) { return isDenseONNXConstant(v); });
Expand All @@ -134,6 +145,46 @@ Value ConstZeroTensor(
type, rewriter.getZeroAttr(type.getElementType())));
}

template <typename GetFPConstFunc =
std::function<APFloat(const llvm::fltSemantics &, bool)>,
typename GetIntConstFunc = std::function<APInt(unsigned)>>
Value GetClipConstantOfType(PatternRewriter &rewriter, ShapedType type,
Value value, GetFPConstFunc fpConstantFunc, bool isNegative,
GetIntConstFunc intConstantFunc) {
auto elemType = type.getElementType();
if (auto floatType = dyn_cast<FloatType>(elemType)) {
auto fpValue =
fpConstantFunc(floatType.getFloatSemantics(), /*Negative=*/isNegative);
return rewriter.create<ONNXConstantOp>(value.getLoc(), Attribute(),
DenseElementsAttr::get(type, llvm::ArrayRef(fpValue)));
}
auto intValue = intConstantFunc(elemType.getIntOrFloatBitWidth());
return rewriter.create<ONNXConstantOp>(value.getLoc(), Attribute(),
DenseElementsAttr::get(type, llvm::ArrayRef(intValue)));
}

Value CreateMaximumValueForClip(
PatternRewriter &rewriter, ShapedType type, Value value) {

// Return 'value' if exists, as there is no need to clip to largest.
if (!isNoneValue(value))
return value;

return GetClipConstantOfType(rewriter, type, value, llvm::APFloat::getLargest,
false, llvm::APInt::getMaxValue);
}

Value CreateMinimumValueForClip(
PatternRewriter &rewriter, ShapedType type, Value value) {

// Return 'value' if exists, as there is no need to clip to lowest.
if (!isNoneValue(value))
return value;

return GetClipConstantOfType(rewriter, type, value, llvm::APFloat::getLargest,
true, llvm::APInt::getMinValue);
}

WideNum asWideNum(double n, Type elemType) {
return wideZeroDispatch(elemType, [n](auto wideZero) {
using cpptype = decltype(wideZero);
Expand Down Expand Up @@ -206,6 +257,31 @@ struct ElementWiseBinaryOpImpl<ONNXDivOp, T, EnableNotBool<T>> {
static T eval(T lhs, T rhs) { return lhs / rhs; }
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXBitwiseAndOp, T, EnableInteger<T>> {
static T eval(T lhs, T rhs) { return lhs & rhs; }
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXBitwiseOrOp, T, EnableInteger<T>> {
static T eval(T lhs, T rhs) { return lhs | rhs; }
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXAndOp, T, EnableBool<T>> {
static T eval(T lhs, T rhs) { return lhs && rhs; }
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXOrOp, T, EnableBool<T>> {
static T eval(T lhs, T rhs) { return lhs || rhs; }
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXXorOp, T, EnableBool<T>> {
static T eval(T lhs, T rhs) { return lhs != rhs; }
};

template <typename T>
struct ElementWiseBinaryOpImpl<ONNXMinOp, T> {
static T eval(T lhs, T rhs) { return std::min<T>(lhs, rhs); }
Expand Down Expand Up @@ -287,6 +363,30 @@ constexpr auto subCombiner(Type elemType) {
return elementwiseBinaryOpCombiner<ONNXSubOp>(elemType);
}

/// Precondition: min values must be less than max values if both exist.
bool satisfiesMinLessThanMaxRequirement(ShapedType type, Value min, Value max) {
if (isNoneValue(min) || isNoneValue(max))
return true;

if (!isDenseONNXConstant(min) || !isDenseONNXConstant(max)) {
return false;
}
auto minValues = getConstValueElements(min);
auto maxValues = getConstValueElements(max);

MLIRContext *ctx = min.getContext();
OnnxElementsAttrBuilder elementsBuilder(ctx);
ElementsAttr resultElements = elementsBuilder.combine(minValues, maxValues,
type.clone(IntegerType::get(ctx, 1)),
elementwiseBinaryOpCombiner<ONNXLessOp>(minValues.getElementType()));
auto denseElems =
OnnxElementsAttrBuilder::toDenseElementsAttr(resultElements);

if (denseElems.isSplat())
return denseElems.getSplatValue<bool>();
return false;
}

/// Do element-wise binary calculation of 'lhs' and 'rhs' values and create an
/// ONNXConstantOp for the result.
template <typename ElementwiseBinaryOp>
Expand Down Expand Up @@ -340,11 +440,56 @@ struct ElementWiseUnaryOpImpl {
static T eval(T val) { llvm_unreachable("unsupported op or type"); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXBitwiseNotOp, T, EnableInteger<T>> {
static T eval(T val) { return ~val; }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXCeilOp, T, EnableNotBool<T>> {
static T eval(T val) { return ceil(val); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXCosOp, T, EnableFloatingPoint<T>> {
static T eval(T val) { return cos(val); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXErfOp, T, EnableNotBool<T>> {
static T eval(T val) { return std::erf(val); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXExpOp, T, EnableFloatingPoint<T>> {
static T eval(T val) { return std::exp(val); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXFloorOp, T, EnableNotBool<T>> {
static T eval(T val) { return floor(val); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXLogOp, T, EnableFloatingPoint<T>> {
static T eval(T val) { return log(val); }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXNegOp, T, EnableNotBool<T>> {
static T eval(T val) { return -val; }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXNotOp, T, EnableBool<T>> {
static T eval(T val) { return !val; }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXSinOp, T, EnableFloatingPoint<T>> {
static T eval(T val) { return sin(val); }
};

template <>
struct ElementWiseUnaryOpImpl<ONNXSqrtOp, double> {
static double eval(double val) { return sqrt(val); }
Expand All @@ -355,6 +500,11 @@ struct ElementWiseUnaryOpImpl<ONNXReluOp, T, EnableNotBool<T>> {
static T eval(T val) { return (val < 0) ? 0 : val; }
};

template <typename T>
struct ElementWiseUnaryOpImpl<ONNXReciprocalOp, T, EnableFloatingPoint<T>> {
static T eval(T val) { return 1 / val; }
};

template <typename ElementwiseUnaryOp>
auto elementwiseUnaryOpFunction(Type elemType) {
return getWideNumWrappedTemplateFunction<ElementWiseUnaryOpImpl,
Expand Down
Loading

0 comments on commit 2cec11f

Please sign in to comment.