Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Co-authored-by: Rachel Han <hanrach@google.com>
  • Loading branch information
abhigunj and hanrach9 authored Jan 29, 2025
1 parent 7c50d4e commit 48a1e14
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 27 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d"
LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24"

LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f"
LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e2402615a5a76d46a433dfcc1de10b38a1263c9d
aa65f93b71dee8cacb22be1957673c8be6a3ec24
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,22 @@ bool isValidQuantizedDimension(Type type) {
numScales == rankedType.getDimSize(quantDim));
}

bool hasSingleBoundedDimension(Type type) {
RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
auto boundedAttr =
dyn_cast_or_null<BoundedAttrInterface>(rankedType.getEncoding());
if (!boundedAttr) return false;

// Count if bounded attr size is not kDynamic
int64_t numBoundedDims = llvm::count_if(
boundedAttr.getBounds(),
[](int64_t bound) { return !ShapedType::isDynamic(bound); });
// Also check that there are only bounded dims and no unbounded dims.
int64_t numDynamicDims = llvm::count_if(
rankedType.getShape(),
[](int64_t bound) { return ShapedType::isDynamic(bound); });
return numBoundedDims == 1 && numDynamicDims == 1;
}

} // namespace hlo
} // namespace mlir
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType);
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

// Returns true if the given type has a single bounded dimension.
bool hasSingleBoundedDimension(Type type);

// TODO(zhouxin) Move type inference related methods to TypeInference.cpp

std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def I32RankedTensor : RankedTensorOf<[I32]>;

def UI32RankedTensor : RankedTensorOf<[UI32]>;

//===----------------------------------------------------------------------===//
// HLO type constraints.
//===----------------------------------------------------------------------===//

// Note: Bounded dynamisms is largely unspecced and this feature needs more
// thoguht as it is adopted to modern frameworks. The current support is
// designed to allow existing TF programs to be representable in StableHLO and
// is subject to change as a formal design for boudned dynamism is developed.
def HLO_HasSingleBoundedDimensionPred
: CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">;

def HLO_HasStaticOrSingleBoundedShapePred
: Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>;

//===----------------------------------------------------------------------===//
// HLO type definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -267,6 +281,9 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;

def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">;

def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;

def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>;
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1980,7 +1980,7 @@ def StableHLO_BroadcastInDimOp : StableHLO_Op<"broadcast_in_dim",
DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/
);

let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);

let hasVerifier = 1;

Expand Down Expand Up @@ -2732,7 +2732,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape",

let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand);

let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor);
let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor);
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3724,9 +3724,8 @@ LogicalResult verifyBroadcastInDimOp(std::optional<Location> location,
Value operand,
ArrayRef<int64_t> broadcastDimensions,
Value result) {
auto operandType = cast<RankedTensorType>(operand.getType());

// broadcast_in_dim_c1
auto operandType = cast<RankedTensorType>(operand.getType());
if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType,
result.getType())))
return failure();
Expand Down Expand Up @@ -4658,11 +4657,12 @@ LogicalResult verifyReshapeOp(std::optional<Location> location, Value operand,
Value result) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operandTy = cast<RankedTensorType>(operand.getType());
if (!operandTy.hasStaticShape()) return success();
auto resultTy = cast<RankedTensorType>(result.getType());
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
return success();

