Skip to content

Commit

Permalink
Merge pull request #1 from newling/newling-update-added-support-for-t…
Browse files Browse the repository at this point in the history
…orch-arange-float-module

Formatting and stylistic changes
  • Loading branch information
Abhishek-TyRnT authored Feb 26, 2024
2 parents 89e02c1 + 9b4ae1e commit 4cd1632
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 34 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenNormScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
53 changes: 45 additions & 8 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
elementType.getIntOrFloatBitWidth())));
}

if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

if (isa<AtenAllDimOp>(op)) {
Expand Down Expand Up @@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (intType.isSigned())
return b.create<arith::MinSIOp>(loc, self, result);
}
} else if (isa<AtenNormScalarOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `p` is zero or one.
Value elem = payloadArgs[0];
Value result = payloadArgs[1];

// TODO: Fix this part to support complex elements.
if (elem.getType().isa<mlir::ComplexType>()) {
op->emitError("lowering of complex input type for torch.aten.norm.Scalar "
"is currently unimplemented");
return nullptr;
}

Value self = convertScalarToDtype(b, loc, elem, resultElementType);

auto abs = b.create<math::AbsFOp>(loc, self);
AtenNormScalarOp::Adaptor adaptor(operands);
Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, p);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenLinalgVectorNormOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `ord` is zero or one.
Expand Down Expand Up @@ -433,7 +454,7 @@ class ConvertReductionOp : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};

if (isa<AtenMaxOp, AtenMinOp, AtenSumOp>(op)) {
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();

Expand Down Expand Up @@ -484,10 +505,12 @@ class ConvertReductionOp : public ConversionPattern {
return err ? Value{} : powOp;
}

FailureOr<Value> createSecondReductionForVectorNormOp(
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
template <typename TOp>
FailureOr<Value>
createSecondReductionForNormOp(Location loc, Type elemType, TOp op,
Value ordOp, Value firstReduction,
const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
// Cast `ord` to float so that we can readily pass it math.powf.
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);

Expand Down Expand Up @@ -544,13 +567,15 @@ class ConvertReductionOp : public ConversionPattern {
LogicalResult
validateReductionElementType(Operation *op, Type elemType,
ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op)) &&
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");

// No checks for all other reduction operations
return success();
}
Expand Down Expand Up @@ -587,11 +612,22 @@ class ConvertReductionOp : public ConversionPattern {
return rewriter.notifyMatchFailure(
op, "failed to create linalg.generic operation for reduction");

// If this is aten.norm.Scalar op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (isa<AtenNormScalarOp>(op)) {
AtenNormScalarOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
reduceOp = *secondReduceOp;
}

// If this is aten.linalg_vector_norm op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
Expand Down Expand Up @@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenMinOp>();
target.addIllegalOp<AtenAllDimOp>();
target.addIllegalOp<AtenNormScalarOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
Expand Down
153 changes: 131 additions & 22 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <optional>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -4067,28 +4066,138 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
// Stores a range value (a start, end, or step value) and whether or not it
// was initiated with a constant integer, an constant float or neither.
class ConstRangeValue {
public:
explicit ConstRangeValue(double v)
: vDouble(v), fromDouble(true), vInt(static_cast<int64_t>(v)),
fromInt(false) {}

explicit ConstRangeValue(int64_t v)
: vDouble(static_cast<double>(v)), fromDouble(false), vInt(v),
fromInt(true) {}

// Constructor for the case where there is no constant value to use.
ConstRangeValue()
: vDouble(0), fromDouble(false), vInt(0), fromInt(false) {}

static ConstRangeValue fromValue(Value v) {
int64_t intVal{0};
double floatVal{0.0};
if (matchPattern(v, m_TorchConstantFloat(&floatVal))) {
return ConstRangeValue(floatVal);
} else if (matchPattern(v, m_TorchConstantInt(&intVal))) {
return ConstRangeValue(intVal);
}
return ConstRangeValue();
}

bool hasConstInt() const { return fromInt; }
bool hasConstDouble() const { return fromDouble; }
bool hasConst() const { return fromInt || fromDouble; }
double getDouble() const { return vDouble; }
int64_t getInt() const { return vInt; }

private:
double vDouble;
bool fromDouble;
int64_t vInt;
bool fromInt;
};

auto start = ConstRangeValue::fromValue(op.getStart());
if (!start.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");
op, "unimplemented: case where `start` is not a constant int or float");
}

if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
auto end = ConstRangeValue::fromValue(op.getEnd());
if (!end.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");
op,
"unimplemented: case where value `end` is not a constant int or float");
}

if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
auto step = ConstRangeValue::fromValue(op.getStep());
if (!step.hasConst()) {
return rewriter.notifyMatchFailure(op,
"unimplemented: case where value `step` "
"is not a constant int or float");
}

auto getRange = [](auto start, auto end, auto step) {
// Initialize a small vector of the same type as start:
using T = decltype(start);
SmallVector<T> values;

uint64_t counter{0};
if (start == end) {
return values;
}
assert(step != T(0));
values.reserve(
1 + static_cast<size_t>(std::abs((end - start) / std::abs(step))));
if (step > 0) {
while (start + T(counter) * step < end) {
values.push_back(start + counter * step);
counter++;
}
} else {
while (start + T(counter) * step > end) {
values.push_back(start + counter * step);
counter++;
}
}
return values;
};

const auto isIntType =
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();

const auto isDoubleType =
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();

auto maybeResult = [&]() -> std::optional<Value> {
// Integer output type, and start / end / range are all integers.
if (isIntType && start.hasConstInt() && end.hasConstInt() &&
step.hasConstInt()) {
auto values = getRange(start.getInt(), end.getInt(), step.getInt());
return tosa::getConstTensor<int64_t>(rewriter, op, values, values.size());
}

// Get a double range.
auto values =
getRange(start.getDouble(), end.getDouble(), step.getDouble());
if (isIntType) {
SmallVector<int64_t> values_i64;
values_i64.reserve(values.size());
for (auto v : values) {
values_i64.push_back(static_cast<int64_t>(v));
}
return tosa::getConstTensor<int64_t>(rewriter, op, values_i64,
values.size());
}

if (!isDoubleType) {
return {};
}

SmallVector<float> values_f32;
values_f32.reserve(values.size());
for (auto v : values) {
values_f32.push_back(static_cast<float>(v));
}
auto vs = tosa::getConstTensor<float>(rewriter, op, values_f32,
values_f32.size());
return vs;
}();

if (!maybeResult.has_value()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
op, "failed to generate constant tensor for arange");
}
auto result = maybeResult.value();

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenNormScalarOp
//===----------------------------------------------------------------------===//

LogicalResult AtenNormScalarOp::verify() {

// Verificaion of input type for torch.aten.norm.Scalar.
// Per PyTorch docs, only float and complex types are valid for norm
// operation.

auto inTensor = getSelf().getType().cast<BaseTensorType>();

// If no dtype is specified, it will default to a float one.
if (!inTensor.hasDtype()) {
return success();
}

auto inTensorDtype = inTensor.getDtype();

// Check if dtype is one of those supported by norm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
mlir::Float32Type, mlir::Float64Type>()) {
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< inTensorDtype;
}

return success();
}

//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//

LogicalResult AtenPermuteOp::verify() {

// Verification of the permute op for input & output dimensions with
Expand Down
Loading

0 comments on commit 4cd1632

Please sign in to comment.