Skip to content

Commit

Permalink
Changes to refine types
Browse files Browse the repository at this point in the history
- Add `!torch.optional` knowledge tracking
- Changes to improve type propagation for branches and terminators. See
examples in `refine-types-branch.mlir`
- Refator to separate handling of different ops from `visitOperation`
- Add refine types for a few new ops
  • Loading branch information
cathyzhyi committed Aug 26, 2021
1 parent d8db41b commit 8445e35
Show file tree
Hide file tree
Showing 5 changed files with 1,000 additions and 249 deletions.
3 changes: 1 addition & 2 deletions include/npcomp/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
}

def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
let summary = "yield-like terminator for torch.prim.Loop";
Expand Down Expand Up @@ -641,8 +642,6 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];

let hasCanonicalizer = 1;
}

def Torch_OperatorOp : Torch_Op<"operator", [
Expand Down
27 changes: 11 additions & 16 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ void PrimLoopOp::getSuccessorRegions(
regions.emplace_back(getResults());
}

//===----------------------------------------------------------------------===//
// PrimLoopConditionOp
//===----------------------------------------------------------------------===//

MutableOperandRange
PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
// Pass all operands except the condition to the successor which is the
// parent loop op.
return iterArgsMutable();
}

//===----------------------------------------------------------------------===//
// PrimIfOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -350,22 +361,6 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(inputs[0], outputs[0]);
}

void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
// TODO: Extend RefineTypes for this case and delete this canonicalization,
// since we don't want control flow or calls to randomly block this fold
// (this canonicalization pattern makes the compiler brittle to control flow
// and calls).
bool allAllowRefinement =
llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement);
if (!allAllowRefinement)
return failure();
rewriter.replaceOp(op, op.getOperand());
return success();
});
}

template <typename OpTy>
static OpFoldResult atenIsOrIsNotFoldHelper(OpTy op, bool equalIsTrue) {
Type lhsType = op.self().getType();
Expand Down
Loading

0 comments on commit 8445e35

Please sign in to comment.