diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp index 84c3c6fb6c29f..2e5d0dcd4114c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp @@ -228,6 +228,8 @@ struct GenericOpTypePropagation } auto inputOperandType = llvm::cast(genericOp->getOperandTypes()[index]); + assert(inputOperandType.getElementType() == argType && + "expected same element type"); std::optional legalizedArgType = legalizeStorageElementType(inputOperandType); if (!legalizedArgType) { @@ -259,6 +261,8 @@ struct GenericOpTypePropagation modifyYield = true; OpOperand *yieldOperand = modifiedOp.getMatchingYieldValue(modifiedOpOperand); + assert(llvm::cast(modifiedOpOperand->get().getType()).getElementType() == + yieldOperand->get().getType() && "expected same element type"); std::optional legalizedType = legalizeStorageElementType(modifiedOpOperand->get().getType()); if (!legalizedType) { @@ -289,8 +293,11 @@ struct LinalgFillTypePropagation matchAndRewrite(linalg::FillOp fillOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Value value = adaptor.getInputs().front(); + TensorType outputType = cast(adaptor.getOutputs()[0].getType()); + assert(outputType.getElementType() == value.getType() && + "expected same element type"); std::optional legalizedElementType = - legalizeStorageElementType(adaptor.getOutputs()[0].getType()); + legalizeStorageElementType(outputType); if (!legalizedElementType) { return fillOp.emitOpError("failed to get legalized type for value"); }