Skip to content

Commit

Permalink
Automatic Stochastic depth on residual blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
dskhudia committed Jul 6, 2022
1 parent 56afa81 commit e9b7418
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 8 deletions.
145 changes: 138 additions & 7 deletions composer/utils/fx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,20 @@

import logging
import operator
from typing import Any, Callable, Dict, List, Mapping, Tuple, Union
import re
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.fx import Node
from torch.fx.graph_module import GraphModule
from torch.fx.passes.split_utils import split_by_tags

from composer.utils import ensure_tuple

log = logging.getLogger(__name__)

__all__ = ['count_op_instances', 'replace_op', 'fuse_parallel_linears']
__all__ = ['count_op_instances', 'replace_op', 'fuse_parallel_linears', 'apply_stochastic_residual']


def count_op_instances(gm: GraphModule, ops: Union[Callable, str, List[Union[Callable, str]]]) -> int:
Expand Down Expand Up @@ -111,6 +113,40 @@ def replace_op(gm: GraphModule, src_ops: Union[Callable, str, List[Union[Callabl
return gm


class BlockStochasticModule(nn.Module):
"""A convenience class that stochastically executes the provided non-residual path.
Args:
lhs (GraphModule): Operators in the non-residual path of a residual block.
rhs (GraphModule | None): Operators, if any, in the residual path of a residual block.
drop_rate: The base probability of dropping this layer. Must be between 0.0 (inclusive) and 1.0 (inclusive).
Returns:
BlockStochasticModule: An instance of :class:`.BlockStochasticModule`.
"""

def __init__(self, lhs: GraphModule, rhs: Optional[GraphModule] = None, drop_rate: float = 0.2):
super().__init__()
self.drop_rate = torch.tensor(drop_rate)
self.lhs = lhs
self.rhs = rhs

def forward(self, x):
sample = (not self.training) or bool(torch.bernoulli(1 - self.drop_rate))
# lhs side is the non-residual connection
rhs_result = x
# rhs side may or may not have any operations
if self.rhs:
rhs_result = self.rhs(x)

if sample:
lhs_result = self.lhs(x)
if not self.training:
lhs_result = lhs_result * (1 - self.drop_rate)
rhs_result = torch.add(lhs_result, rhs_result)
return rhs_result


def detect_residual_pattern(gm: GraphModule):
"""Search and replace the pattern with another.
Expand All @@ -123,16 +159,111 @@ def detect_residual_pattern(gm: GraphModule):
raise NotImplementedError('detect_residual_pattern is currently not implemented.')


def replace_residual_with_stochastic(gm: GraphModule):
"""Replaces residual pattern with their stoachstic equivalent.
def _get_ancestors(node: Node) -> List[Node]:
ancestorNodes = []
while node.op != 'placeholder':
ancestorNodes.append(node)
node = node.all_input_nodes[0]
return ancestorNodes


def _get_residual_path(nodeLHS: Node, nodeRHS: Node) -> Tuple[List[Node], List[Node]]:
"""Walk backwards from nodeLHS and nodeRSH to the root and construct lists of their parents.
Arguments:
gm (GraphModule): The source FX-traced graph.
nodeLHS (Node): left-hand side node for a binary operator
nodeRHS (Node): right-hand side node for a binary operator
Returns:
GraphModule: Modified GraphModule.
(lhsAncestors, rhsAncestors): Two lists of nodes containing ancestors for ``nodeLHS`` and ``nodeRHS`` with
their common ancestors removed.
"""
lhsAncestors = _get_ancestors(nodeLHS)
rhsAncestors = _get_ancestors(nodeRHS)

# Iterate from back and eliminate common nodes
while lhsAncestors and rhsAncestors and lhsAncestors[-1] == rhsAncestors[-1]: # type: ignore
del lhsAncestors[-1]
del rhsAncestors[-1]
lhsAncestors.reverse()
rhsAncestors.reverse()
return lhsAncestors, rhsAncestors


def _attach_tag(nodes: List[Node], tag: str):
"""Attach tag to the given nodes for the splitter."""
for node in nodes:
node.tag = tag # type: ignore[attr-defined]


def _tag_nodes(gm: GraphModule) -> Tuple[List[str], int]:
"""Tag nodes for splitting."""
# all nodes that are not a part of the residual blocks are tagged with "main_0".
# a tag is required for all nodes by split_by_tags
# Also an earlier tag can be repeated for later nodes.
count = 0
all_tags = []
# In this pass over all nodes, we just tag them
for node in gm.graph.nodes:
default_tag = f'mainN_{count}'
node.tag = default_tag
if default_tag not in all_tags:
all_tags.append(default_tag)
if node.op == 'call_function' and node.target in [torch.add, operator.add]:
assert len(node.all_input_nodes) == 2
node0, node1 = node.all_input_nodes[0], node.all_input_nodes[1]
lhs_nodes, rhs_nodes = _get_residual_path(node0, node1)
if lhs_nodes or rhs_nodes:
if len(lhs_nodes):
_attach_tag(lhs_nodes, f'non_res_{count}')
all_tags.append(f'non_res_{count}')
if len(rhs_nodes):
_attach_tag(rhs_nodes, f'residual_{count}')
all_tags.append(f'residual_{count}')
add_tag = f'addN_{count}'
if add_tag not in all_tags:
all_tags.append(add_tag)
node.tag = add_tag
count += 1
return all_tags, count


def apply_stochastic_residual(gm: GraphModule, drop_rate: float = 0.2) -> Tuple[GraphModule, int]:
"""Detect and replace residual pattern with their stochastic equivalent.
Arguments:
gm (GraphModule): The source FX-traced graph. It can be the whole model symbolically traced.
Returns:
GraphModule: Modified GraphModule that has stochastic residual connections.
"""
raise NotImplementedError('replace_residual_with_stochastic is currently not implemented.')
assert isinstance(gm, GraphModule), 'Input to apply_stochastic_residual should be an instance of GraphModule'
all_tags, count = _tag_nodes(gm)
split_gm = split_by_tags(gm, all_tags)
pattern = re.compile(r'non_res_(\d+)|residual_(\d+)')
for node in split_gm.graph.nodes:
if node.op == 'call_module':
matches = pattern.match(node.target)
if matches:
idx = int(matches[1]) if matches[1] else int(matches[2])
lhs_submod = getattr(split_gm, f'non_res_{idx}')
rhs_submod = getattr(split_gm, f'residual_{idx}', None)
bl_st_instance = BlockStochasticModule(lhs_submod, rhs_submod, drop_rate)
split_gm.add_submodule(f'bl_st_{idx}', bl_st_instance) # type: ignore
insert_node = node.prev
add_node = node.next
if rhs_submod:
add_node = node.next.next
with split_gm.graph.inserting_after(insert_node):
new_node = split_gm.graph.call_module(f'bl_st_{idx}', args=(insert_node,)) # type: ignore
add_node.replace_all_uses_with(new_node)
split_gm.graph.erase_node(add_node)
if rhs_submod:
split_gm.graph.erase_node(node.next)
split_gm.graph.erase_node(node)
split_gm.graph.lint()
split_gm.recompile()
return split_gm, count


def _can_linears_be_fused(linear_nodes: List[Node], all_modules: Mapping[str, nn.Module]) -> bool:
Expand Down
27 changes: 26 additions & 1 deletion tests/utils/test_fx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from torch import nn
from torch.fx import symbolic_trace
from torch.fx.graph_module import GraphModule
from torchvision import models

from composer.utils.fx_utils import count_op_instances, fuse_parallel_linears, replace_op
from composer.utils.fx_utils import apply_stochastic_residual, count_op_instances, fuse_parallel_linears, replace_op


class MyTestModel(nn.Module):
Expand Down Expand Up @@ -153,3 +154,27 @@ def test_fuse_parallel_linears(model_cls, before_count, after_count):
fuse_parallel_linears(traced)

assert count_op_instances(traced, nn.Linear) == after_count


@pytest.mark.parametrize(
'model_cls, block_count',
[(models.resnet18, 8)],
)
@pytest.mark.filterwarnings(
r'ignore:Attempted to insert a call_module Node with no underlying reference in the owning GraphModule!.*:UserWarning'
)
def test_stochastic_depth(model_cls, block_count):
model = model_cls()
traced = symbolic_trace(model)

assert isinstance(traced, GraphModule)

inp = torch.randn(1, 3, 224, 224)

traced_st_depth_no_drop, residual_count = apply_stochastic_residual(traced, 0.0)

out_traced = traced(inp)
out_traced_st_depth_no_drop = traced_st_depth_no_drop(inp)
assert torch.allclose(out_traced,
out_traced_st_depth_no_drop), 'mismatch in outputs with 0 drop rate for stochastic modules'
assert residual_count == block_count

0 comments on commit e9b7418

Please sign in to comment.