Skip to content

Commit

Permalink
Create tensor from index array (cruise-automation#24)
Browse files Browse the repository at this point in the history
Create tensor from index array handle
  • Loading branch information
Muhammad Abubakar authored and GitHub Enterprise committed May 21, 2024
2 parents aa8740f + a7dde48 commit a5b2c20
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 14 deletions.
3 changes: 2 additions & 1 deletion BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ td_library(
"include/mlir-tcp/Dialect/IR/TcpOps.td",
"include/mlir-tcp/Dialect/IR/TcpTypes.td",
"include/mlir-tcp/Dialect/IR/TcpOpsCruiseInternal.td",
],
"include/mlir-tcp/Dialect/IR/TcpTypesCruiseInternal.td",
],
includes = ["include"],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
22 changes: 21 additions & 1 deletion include/mlir-tcp/Dialect/IR/TcpOpsCruiseInternal.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#ifndef TCP_OPS_CRUISE_INTERNAL
#define TCP_OPS_CRUISE_INTERNAL

def Tcp_CreateIndexArrayOp : Tcp_Op<"create_index_array", []> {
def Tcp_CreateIndexArrayOp : Tcp_Op<"create_index_array"> {
let summary = "Creates index_arry type given variadic number of input indices";

let arguments = (ins
Expand All @@ -20,6 +20,10 @@ def Tcp_CreateIndexArrayOp : Tcp_Op<"create_index_array", []> {
let results = (outs
Tcp_IndexArrayType : $outputs
);

let assemblyFormat = [{
`(` operands `)` `:` attr-dict-with-keyword functional-type(operands, results)
}];
}

def Tcp_BindTensorShape : Tcp_Op<"bind_tensor_shape", []> {
Expand All @@ -46,6 +50,21 @@ def Tcp_CasprCreateTensorFromIndexOp : Tcp_Op<"caspr_create_tensor_from_index">
}];
}

// caspr CasprCreateTensorFromIndexArrayOp : used for encapsulating result of tensor_dim_op in tensor
def Tcp_CasprCreateTensorFromIndexArrayOp : Tcp_Op<"caspr_create_tensor_from_index_array"> {
let arguments = (ins
Tcp_IndexArrayType:$shape_array
);

let results = (outs
AnyType:$result
);

let assemblyFormat = [{
`(` operands `)` `:` attr-dict-with-keyword functional-type(operands, results)
}];
}

def Tcp_CasprIndexFromTensorOp : Tcp_Op<"caspr_index_from_tensor"> {
let arguments = (ins
AnyType:$input
Expand All @@ -60,4 +79,5 @@ def Tcp_CasprIndexFromTensorOp : Tcp_Op<"caspr_index_from_tensor"> {
}];
}


#endif // TCP_OPS_CRUISE_INTERNAL
11 changes: 1 addition & 10 deletions include/mlir-tcp/Dialect/IR/TcpTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/DialectBase.td"

include "mlir-tcp/Dialect/IR/TcpBase.td"

include "mlir-tcp/Dialect/IR/TcpTypesCruiseInternal.td"

//===----------------------------------------------------------------------===//
// Tcp Quantized Types
Expand Down Expand Up @@ -58,13 +58,4 @@ def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>;
def Tcp_IntTensor : RankedTensorOf<[AnySignlessInteger]>;
def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>;


//===----------------------------------------------------------------------===//
// Tcp Custom Types
//===----------------------------------------------------------------------===//

def Tcp_IndexArrayType : Tcp_Type<"IndexArray", "index_array"> {
let summary = "IndexArray TCP type, to holds a list of index builtin type to represent shape";
}

#endif // TCP_TYPES
27 changes: 27 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpTypesCruiseInternal.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TCP_TYPES_CRUISE_INTERNAL
#define TCP_TYPES_CRUISE_INTERNAL

//===----------------------------------------------------------------------===//
// Tcp Custom Types
//===----------------------------------------------------------------------===//

def Tcp_IndexArrayType : Tcp_Type<"IndexArray", "index_array"> {
let summary = "IndexArray TCP type, to holds a list of index builtin type to represent shape";

let parameters = (ins
"int":$elements
);

let assemblyFormat = "`<` `[` $elements `]` `>`";
}

#endif // TCP_TYPES_CRUISE_INTERNAL
20 changes: 20 additions & 0 deletions lib/Conversion/TorchToTcp/CruiseInternalPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ class ConvertCasprCreateTensorFromIndexOp : public OpConversionPattern<Torch::Ca
}
};

