Skip to content

Commit

Permalink
Changes to refine types
Browse files Browse the repository at this point in the history
- Add `!torch.optional` knowledge tracking
- Changes to improve type propagation for branches and terminators. See
examples in `refine-types-branch.mlir`
- Refator to separate handling of different ops from `visitOperation`
- Add refine types for a few new ops
  • Loading branch information
cathyzhyi committed Aug 27, 2021
1 parent bc5eae4 commit d6b9709
Show file tree
Hide file tree
Showing 5 changed files with 1,002 additions and 235 deletions.
1 change: 1 addition & 0 deletions include/npcomp/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def Torch_PrimLoopOp : Torch_Op<"prim.Loop", [
}

def Torch_PrimLoopConditionOp : Torch_Op<"prim.Loop.condition", [
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
Terminator,
HasParent<"::mlir::NPCOMP::Torch::PrimLoopOp">]> {
let summary = "yield-like terminator for torch.prim.Loop";
Expand Down
20 changes: 16 additions & 4 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ void PrimLoopOp::getSuccessorRegions(
regions.emplace_back(getResults());
}

//===----------------------------------------------------------------------===//
// PrimLoopConditionOp
//===----------------------------------------------------------------------===//

MutableOperandRange
PrimLoopConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
// Pass all operands except the condition to the successor which is the
// parent loop op.
return iterArgsMutable();
}

//===----------------------------------------------------------------------===//
// PrimIfOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -357,10 +368,11 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
// TODO: Extend RefineTypes for this case and delete this canonicalization,
// since we don't want control flow or calls to randomly block this fold
// (this canonicalization pattern makes the compiler brittle to control flow
// and calls).
// TODO: This pattern should be removed because type refine does a better
// job dealing with control flow. However, removing this would expose an
// issue with ReduceOpVariants. DerefineOp doesn't have value semantics and
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants
// from converting certain tensors value semantics.
bool allAllowRefinement =
llvm::all_of(op.getResult().getUsers(), allowsTypeRefinement);
if (!allAllowRefinement)
Expand Down
Loading

0 comments on commit d6b9709

Please sign in to comment.