Skip to content

Commit

Permalink
[onnx] Drop ConstantOfShape logic form importer, fix torch lowering (
Browse files Browse the repository at this point in the history
…llvm#2930)

There is no reason to treat `ConstantOfShape` as a specialized import
any as there exists a onnx-to-torch equivalent. Dropping the import
coding and adding support for resource conversion substantially
increases test coverage for dynamically shaped tests.
  • Loading branch information
rsuderman authored Feb 22, 2024
1 parent df2aa1a commit 53f6d06
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 142 deletions.
71 changes: 47 additions & 24 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"ConstantOfShape", 20,
"ConstantOfShape", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value shape;
Expand All @@ -1582,15 +1582,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
auto shapeSizes =
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
SmallVector<Value> dimList;
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Torch::BaseTensorType shapeType =
shape.getType().cast<Torch::BaseTensorType>();
Type selectResultType = shapeType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

for (int i = 0; i < shapeSizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
Expand All @@ -1601,6 +1600,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.getLoc(), rewriter.getType<Torch::IntType>(), extract);
dimList.push_back(dim);
}

Value dimValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
Expand All @@ -1609,7 +1609,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

// Get fill_value if it is present.
// Assumption : resultDType and value attr type match.
Value value_const;
auto attr = binder.op->getAttr("torch.onnx.value");
auto resultDType = resultType.getDtype();

Expand All @@ -1620,34 +1619,58 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
resultType.toBuiltinTensor().clone(resultDType),
rewriter.getFloatAttr(resultDType, 0.0));
}
if (!isa<DenseElementsAttr>(attr)) {
return rewriter.notifyMatchFailure(
binder.op, "`value` attr needs to be a tensor.");

// If its a dense resource attr we need to convert to a dense type:
if (DenseResourceElementsAttr rattr =
attr.dyn_cast_or_null<DenseResourceElementsAttr>()) {
// Bytes are stored in little endian order. Big endian support will
// require swizzling.
if (!Endian::little) {
binder.op->emitError(
"unimplemented: importing on big endian systems");
return failure();
}

auto ty = cast<ShapedType>(rattr.getType());
auto ptr = rattr.getRawHandle().getBlob()->getData();
auto denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr);
attr = dyn_cast_or_null<SplatElementsAttr>(denseAttr);
}

Attribute splattr;
if (isa<SplatElementsAttr>(attr)) {
auto denseAttr = attr.cast<DenseElementsAttr>();
splattr = denseAttr.getSplatValue<Attribute>();
}

auto denseAttr = attr.cast<DenseElementsAttr>();
auto denseAttrEleType = denseAttr.getElementType();
if (!isa<FloatType, IntegerType>(denseAttrEleType)) {
if (!isa<FloatAttr, IntegerAttr>(splattr)) {
return rewriter.notifyMatchFailure(
binder.op,
"`value` attr tensor only supports types int and float for now.");
}

// Create constant op for value
if (denseAttrEleType.isa<IntegerType>()) {
int64_t intVal = denseAttr.getSplatValue<IntegerAttr>().getSInt();
value_const = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(intVal));
}
if (denseAttrEleType.isa<FloatType>()) {
float floatVal =
denseAttr.getSplatValue<FloatAttr>().getValue().convertToFloat();
value_const = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(floatVal));
Value splatvalue;
if (auto intattr = dyn_cast<IntegerAttr>(splattr)) {
IntegerType intty = cast<IntegerType>(intattr.getType());
int64_t value;
if (intty.isUnsignedInteger()) {
value = intattr.getUInt();
} else if (intty.isSignedInteger()) {
value = intattr.getSInt();
} else {
value = intattr.getInt();
}
splatvalue =
rewriter.create<Torch::ConstantIntOp>(binder.getLoc(), value);
}

if (auto fpattr = dyn_cast<FloatAttr>(splattr))
splatvalue = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(),
rewriter.getF64FloatAttr(fpattr.getValueAsDouble()));

