Skip to content

Commit

Permalink
Fix GroupNorm to support Opset21 (#2928)
Browse files Browse the repository at this point in the history
* Group norm for opset 21

* Testing phase

* Fix GroupNorm to support Opset21

---------

Signed-off-by: hamptonm1 <79232909+hamptonm1@users.noreply.github.com>
Co-authored-by: Megan Hampton <hamptonm@us.ibm.com>
  • Loading branch information
hamptonm1 and MegoHam21 authored Sep 13, 2024
1 parent 97d497f commit 2f2ccc5
Show file tree
Hide file tree
Showing 14 changed files with 383 additions and 109 deletions.
57 changes: 57 additions & 0 deletions docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -3589,6 +3589,63 @@ where the mean and variance are computed per instance per group of channels, and
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

The overall computation has two stages: the first stage normalizes the elements to
have zero mean and unit variance for each instance in each group, and the second
stage scales and shifts the results of the first stage. The floating-point precision
used in the first stage is determined by the `stash_type` attribute. For example,
if `stash_type` is 1, the operator casts all input variables to 32-bit float,
performs the computation, and finally casts the normalized results back to the
original type of `X`. The second stage does not depend on `stash_type`.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>epsilon</code></td><td>::mlir::FloatAttr</td><td>32-bit float attribute</td></tr>
<tr><td><code>num_groups</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
<tr><td><code>stash_type</code></td><td>::mlir::IntegerAttr</td><td>64-bit signed integer attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `X` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
| `scale` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values
| `bias` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values

#### Results:

| Result | Description |
| :----: | ----------- |
| `Y` | tensor of bfloat16 type values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values

### `onnx.GroupNormalizationV18` (ONNXGroupNormalizationV18Op)

_ONNX GroupNormalization operation_

A GroupNormalization function. Carries out group normalization as described in
the paper https://arxiv.org/abs/1803.08494

This operator transforms input according to
```
y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
```
where the mean and variance are computed per instance per group of channels, and
`scale` and `bias` should be specified for each group of channels. The number of
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.
Expand Down
4 changes: 2 additions & 2 deletions docs/SupportedONNXOps-NNPA.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# Supported ONNX Operation for Target *NNPA*.

Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.
Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.

* Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md).
* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21.


NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA.
Expand Down
4 changes: 2 additions & 2 deletions docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# Supported ONNX Operation for Target *cpu*.

Onnx-mlir currently supports ONNX operations targeting up to opset 20. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.
Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes.

* Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md).
* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 20.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21.


