Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatting and stylistic changes #1

25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenNormScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
53 changes: 45 additions & 8 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
elementType.getIntOrFloatBitWidth())));
}

if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

if (isa<AtenAllDimOp>(op)) {
Expand Down Expand Up @@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
if (intType.isSigned())
return b.create<arith::MinSIOp>(loc, self, result);
}
} else if (isa<AtenNormScalarOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `p` is zero or one.
Value elem = payloadArgs[0];
Value result = payloadArgs[1];

// TODO: Fix this part to support complex elements.
if (elem.getType().isa<mlir::ComplexType>()) {
op->emitError("lowering of complex input type for torch.aten.norm.Scalar "
"is currently unimplemented");
return nullptr;
}

Value self = convertScalarToDtype(b, loc, elem, resultElementType);

auto abs = b.create<math::AbsFOp>(loc, self);
AtenNormScalarOp::Adaptor adaptor(operands);
Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, p);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenLinalgVectorNormOp>(op)) {
// This creates payload for only the first of the two linalg.generic ops.
// TODO: Short-circuit operations if `ord` is zero or one.
Expand Down Expand Up @@ -433,7 +454,7 @@ class ConvertReductionOp : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};

if (isa<AtenMaxOp, AtenMinOp, AtenSumOp>(op)) {
if (isa<AtenMaxOp, AtenMinOp, AtenSumOp, AtenNormScalarOp>(op)) {
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();

Expand Down Expand Up @@ -484,10 +505,12 @@ class ConvertReductionOp : public ConversionPattern {
return err ? Value{} : powOp;
}

FailureOr<Value> createSecondReductionForVectorNormOp(
Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp,
Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
template <typename TOp>
FailureOr<Value>
createSecondReductionForNormOp(Location loc, Type elemType, TOp op,
Value ordOp, Value firstReduction,
const torch_to_linalg::ReductionOpInfo &opInfo,
ConversionPatternRewriter &rewriter) const {
// Cast `ord` to float so that we can readily pass it math.powf.
Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType);

Expand Down Expand Up @@ -544,13 +567,15 @@ class ConvertReductionOp : public ConversionPattern {
LogicalResult
validateReductionElementType(Operation *op, Type elemType,
ConversionPatternRewriter &rewriter) const {
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op)) &&
if ((isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op) ||
isa<AtenNormScalarOp>(op)) &&
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if (isa<AtenAllDimOp>(op) && elemType.isa<mlir::IntegerType>() &&
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");

// No checks for all other reduction operations
return success();
}
Expand Down Expand Up @@ -587,11 +612,22 @@ class ConvertReductionOp : public ConversionPattern {
return rewriter.notifyMatchFailure(
op, "failed to create linalg.generic operation for reduction");

// If this is aten.norm.Scalar op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (isa<AtenNormScalarOp>(op)) {
AtenNormScalarOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
reduceOp = *secondReduceOp;
}

// If this is aten.linalg_vector_norm op, then we need to generate another
// linalg.generic op that references the first linalg.generic op.
if (auto normOp = dyn_cast<AtenLinalgVectorNormOp>(op)) {
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
FailureOr<Value> secondReduceOp = createSecondReductionForVectorNormOp(
FailureOr<Value> secondReduceOp = createSecondReductionForNormOp(
loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter);
if (failed(secondReduceOp))
return secondReduceOp;
Expand Down Expand Up @@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenMinOp>();
target.addIllegalOp<AtenAllDimOp>();
target.addIllegalOp<AtenNormScalarOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
Expand Down
153 changes: 131 additions & 22 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,23 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

#include "../PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include <optional>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -4067,28 +4066,138 @@ LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
// Stores a range value (a start, end, or step value) and whether or not it
// was initiated with a constant integer, an constant float or neither.
class ConstRangeValue {
public:
explicit ConstRangeValue(double v)
: vDouble(v), fromDouble(true), vInt(static_cast<int64_t>(v)),
fromInt(false) {}

explicit ConstRangeValue(int64_t v)
: vDouble(static_cast<double>(v)), fromDouble(false), vInt(v),
fromInt(true) {}

// Constructor for the case where there is no constant value to use.
ConstRangeValue()
: vDouble(0), fromDouble(false), vInt(0), fromInt(false) {}

static ConstRangeValue fromValue(Value v) {
int64_t intVal{0};
double floatVal{0.0};
if (matchPattern(v, m_TorchConstantFloat(&floatVal))) {
return ConstRangeValue(floatVal);
} else if (matchPattern(v, m_TorchConstantInt(&intVal))) {
return ConstRangeValue(intVal);
}
return ConstRangeValue();
}

bool hasConstInt() const { return fromInt; }
bool hasConstDouble() const { return fromDouble; }
bool hasConst() const { return fromInt || fromDouble; }
double getDouble() const { return vDouble; }
int64_t getInt() const { return vInt; }

private:
double vDouble;
bool fromDouble;
int64_t vInt;
bool fromInt;
};

auto start = ConstRangeValue::fromValue(op.getStart());
if (!start.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");
op, "unimplemented: case where `start` is not a constant int or float");
}

if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
auto end = ConstRangeValue::fromValue(op.getEnd());
if (!end.hasConst()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");
op,
"unimplemented: case where value `end` is not a constant int or float");
}

if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
auto step = ConstRangeValue::fromValue(op.getStep());
if (!step.hasConst()) {
return rewriter.notifyMatchFailure(op,
"unimplemented: case where value `step` "
"is not a constant int or float");
}

auto getRange = [](auto start, auto end, auto step) {
// Initialize a small vector of the same type as start:
using T = decltype(start);
SmallVector<T> values;

uint64_t counter{0};
if (start == end) {
return values;
}
assert(step != T(0));
values.reserve(
1 + static_cast<size_t>(std::abs((end - start) / std::abs(step))));
if (step > 0) {
while (start + T(counter) * step < end) {
values.push_back(start + counter * step);
counter++;
}
} else {
while (start + T(counter) * step > end) {
values.push_back(start + counter * step);
counter++;
}
}
return values;
};

const auto isIntType =
resultType.getElementType().dyn_cast_or_null<mlir::IntegerType>();

const auto isDoubleType =
resultType.getElementType().dyn_cast_or_null<mlir::FloatType>();

auto maybeResult = [&]() -> std::optional<Value> {
// Integer output type, and start / end / range are all integers.
if (isIntType && start.hasConstInt() && end.hasConstInt() &&
step.hasConstInt()) {
auto values = getRange(start.getInt(), end.getInt(), step.getInt());
return tosa::getConstTensor<int64_t>(rewriter, op, values, values.size());
}

// Get a double range.
auto values =
getRange(start.getDouble(), end.getDouble(), step.getDouble());
if (isIntType) {
SmallVector<int64_t> values_i64;
values_i64.reserve(values.size());
for (auto v : values) {
values_i64.push_back(static_cast<int64_t>(v));
}
return tosa::getConstTensor<int64_t>(rewriter, op, values_i64,
values.size());
}

if (!isDoubleType) {
return {};
}

SmallVector<float> values_f32;
values_f32.reserve(values.size());
for (auto v : values) {
values_f32.push_back(static_cast<float>(v));
}
auto vs = tosa::getConstTensor<float>(rewriter, op, values_f32,
values_f32.size());
return vs;
}();

if (!maybeResult.has_value()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();
op, "failed to generate constant tensor for arange");
}
auto result = maybeResult.value();

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenNormScalarOp
//===----------------------------------------------------------------------===//

LogicalResult AtenNormScalarOp::verify() {

// Verificaion of input type for torch.aten.norm.Scalar.
// Per PyTorch docs, only float and complex types are valid for norm
// operation.

auto inTensor = getSelf().getType().cast<BaseTensorType>();

// If no dtype is specified, it will default to a float one.
if (!inTensor.hasDtype()) {
return success();
}

auto inTensorDtype = inTensor.getDtype();

// Check if dtype is one of those supported by norm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (!inTensorDtype.isa<mlir::ComplexType, mlir::Float16Type,
mlir::Float32Type, mlir::Float64Type>()) {
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< inTensorDtype;
}

return success();
}

//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//

LogicalResult AtenPermuteOp::verify() {

// Verification of the permute op for input & output dimensions with
Expand Down
Loading
Loading