rewriter.replaceOpWithNewOp<Torch::AtenFullOp>(
binder.op, resultType, dimValueList, value_const, /*dtype=*/noneVal,
binder.op, resultType, dimValueList, splatvalue, /*dtype=*/noneVal,
/*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal);
return success();
});
Expand Down
28 changes: 0 additions & 28 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,23 +1825,6 @@
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseClampTensorIntModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64_basic",
"FullLikeModuleDefaultDtype_basic",
"FullLikeModuleFalsePinMemory_basic",
"FullLikeModuleFloat2D_basic",
"FullLikeModuleFloat3D_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"HBC_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
Expand All @@ -1856,10 +1839,6 @@
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
"IndexPutHackedTwin3DIntAccumulateModule_basic",
"NormalizeModule_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"PadWithNoneValModule_basic",
"QuantizedMLP_basic",
"RandModule_basic",
Expand All @@ -1875,13 +1854,6 @@
"TileSmallDimsSizeModule_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticSize_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_defaultDtype",
"ZerosLikeModule_falsePinMemory",
"ZerosLikeModule_float",
"ZerosLikeModule_int",

# Failure - onnx_lowering
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
Expand Down
31 changes: 0 additions & 31 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,37 +408,6 @@ def _handle_node_Constant(self, node: onnx.NodeProto) -> bool:
self._gi.initializer_map[const_name] = value_proto.t
return True

def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool:
# This op is special: It has an input of the shape, and in full generality
# could involve eager production of constants of variable size. In
# practice, the DNN profile for ONNX makes this very difficult to do
# and we hard-assert that the input can be resolved to an immediate
# value.
assert len(node.input) == 1
assert len(node.output) == 1
shape = self._get_immediate_tensor(node.input[0]).astype(np.int64)
value_proto = _get_attr(node, "value")
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
tensor_proto = value_proto.t
element_type = self._cc.tensor_element_type(tensor_proto.data_type)
vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type)
assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1
try:
cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type]
except KeyError:
raise OnnxImportError(
f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)"
)
value_attr = cb(tensor_proto, tuple(shape))
literal_op = Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
attributes={"value": value_attr},
)
self._nv_map[node.output[0]] = literal_op.result
return True


class ContextCache:
"""Caches per-context lookups of various things."""

Expand Down
36 changes: 18 additions & 18 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1539,14 +1539,14 @@ func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4
// CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00
Expand All @@ -1563,14 +1563,14 @@ func.func @test_constant_of_shape_dense_float_cst() -> !torch.vtensor<[2,3,4], f
// CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FILL_VAL:.*]] = torch.constant.float 3.4000000953674316
Expand All @@ -1587,14 +1587,14 @@ func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si6
// CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
// CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int
// CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FILL_VAL:.*]] = torch.constant.int 3
Expand Down
41 changes: 0 additions & 41 deletions test/python/onnx_importer/import_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@
"node_test_castlike_FLOAT_to_STRING_model",
"node_test_castlike_STRING_to_FLOAT_expanded_model",
"node_test_castlike_STRING_to_FLOAT_model",
"node_test_constantofshape_float_ones_model",
"node_test_constantofshape_int_shape_zero_model",
"node_test_constantofshape_int_zeros_model",
"node_test_dequantizelinear_e4m3fn_model",
"node_test_dequantizelinear_e4m3fn_zero_point_model",
"node_test_dequantizelinear_e5m2_model",
Expand All @@ -118,44 +115,6 @@
"node_test_if_model",
"node_test_if_opt_model",
"node_test_if_seq_model",
"node_test_layer_normalization_2d_axis0_expanded_model",
"node_test_layer_normalization_2d_axis0_expanded_ver18_model",
"node_test_layer_normalization_2d_axis1_expanded_model",
"node_test_layer_normalization_2d_axis1_expanded_ver18_model",
"node_test_layer_normalization_2d_axis_negative_1_expanded_model",
"node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model",
"node_test_layer_normalization_2d_axis_negative_2_expanded_model",
"node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model",
"node_test_layer_normalization_3d_axis0_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis1_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis2_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model",
"node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model",
"node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model",
"node_test_layer_normalization_4d_axis0_expanded_model",
"node_test_layer_normalization_4d_axis0_expanded_ver18_model",
"node_test_layer_normalization_4d_axis1_expanded_model",
"node_test_layer_normalization_4d_axis1_expanded_ver18_model",
"node_test_layer_normalization_4d_axis2_expanded_model",
"node_test_layer_normalization_4d_axis2_expanded_ver18_model",
"node_test_layer_normalization_4d_axis3_expanded_model",
"node_test_layer_normalization_4d_axis3_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_1_expanded_model",
"node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_2_expanded_model",
"node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_3_expanded_model",
"node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model",
"node_test_layer_normalization_4d_axis_negative_4_expanded_model",
"node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model",
"node_test_layer_normalization_default_axis_expanded_model",
"node_test_layer_normalization_default_axis_expanded_ver18_model",
"node_test_loop11_model",
"node_test_loop13_seq_model",
"node_test_loop16_seq_none_model",
Expand Down

0 comments on commit 53f6d06

Please sign in to comment.