Skip to content

Commit

Permalink
Merge pull request #169 from leanderweber/master
Browse files Browse the repository at this point in the history
Added LRP-analyzer class to allow application of LRP_flat on all (parametrized) layers until a given index
  • Loading branch information
enryH authored Aug 27, 2020
2 parents 49f83af + 0bd3388 commit 7fa298b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
3 changes: 3 additions & 0 deletions innvestigate/analyzer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .relevance_based.relevance_analyzer import LRPSequentialPresetB
from .relevance_based.relevance_analyzer import LRPSequentialPresetAFlat
from .relevance_based.relevance_analyzer import LRPSequentialPresetBFlat
from .relevance_based.relevance_analyzer import LRPSequentialPresetBFlatUntilIdx
from .deeptaylor import DeepTaylor
from .deeptaylor import BoundedDeepTaylor
from .wrapper import WrapperBase
Expand Down Expand Up @@ -101,6 +102,8 @@
"lrp.sequential_preset_b": LRPSequentialPresetB,
"lrp.sequential_preset_a_flat": LRPSequentialPresetAFlat,
"lrp.sequential_preset_b_flat": LRPSequentialPresetBFlat,
"lrp.sequential_preset_b_flat_until_idx": LRPSequentialPresetBFlatUntilIdx,


# Deep Taylor
"deep_taylor": DeepTaylor,
Expand Down
27 changes: 27 additions & 0 deletions innvestigate/analyzer/relevance_based/relevance_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

"LRPSequentialPresetAFlat",
"LRPSequentialPresetBFlat",
"LRPSequentialPresetBFlatUntilIdx",
]


Expand Down Expand Up @@ -343,6 +344,9 @@ class LRP(base.ReverseAnalyzerBase):
def __init__(self, model, *args, **kwargs):
rule = kwargs.pop("rule", None)
input_layer_rule = kwargs.pop("input_layer_rule", None)
until_layer_idx = kwargs.pop("until_layer_idx", None)
until_layer_rule = kwargs.pop("until_layer_rule", None)

bn_layer_rule = kwargs.pop("bn_layer_rule", None)
bn_layer_fuse_mode = kwargs.pop("bn_layer_fuse_mode", "one_linear")
assert bn_layer_fuse_mode in ["one_linear", "two_linear"]
Expand All @@ -367,6 +371,9 @@ def __init__(self, model, *args, **kwargs):
else:
self._rule = rule
self._input_layer_rule = input_layer_rule
self._until_layer_rule = until_layer_rule
self._until_layer_idx = until_layer_idx

self._bn_layer_rule = bn_layer_rule
self._bn_layer_fuse_mode = bn_layer_fuse_mode

Expand All @@ -387,6 +394,12 @@ def __init__(self, model, *args, **kwargs):
use_conditions = True
rules = rule

#apply rule to first self._until_layer_idx layers
if self._until_layer_rule is not None and self._until_layer_idx is not None:
for i in range(self._until_layer_idx+1):
rules.insert(0,
(lambda layer, foo, bound_i=i: kchecks.is_layer_at_idx(layer, bound_i),
self._until_layer_rule))

# create a BoundedRule for input layer handling from given tuple
if self._input_layer_rule is not None:
Expand Down Expand Up @@ -840,3 +853,17 @@ def __init__(self, model, *args, **kwargs):
*args,
input_layer_rule="Flat",
**kwargs)

class LRPSequentialPresetBFlatUntilIdx(LRPSequentialPresetBFlat):
"""
Special LRP-configuration for ConvNets. Allows to perform LRP_flat from (including) layer until_layer_idx down until
the input layer. Weightless layers are ignored when counting the index for now.
"""

def __init__(self, model, *args, **kwargs):
layer_flat_idx=kwargs.pop("until_layer_idx", None)
super(LRPSequentialPresetBFlatUntilIdx, self).__init__(model,
*args,
until_layer_idx=layer_flat_idx,
until_layer_rule=rrule.FlatRule,
**kwargs)
4 changes: 4 additions & 0 deletions innvestigate/utils/keras/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,7 @@ def is_input_layer(layer, ignore_reshape_layers=True):
return True
else:
return False

def is_layer_at_idx(layer, index, ignore_reshape_layers=True):
"""Checks if layer is a layer at index index, by repeatedly applying is_input_layer()."""
kgraph = get_kgraph()

0 comments on commit 7fa298b

Please sign in to comment.