Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix onnx importer to treat Constant values as static #2780

Merged
merged 14 commits into from
Jan 22, 2024
45 changes: 37 additions & 8 deletions python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto):
with InsertionPoint(self._b), Location.name(node.name):
op_type = node.op_type
# Handle special op types that materialize to non-op IR constructs.
# Handlers return True if the op was handled, else this function
# should process it as a general node.
special_key = f"_handle_node_{op_type}"
if hasattr(self, special_key):
getattr(self, special_key)(node)
return
was_handled = getattr(self, special_key)(node)
if was_handled:
return

# General node import.
input_values = []
Expand Down Expand Up @@ -333,16 +336,19 @@ def import_attributes(
)
attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc)

def import_initializer(self, initializer: onnx.TensorProto) -> Value:
with InsertionPoint(self._b), Location.name(initializer.name):
def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value:
# If an explicitly specified name is given, use that; otherwise, pick
# up the name from the tensor proto itself
iname = extern_name if extern_name else initializer.name
with InsertionPoint(self._b), Location.name(iname):
value_attr = self._cc.tensor_proto_to_attr(initializer)
vtensor_type = self._cc.tensor_proto_to_type(initializer)
literal_op = Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
attributes={"value": value_attr},
)
self._nv_map[initializer.name] = literal_op.result
self._nv_map[iname] = literal_op.result
return literal_op.result

def _get_immediate_tensor(self, name: str) -> np.array:
Expand All @@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array:
f"Unhandled ONNX TensorProto immediate data: {initializer}"
)

def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
def _handle_node_Constant(self, node: onnx.NodeProto) -> bool:
# Special case only for constants specified by value attribute (for now)
value_proto = _get_attr(node, "value", False)
if not value_proto:
return False

# Produce an initializer for the constant, so that it can be used in
# combination with other ops, such as ConstantOfShape, requiring
# a constant input
assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR
assert len(node.output) == 1
const_name = node.output[0]
self.import_initializer(value_proto.t, const_name)
self._gi.initializer_map[const_name] = value_proto.t
return True

def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool:
# This op is special: It has an input of the shape, and in full generality
# could involve eager production of constants of variable size. In
# practice, the DNN profile for ONNX makes this very difficult to do
Expand Down Expand Up @@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto):
attributes={"value": value_attr},
)
self._nv_map[node.output[0]] = literal_op.result
return True


class ContextCache:
Expand Down Expand Up @@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0])
),
onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat(
RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get(
IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little",
signed=True) if tp.HasField("raw_data") else tp.int64_data[0])
),
# TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB
}

Expand Down Expand Up @@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute:
}


def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto:
def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto:
for attr in node.attribute:
if attr.name == attr_name:
return attr
else:
if is_required:
raise OnnxImportError(f"Required attribute {attr_name} not found in {node}")
return None