// If the operand type is statically shaped (not required) the number of
// elements must match that of the result type.
auto resultTy = cast<RankedTensorType>(result.getType());
int64_t numResultElements = resultTy.getNumElements();
int64_t numOperandElements = operandTy.getNumElements();
if (numResultElements != numOperandElements)
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ FailureOr<int64_t> Version::getBytecodeVersion() const {
Version Version::fromCompatibilityRequirement(
CompatibilityRequirement requirement) {
// Compatibility requirement versions can be updated as needed, as long as the
// version satisifies the requirement.
// version satisfies the requirement.
// The time frames used are from the date that the release was tagged on, not
// merged. The tag date is when the version has been verified and exported to
// XLA. See: https://github.com/openxla/stablehlo/tags
Expand Down
17 changes: 9 additions & 8 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h"
Expand All @@ -40,7 +40,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"
#include "stablehlo/dialect/AssemblyFormat.h"
#include "stablehlo/dialect/AssemblyFormat.h" // IWYU pragma: keep
#include "stablehlo/dialect/Version.h"
#include "stablehlo/dialect/VhloBytecode.h"
#include "stablehlo/dialect/VhloTypes.h"
Expand Down Expand Up @@ -184,12 +184,13 @@ ParseResult parseFunctionBody(OpAsmParser& parser, Attribute& name,
return success();
}

void TensorV1Attr::print(mlir::AsmPrinter& p) const {
p << '<'
<< DenseIntOrFPElementsAttr::getFromRawBuffer(
llvm::cast<ShapedType>(convertTypeToBuiltinForPrint(getType())),
getData())
<< '>';
void TensorV1Attr::print(mlir::AsmPrinter& odsPrinter) const {
odsPrinter << '<'
<< DenseIntOrFPElementsAttr::getFromRawBuffer(
llvm::cast<ShapedType>(
convertTypeToBuiltinForPrint(getType())),
getData())
<< '>';
}

// Parse tensor elements using DenseIntOrFPElementsAttr printing.
Expand Down
13 changes: 6 additions & 7 deletions stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isFloat4E2M1FN() || type.isFloat6E2M3FN() ||
type.isFloat6E3M2FN() || type.isFloat8E3M4() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3() ||
type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() ||
type.isFloat8E8M0FNU() || type.isF16() || type.isBF16() ||
type.isF32() || type.isF64();
return llvm::isa<
mlir::Float4E2M1FNType, mlir::Float6E2M3FNType, mlir::Float6E3M2FNType,
mlir::Float8E3M4Type, mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3Type,
mlir::Float8E4M3FNType, mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
mlir::Float8E5M2FNUZType, mlir::Float8E8M0FNUType, mlir::Float16Type,
mlir::BFloat16Type, mlir::Float32Type, mlir::Float64Type>(type);
}

bool isSupportedComplexType(Type type) {
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,22 @@ func.func @broadcast_in_dim_c5(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> {

// -----

// CHECK-LABEL: func @broadcast_in_dim_dynamic_i1
func.func @broadcast_in_dim_dynamic_i1(%arg0: tensor<?xi32>) -> tensor<1x3xi32> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<?xi32>) -> tensor<1x3xi32>
return %0 : tensor<1x3xi32>
}

// -----

func.func @broadcast_in_dim_dynamic_result(%arg0: tensor<3xi32>) -> tensor<?x3xi32> {
// expected-error@+1 {{must be statically shaped or single bounded dimension tensor}}
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1>} : (tensor<3xi32>) -> tensor<?x3xi32>
func.return %0 : tensor<?x3xi32>
}

// -----

// Regression test for b/180052624, where this was improperly marked as an
// invalid stablehlo.broadcast_in_dim op.
// CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand
Expand Down
63 changes: 63 additions & 0 deletions stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s

// This file captures some quirks to bounded dynamism in StableHLO that are
// included to allow StableHLO to repersent existing TF programs.

// CHECK-LABEL: reshape_with_single_bounded_dimension
func.func @reshape_with_single_bounded_dimension(%arg0: tensor<?x2xf32, #stablehlo.bounds<5, ?>>) -> tensor<2x?xf32, #stablehlo.bounds<?, 5>> {
%0 = stablehlo.reshape %arg0 : (tensor<?x2xf32, #stablehlo.bounds<5, ?>>) -> tensor<2x?xf32, #stablehlo.bounds<?, 5>>
// CHECK: return {{.*}} #stablehlo.bounds<?, 5>
return %0 : tensor<2x?xf32, #stablehlo.bounds<?, 5>>
}

// -----

// CHECK-LABEL: reshape_scalar_with_single_bounded_dimension
func.func @reshape_scalar_with_single_bounded_dimension(%arg0: tensor<?xf32, #stablehlo.bounds<5>>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
%0 = stablehlo.reshape %arg0 : (tensor<?xf32, #stablehlo.bounds<5>>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>>
// CHECK: return {{.*}} #stablehlo.bounds<?, 5>
return %0 : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}

// -----

func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<?x?xf32, #stablehlo.bounds<5, 5>> {
// expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}}
%0 = stablehlo.reshape %arg0 : (tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<?x?xf32, #stablehlo.bounds<5, 5>>
return %0 : tensor<?x?xf32, #stablehlo.bounds<5, 5>>
}

// -----

// CHECK-LABEL: broadcast_in_dim_with_single_bounded_dimension
func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, #stablehlo.bounds<?, 5>>) -> tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>> {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds<?, 5>>) -> tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>>
// CHECK: return {{.*}} #stablehlo.bounds<?, ?, 5>
return %0 : tensor<2x1x?xf32, #stablehlo.bounds<?, ?, 5>>
}

// -----

func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>> {
// expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}}
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<?x?xf32, #stablehlo.bounds<5, 5>>) -> tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>>
return %0 : tensor<2x?x?xf32, #stablehlo.bounds<?, 5, 5>>
}

// -----

// CHECK-LABEL: constant_splat_broadcast
func.func @constant_splat_broadcast() -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
%0 = stablehlo.constant dense<1.0> : tensor<f32>
%1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor<f32>) -> tensor<1x?xf32, #stablehlo.bounds<?, 5>>
// CHECK: tensor<1x?xf32, #stablehlo.bounds<?, 5>>
return %1 : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}

// -----

func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #stablehlo.bounds<?, 5>> {
// expected-error@+2 {{elements literal type must have static shape}}
%c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
return %c : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}
Original file line number Diff line number Diff line change
Expand Up @@ -1940,6 +1940,17 @@ func.func @reorder_with_type_change(%arg0 : tensor<3x4xi32>) -> tensor<12xi64> {
return %1 : tensor<12xi64>
}

