From e9b7418350b4450d118b88954bfd50efa84f150a Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Fri, 24 Jun 2022 16:34:33 +0000 Subject: [PATCH] Automatic Stochastic depth on residual blocks --- composer/utils/fx_utils.py | 145 +++++++++++++++++++++++++++++++++-- tests/utils/test_fx_utils.py | 27 ++++++- 2 files changed, 164 insertions(+), 8 deletions(-) diff --git a/composer/utils/fx_utils.py b/composer/utils/fx_utils.py index 85faac80268..81d1b65c322 100644 --- a/composer/utils/fx_utils.py +++ b/composer/utils/fx_utils.py @@ -8,18 +8,20 @@ import logging import operator -from typing import Any, Callable, Dict, List, Mapping, Tuple, Union +import re +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import torch import torch.nn as nn from torch.fx import Node from torch.fx.graph_module import GraphModule +from torch.fx.passes.split_utils import split_by_tags from composer.utils import ensure_tuple log = logging.getLogger(__name__) -__all__ = ['count_op_instances', 'replace_op', 'fuse_parallel_linears'] +__all__ = ['count_op_instances', 'replace_op', 'fuse_parallel_linears', 'apply_stochastic_residual'] def count_op_instances(gm: GraphModule, ops: Union[Callable, str, List[Union[Callable, str]]]) -> int: @@ -111,6 +113,40 @@ def replace_op(gm: GraphModule, src_ops: Union[Callable, str, List[Union[Callabl return gm +class BlockStochasticModule(nn.Module): + """A convenience class that stochastically executes the provided non-residual path. + + Args: + lhs (GraphModule): Operators in the non-residual path of a residual block. + rhs (GraphModule | None): Operators, if any, in the residual path of a residual block. + drop_rate: The base probability of dropping this layer. Must be between 0.0 (inclusive) and 1.0 (inclusive). + + Returns: + BlockStochasticModule: An instance of :class:`.BlockStochasticModule`. + """ + + def __init__(self, lhs: GraphModule, rhs: Optional[GraphModule] = None, drop_rate: float = 0.2): + super().__init__() + self.drop_rate = torch.tensor(drop_rate) + self.lhs = lhs + self.rhs = rhs + + def forward(self, x): + sample = (not self.training) or bool(torch.bernoulli(1 - self.drop_rate)) + # lhs side is the non-residual connection + rhs_result = x + # rhs side may or may not have any operations + if self.rhs: + rhs_result = self.rhs(x) + + if sample: + lhs_result = self.lhs(x) + if not self.training: + lhs_result = lhs_result * (1 - self.drop_rate) + rhs_result = torch.add(lhs_result, rhs_result) + return rhs_result + + def detect_residual_pattern(gm: GraphModule): """Search and replace the pattern with another. @@ -123,16 +159,111 @@ def detect_residual_pattern(gm: GraphModule): raise NotImplementedError('detect_residual_pattern is currently not implemented.') -def replace_residual_with_stochastic(gm: GraphModule): - """Replaces residual pattern with their stoachstic equivalent. +def _get_ancestors(node: Node) -> List[Node]: + ancestorNodes = [] + while node.op != 'placeholder': + ancestorNodes.append(node) + node = node.all_input_nodes[0] + return ancestorNodes + + +def _get_residual_path(nodeLHS: Node, nodeRHS: Node) -> Tuple[List[Node], List[Node]]: + """Walk backwards from nodeLHS and nodeRSH to the root and construct lists of their parents. Arguments: - gm (GraphModule): The source FX-traced graph. + nodeLHS (Node): left-hand side node for a binary operator + nodeRHS (Node): right-hand side node for a binary operator Returns: - GraphModule: Modified GraphModule. + (lhsAncestors, rhsAncestors): Two lists of nodes containing ancestors for ``nodeLHS`` and ``nodeRHS`` with + their common ancestors removed. + """ + lhsAncestors = _get_ancestors(nodeLHS) + rhsAncestors = _get_ancestors(nodeRHS) + + # Iterate from back and eliminate common nodes + while lhsAncestors and rhsAncestors and lhsAncestors[-1] == rhsAncestors[-1]: # type: ignore + del lhsAncestors[-1] + del rhsAncestors[-1] + lhsAncestors.reverse() + rhsAncestors.reverse() + return lhsAncestors, rhsAncestors + + +def _attach_tag(nodes: List[Node], tag: str): + """Attach tag to the given nodes for the splitter.""" + for node in nodes: + node.tag = tag # type: ignore[attr-defined] + + +def _tag_nodes(gm: GraphModule) -> Tuple[List[str], int]: + """Tag nodes for splitting.""" + # all nodes that are not a part of the residual blocks are tagged with "main_0". + # a tag is required for all nodes by split_by_tags + # Also an earlier tag can be repeated for later nodes. + count = 0 + all_tags = [] + # In this pass over all nodes, we just tag them + for node in gm.graph.nodes: + default_tag = f'mainN_{count}' + node.tag = default_tag + if default_tag not in all_tags: + all_tags.append(default_tag) + if node.op == 'call_function' and node.target in [torch.add, operator.add]: + assert len(node.all_input_nodes) == 2 + node0, node1 = node.all_input_nodes[0], node.all_input_nodes[1] + lhs_nodes, rhs_nodes = _get_residual_path(node0, node1) + if lhs_nodes or rhs_nodes: + if len(lhs_nodes): + _attach_tag(lhs_nodes, f'non_res_{count}') + all_tags.append(f'non_res_{count}') + if len(rhs_nodes): + _attach_tag(rhs_nodes, f'residual_{count}') + all_tags.append(f'residual_{count}') + add_tag = f'addN_{count}' + if add_tag not in all_tags: + all_tags.append(add_tag) + node.tag = add_tag + count += 1 + return all_tags, count + + +def apply_stochastic_residual(gm: GraphModule, drop_rate: float = 0.2) -> Tuple[GraphModule, int]: + """Detect and replace residual pattern with their stochastic equivalent. + + Arguments: + gm (GraphModule): The source FX-traced graph. It can be the whole model symbolically traced. + + Returns: + GraphModule: Modified GraphModule that has stochastic residual connections. """ - raise NotImplementedError('replace_residual_with_stochastic is currently not implemented.') + assert isinstance(gm, GraphModule), 'Input to apply_stochastic_residual should be an instance of GraphModule' + all_tags, count = _tag_nodes(gm) + split_gm = split_by_tags(gm, all_tags) + pattern = re.compile(r'non_res_(\d+)|residual_(\d+)') + for node in split_gm.graph.nodes: + if node.op == 'call_module': + matches = pattern.match(node.target) + if matches: + idx = int(matches[1]) if matches[1] else int(matches[2]) + lhs_submod = getattr(split_gm, f'non_res_{idx}') + rhs_submod = getattr(split_gm, f'residual_{idx}', None) + bl_st_instance = BlockStochasticModule(lhs_submod, rhs_submod, drop_rate) + split_gm.add_submodule(f'bl_st_{idx}', bl_st_instance) # type: ignore + insert_node = node.prev + add_node = node.next + if rhs_submod: + add_node = node.next.next + with split_gm.graph.inserting_after(insert_node): + new_node = split_gm.graph.call_module(f'bl_st_{idx}', args=(insert_node,)) # type: ignore + add_node.replace_all_uses_with(new_node) + split_gm.graph.erase_node(add_node) + if rhs_submod: + split_gm.graph.erase_node(node.next) + split_gm.graph.erase_node(node) + split_gm.graph.lint() + split_gm.recompile() + return split_gm, count def _can_linears_be_fused(linear_nodes: List[Node], all_modules: Mapping[str, nn.Module]) -> bool: diff --git a/tests/utils/test_fx_utils.py b/tests/utils/test_fx_utils.py index 136bc6dfa06..594c035c136 100644 --- a/tests/utils/test_fx_utils.py +++ b/tests/utils/test_fx_utils.py @@ -8,8 +8,9 @@ from torch import nn from torch.fx import symbolic_trace from torch.fx.graph_module import GraphModule +from torchvision import models -from composer.utils.fx_utils import count_op_instances, fuse_parallel_linears, replace_op +from composer.utils.fx_utils import apply_stochastic_residual, count_op_instances, fuse_parallel_linears, replace_op class MyTestModel(nn.Module): @@ -153,3 +154,27 @@ def test_fuse_parallel_linears(model_cls, before_count, after_count): fuse_parallel_linears(traced) assert count_op_instances(traced, nn.Linear) == after_count + + +@pytest.mark.parametrize( + 'model_cls, block_count', + [(models.resnet18, 8)], +) +@pytest.mark.filterwarnings( + r'ignore:Attempted to insert a call_module Node with no underlying reference in the owning GraphModule!.*:UserWarning' +) +def test_stochastic_depth(model_cls, block_count): + model = model_cls() + traced = symbolic_trace(model) + + assert isinstance(traced, GraphModule) + + inp = torch.randn(1, 3, 224, 224) + + traced_st_depth_no_drop, residual_count = apply_stochastic_residual(traced, 0.0) + + out_traced = traced(inp) + out_traced_st_depth_no_drop = traced_st_depth_no_drop(inp) + assert torch.allclose(out_traced, + out_traced_st_depth_no_drop), 'mismatch in outputs with 0 drop rate for stochastic modules' + assert residual_count == block_count