diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index a68d03472e3..9f12749d98e 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1805,24 +1805,24 @@ void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs, llvm::map_range(adaptor.getInitValues().getTypes(), [](Type t) { return t.cast(); })}; - if (succeeded(hlo::verifyReduceOpInputsAndInferShape( + if (failed(hlo::verifyReduceOpInputsAndInferShape( odsState.location, inputArgTensorTypes, dimensions, newDimensions, - encoding))) { - SmallVector inferredReturnTypes; - for (auto [inputTy, elementTy] : - llvm::zip(inputArgTensorTypes, elementTypes)) { - if (inputTy.hasRank()) { - inferredReturnTypes.push_back( - RankedTensorType::get(newDimensions, elementTy, encoding)); - } else { - assert(encoding == nullptr && "attribute not supported"); - inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); - } - } - odsState.addTypes(inferredReturnTypes); - } else { + encoding))) llvm::report_fatal_error("Failed to infer result type(s)."); + + SmallVector inferredReturnTypes; + for (auto [inputTy, elementTy] : + llvm::zip(inputArgTensorTypes, elementTypes)) { + if (inputTy.hasRank()) { + inferredReturnTypes.push_back( + RankedTensorType::get(newDimensions, elementTy, encoding)); + } else { + if (encoding != nullptr) + llvm::report_fatal_error("attribute not supported."); + inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); + } } + odsState.addTypes(inferredReturnTypes); } LogicalResult ReduceOp::verify() {