class ConvertCasprCreateTensorFromIndexArrayOp : public OpConversionPattern<Torch::CasprCreateTensorFromIndexArrayOp> {
public:
using OpConversionPattern<Torch::CasprCreateTensorFromIndexArrayOp>::OpConversionPattern;
using OpAdaptor = typename Torch::CasprCreateTensorFromIndexArrayOp::Adaptor;

LogicalResult
matchAndRewrite(Torch::CasprCreateTensorFromIndexArrayOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(OpConversionPattern<Torch::CasprCreateTensorFromIndexArrayOp>::getTypeConverter()->convertTypes(
op->getResultTypes(), resultTypes))) {
return failure();
}
rewriter.replaceOpWithNewOp<tcp::CasprCreateTensorFromIndexArrayOp>(
op, resultTypes, adaptor.getShapeArray());
return success();
}
};

class ConvertCasprIndexFromTensorOp : public OpConversionPattern<Torch::CasprIndexFromTensorOp> {
public:
using OpConversionPattern<Torch::CasprIndexFromTensorOp>::OpConversionPattern;
Expand Down Expand Up @@ -370,6 +389,7 @@ void torch_to_tcp::cruise::populateCruiseInternalPatternsAndLegality(TypeConvert
patterns.add<ConvertAxisAlignedHardNMS2dOp>(typeConverter, context);
patterns.add<ConvertCreateIndexArrayOp>(typeConverter, context);
patterns.add<ConvertCasprCreateTensorFromIndexOp>(typeConverter, context);
patterns.add<ConvertCasprCreateTensorFromIndexArrayOp>(typeConverter, context);
patterns.add<ConvertCasprIndexFromTensorOp>(typeConverter, context);
patterns.add<ConvertBindTensorShapeOp>(typeConverter, context);
patterns.add<ConvertCasprShapeTensorDimOp>(typeConverter, context);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToTcp/TorchToTcpCruiseInternal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ConvertTorchToTcpCruiseInternal

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion([&](Torch::IndexArrayType) {
return tcp::IndexArrayType::get(context);
typeConverter.addConversion([&](Torch::IndexArrayType ty) {
return tcp::IndexArrayType::get(context, ty.getElements());
});
typeConverter.addConversion(
[&](Torch::TorchIndexType) { return ::mlir::IndexType::get(context); });
Expand Down
23 changes: 23 additions & 0 deletions test/Conversion/TorchToTcp/custom_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,26 @@ func.func @index_from_tensor_tensor_from_index(%arg0: !torch.vtensor<[1],si64>)
%1 = torch.caspr_create_tensor_from_index(%0) : (!torch.index) -> !torch.vtensor<[1],si64>
return %1 : !torch.vtensor<[1],si64>
}

// -----

// Since function arguments are not yet converted, casts will be added which will be
// resolved in the next pass
// CHECK: @create_tensor_from_index_array(
// CHECK: %[[ARG0:.+]]: !torch.vtensor<[?,?],f32>
// CHECK: %[[CAST:.+]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[DIM0:.+]] = tensor.dim %[[CAST]], %[[C0:.+]] : tensor<?x?xf32>
// CHECK: %[[DIM1:.+]] = tensor.dim %[[CAST]], %[[C1:.+]] : tensor<?x?xf32>
// CHECK: %[[ARRAY:.+]] = tcp.create_index_array(%[[DIM0]], %[[DIM1]]) : (index, index) -> !tcp.index_array<[2]>
// CHECK: %[[T:.+]] = tcp.caspr_create_tensor_from_index_array(%[[ARRAY]]) : (!tcp.index_array<[2]>) -> tensor<2xi64>
// CHECK: %[[CAST2:.+]] = torch_c.from_builtin_tensor %[[T]] : tensor<2xi64> -> !torch.vtensor<[2],si64>
// CHECK: return %[[CAST2]] : !torch.vtensor<[2],si64>
func.func @create_tensor_from_index_array(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2],si64> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.caspr_shapes.tensor_dim_op(%arg0, %int0) : (!torch.vtensor<[?,?],f32>, !torch.int) -> !torch.index
%1 = torch.caspr_shapes.tensor_dim_op(%arg0, %int1) : (!torch.vtensor<[?,?],f32>, !torch.int) -> !torch.index
%2 = torch.create_index_array(%0, %1) : (!torch.index, !torch.index) -> !torch.index_array<[2]>
%3 = torch.caspr_create_tensor_from_index_array(%2) : (!torch.index_array<[2]>) -> !torch.vtensor<[2],si64>
return %3: !torch.vtensor<[2],si64>
}

0 comments on commit a5b2c20

Please sign in to comment.