| Op |Supported Opsets (inclusive) |Limitations |Notes |
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
lit~=15.0
# numpy 1.24 deprecates np.object, np.bool, np.float, np.complex, np.str,
# and np.int which are used heavily in onnx-mlir.
numpy~=1.22.2, <=1.23.5
numpy==2.0.1
onnx==1.16.2
protobuf==4.21.12
pytest~=7.2
pytest-xdist~=3.0
pytest==8.3.2
pytest-xdist==3.6.1
4 changes: 3 additions & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ op_dialect_version_map_["Gradient"] = {1};
op_dialect_version_map_["Greater"] = {13};
op_dialect_version_map_["GreaterOrEqual"] = {16};
op_dialect_version_map_["GridSample"] = {16};
op_dialect_version_map_["GroupNormalization"] = {18};
op_dialect_version_map_["GroupNormalization"] = {21, 18};
op_dialect_version_map_["HammingWindow"] = {17};
op_dialect_version_map_["HannWindow"] = {17};
op_dialect_version_map_["HardSigmoid"] = {6};
Expand Down Expand Up @@ -358,6 +358,8 @@ import_handler_map_["GridSample"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGridSampleOp>;
import_handler_map_["GroupNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGroupNormalizationOp>;
import_handler_map_["GroupNormalizationV18"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGroupNormalizationV18Op>;
import_handler_map_["HammingWindow"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXHammingWindowOp>;
import_handler_map_["HannWindow"] =
Expand Down
15 changes: 14 additions & 1 deletion src/Dialect/ONNX/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===----- DialectBuilder.cpp - Helper functions for ONNX dialects -------===//
//
// Copyright 2019-2023 The IBM Research Authors.
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -164,6 +164,19 @@ Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
toTensor(bias), axisAttr, epsilon, stashTypeAttr);
return layerNormOp.getY();
}
// In the case of GroupNormalization when stashType can be specified
Value OnnxBuilder::layerNorm(Type outputType, Value input, Value scale,
Value bias, int64_t axis, FloatAttr epsilon, IntegerAttr stashType) const {
IntegerAttr axisAttr = getSignedInt64Attr(axis);
Value noneVal = none();
Type noneType = noneVal.getType();
ONNXLayerNormalizationOp layerNormOp =
createOpAndInferShapes<ONNXLayerNormalizationOp>(
/*Y type*/ toTensor(outputType), /*mean type*/ noneType,
/*std dev Type*/ noneType, toTensor(input), toTensor(scale),
toTensor(bias), axisAttr, epsilon, stashType);
return layerNormOp.getY();
}

Value OnnxBuilder::qlinearMatMul(Type outputType, Value a, Value aScale,
Value aZeroPoint, Value b, Value bScale, Value bZeroPoint, Value yScale,
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ struct OnnxBuilder : DialectBuilder {
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon) const;
// In the case of GroupNormalization when stashType can be specified
mlir::Value layerNorm(mlir::Type outputType, mlir::Value input,
mlir::Value scale, mlir::Value bias, int64_t axis,
mlir::FloatAttr epsilon, mlir::IntegerAttr stashType) const;

// ONNXQLinearMatMulOp
mlir::Value qlinearMatMul(mlir::Type outputType, mlir::Value a,
Expand Down
59 changes: 58 additions & 1 deletion src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3122,6 +3122,62 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization",
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

The overall computation has two stages: the first stage normalizes the elements to
have zero mean and unit variance for each instance in each group, and the second
stage scales and shifts the results of the first stage. The floating-point precision
used in the first stage is determined by the `stash_type` attribute. For example,
if `stash_type` is 1, the operator casts all input variables to 32-bit float,
performs the computation, and finally casts the normalized results back to the
original type of `X`. The second stage does not depend on `stash_type`.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$X,
AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$scale,
AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$bias,
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
SI64Attr:$num_groups,
DefaultValuedAttr<SI64Attr, "1">:$stash_type);
let results = (outs AnyTypeOf<[TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>]>:$Y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 3;
}
static int getNumberOfResults() {
return 1;
}
static std::vector<int> getTypeMap() {
return {30};
}
}];
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
}

def ONNXGroupNormalizationV18Op:ONNX_Op<"GroupNormalizationV18",
[Pure, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ONNX GroupNormalization operation";
let description = [{
A GroupNormalization function. Carries out group normalization as described in
the paper https://arxiv.org/abs/1803.08494

This operator transforms input according to
```
y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
```
where the mean and variance are computed per instance per group of channels, and
`scale` and `bias` should be specified for each group of channels. The number of
groups `num_groups` should be divisible by the number of channels so that there are
an equal number of channels per group.

When the number of groups is the same as the number of channels, this operator is
equivalent to InstanceNormalization. When there is only one group, this operator
is equivalent to LayerNormalization.
Expand All @@ -3146,11 +3202,12 @@ def ONNXGroupNormalizationOp:ONNX_Op<"GroupNormalization",
let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * $cppClass::getShapeHelper(mlir::Operation *op, llvm::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationOpShapeHelper(op, oper, ieb, scope);
onnx_mlir::ONNXOpShapeHelper *sh = new onnx_mlir::ONNXGroupNormalizationV18OpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
let hasVerifier = 1;
}

def ONNXHammingWindowOp:ONNX_Op<"HammingWindow",
Expand Down
15 changes: 15 additions & 0 deletions src/Dialect/ONNX/ONNXOps/NN/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ LogicalResult ONNXInstanceNormalizationOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// GroupNormalizationV18
//===----------------------------------------------------------------------===//
LogicalResult ONNXGroupNormalizationV18Op::verify() {
ONNXGroupNormalizationV18OpAdaptor(*this);
llvm::outs()
<< "Warning: The previous understanding of Opset 18 for "
"GroupNormalization "
"is incorrect. As shown in the following issue: "
"https://github.com/onnx/onnx/issues/5466.Rather, use Opset 21 for "
"GroupNormalization instead."
<< "/n";
return success();
}

// TODO: should there be a shape inference for this one?

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXUnsupportedOps.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ CONVERTED_TO_SUPPORTED_OPS(ONNXClipV12Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXClipV6Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXDFTV17Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationOp)
CONVERTED_TO_SUPPORTED_OPS(ONNXGroupNormalizationV18Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV18Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV13Op)
CONVERTED_TO_SUPPORTED_OPS(ONNXPadV11Op)
Expand Down
Loading

0 comments on commit 2f2ccc5

Please sign in to comment.