// -----

// CHECK-LABEL: @reorder_invalid_with_dynamic_shape
func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor<?x4xf32>) {
// CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
// CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor<?x4xf32>
// CHECK: return %[[CONVERT]]
%0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32>
%1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor<?x4xf32>
return %1 : tensor<?x4xf32>
}

// -----

Expand Down
5 changes: 5 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ limitations under the License.
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/Version.h"

namespace mlir {
namespace stablehlo {

#define GEN_PASS_DECL

std::unique_ptr<::mlir::Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config);

#define GEN_PASS_REGISTRATION
#include "stablehlo/transforms/Passes.h.inc"

Expand Down
22 changes: 20 additions & 2 deletions stablehlo/transforms/StablehloAggressiveSimplification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>

Expand All @@ -21,6 +22,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
Expand All @@ -38,6 +40,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand Down Expand Up @@ -1447,12 +1450,18 @@ struct ReorderElementwiseAndShapeOp final
return rewriter.notifyMatchFailure(
op, "defining operation of unexpected type");

// Reshape and broadcast are not allowed to have dynamic shape.
Value result = op->getResult(0);
if (isa<ReshapeOp, BroadcastOp>(definingOp) &&
!cast<ShapedType>(result.getType()).hasStaticShape())
return rewriter.notifyMatchFailure(
op, "cannot reorder around reshape/broadcast with dynamic shape");

// Only reorder if the defining op has no other uses.
if (!llvm::hasSingleElement(definingOp->getResult(0).getUses()))
return rewriter.notifyMatchFailure(op, "operation has more than one use");

Value input = definingOp->getOperand(0);
Value result = op->getResult(0);
auto intermediateType = cast<ShapedType>(input.getType())
.clone(getElementTypeOrSelf(result.getType()));

Expand All @@ -1470,6 +1479,9 @@ struct ReorderElementwiseAndShapeOp final
struct StablehloAggressiveSimplificationPass final
: impl::StablehloAggressiveSimplificationPassBase<
StablehloAggressiveSimplificationPass> {
StablehloAggressiveSimplificationPass() = default;
StablehloAggressiveSimplificationPass(GreedyRewriteConfig config)
: config(config) {}
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet patterns_(context);
populateStablehloCanonicalizationPatterns(context, &patterns_);
Expand All @@ -1478,11 +1490,12 @@ struct StablehloAggressiveSimplificationPass final
}

void runOnOperation() override {
if (failed(applyPatternsGreedily(getOperation(), patterns)))
if (failed(applyPatternsGreedily(getOperation(), patterns, config)))
signalPassFailure();
}

private:
GreedyRewriteConfig config;
FrozenRewritePatternSet patterns;
};

Expand Down Expand Up @@ -1515,5 +1528,10 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
}

std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config) {
return std::make_unique<StablehloAggressiveSimplificationPass>(config);
}

} // namespace stablehlo
} // namespace mlir

0 comments on commit 48a1e14

Please sign in to comment.