Skip to content

Commit

Permalink
Shape inference: GatherBlockQuantized dispatcher (#23748)
Browse files Browse the repository at this point in the history
### Description
Add shape infer dispatcher for `GatherBlockQuantized` contrib op. It
reuses the dispatcher for `Gather` op since the first two inputs have
the same specs. The output elem type comes from input 2 (scales) for
`GatherBlockQuantized`.

### Motivation and Context
Support shape inference for models with `GatherBlockQuantized` op.
  • Loading branch information
jambayk authored Feb 21, 2025
1 parent 75cf166 commit 6715d4c
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
"FastGelu": self._infer_FastGelu,
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
"GatherBlockQuantized": self._infer_Gather,
"Gelu": self._infer_Gelu,
"GemmFastGelu": self._infer_GemmFastGelu,
"GemmFloat8": self._infer_GemmFloat8,
Expand Down Expand Up @@ -459,6 +460,7 @@ def _onnx_infer_single_node(self, node):
"BiasGelu",
"EmbedLayerNormalization",
"FastGelu",
"GatherBlockQuantized",
"Gelu",
"GemmFastGelu",
"LayerNormalization",
Expand Down Expand Up @@ -1118,10 +1120,17 @@ def _infer_Gather(self, node): # noqa: N802
axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
indices_shape = self._get_shape(node, 1)
vi = self.known_vi_[node.output[0]]
if node.op_type == "Gather":
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
elif node.op_type == "GatherBlockQuantized":
# scales
elem_type = self.known_vi_[node.input[2]].type.tensor_type.elem_type
else:
raise ValueError(f"Unsupported Gather op_type: {node.op_type}")
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[0],
self.known_vi_[node.input[0]].type.tensor_type.elem_type,
elem_type,
data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
)
)
Expand Down

0 comments on commit 6715d4c

Please sign in to comment.