Skip to content

Commit

Permalink
Merge pull request #212 (#115)
Browse files Browse the repository at this point in the history
embedding layers
  • Loading branch information
enryH authored Aug 24, 2020
2 parents 1f3ed62 + 15a9d19 commit 49f83af
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 11 deletions.
21 changes: 21 additions & 0 deletions examples/embedding_minimal_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from keras import Sequential
from keras.layers import Dense, Conv1D, Embedding, GlobalMaxPooling1D
import numpy as np
import innvestigate

model = Sequential()
model.add(Embedding(input_dim=219, output_dim=8))
model.add(Conv1D(filters=64, kernel_size=8, padding='valid', activation='relu'))
model.add(GlobalMaxPooling1D())
model.add(Dense(16, activation='relu'))
model.add(Dense(2, activation=None))

# print(model.layers[0].get_weights()[0].shape)
# exit()

#test
model.predict(np.random.randint(1, 219, (1,100))) # [[0.04913538 0.04234646]]

analyzer = innvestigate.create_analyzer('lrp.epsilon', model, neuron_selection_mode='max_activation', **{'epsilon': 1})
a = analyzer.analyze(np.random.randint(1, 219, (1,100)))
print(a[0], a[0].shape)
33 changes: 22 additions & 11 deletions innvestigate/analyzer/relevance_based/relevance_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,27 @@ def _create_analysis(self, model, stop_analysis_at_tensors=[]):
}


class EmbeddingReverseLayer(kgraph.ReverseMappingBase):
def __init__(self, layer, state):
#TODO: implement rule support.
return

def apply(self, Xs, Ys, Rs, reverse_state):
# the embedding layer outputs for an (indexed) input a vector.
# thus, in the relevance backward pass, the embedding layer receives
# relevances Rs corresponding to those vectors.
# Due to the 1:1 relationship between input index and output mapping vector,
# the relevance backward pass can be realized by pooling relevances
# over the vector axis.

#relevances are given shaped [batch_size, sequence_length, embedding_dims]
pool_relevance = keras.layers.Lambda(lambda x: keras.backend.sum(x, axis=-1))
return [pool_relevance(r) for r in Rs]

class BatchNormalizationReverseLayer(kgraph.ReverseMappingBase):
"""Special BN handler that applies the Z-Rule"""

def __init__(self, layer, state):
##print("in BatchNormalizationReverseLayer.init:", layer.__class__.__name__,"-> Dedicated ReverseLayer class" ) #debug
config = layer.get_config()

self._center = config['center']
Expand All @@ -209,8 +225,6 @@ def __init__(self, layer, state):
# to BatchNormEpsilonRule. Not pretty, but should work.

def apply(self, Xs, Ys, Rs, reverse_state):
##print(" in BatchNormalizationReverseLayer.apply:", reverse_state['layer'].__class__.__name__, '(nid: {})'.format(reverse_state['nid']))

input_shape = [K.int_shape(x) for x in Xs]
if len(input_shape) != 1:
#extend below lambda layers towards multiple parameters.
Expand Down Expand Up @@ -254,7 +268,6 @@ class AddReverseLayer(kgraph.ReverseMappingBase):
"""Special Add layer handler that applies the Z-Rule"""

def __init__(self, layer, state):
##print("in AddReverseLayer.init:", layer.__class__.__name__,"-> Dedicated ReverseLayer class" ) #debug
self._layer_wo_act = kgraph.copy_layer_wo_activation(layer,
name_template="reversed_kernel_%s")

Expand Down Expand Up @@ -285,7 +298,6 @@ class AveragePoolingReverseLayer(kgraph.ReverseMappingBase):
"""Special AveragePooling handler that applies the Z-Rule"""

def __init__(self, layer, state):
##print("in AveragePoolingRerseLayer.init:", layer.__class__.__name__,"-> Dedicated ReverseLayer class" ) #debug
self._layer_wo_act = kgraph.copy_layer_wo_activation(layer,
name_template="reversed_kernel_%s")

Expand Down Expand Up @@ -404,16 +416,13 @@ def __init__(self, *args, **kwargs):
super(LRP, self).__init__(model, *args, **kwargs)

def create_rule_mapping(self, layer, reverse_state):
##print("in select_rule:", layer.__class__.__name__ , end='->') #debug
rule_class = None
if self._rules_use_conditions is True:
for condition, rule in self._rules:
if condition(layer, reverse_state):
##print(str(rule)) #debug
rule_class = rule
break
else:
##print(str(rules[0]), '(via pop)') #debug
rule_class = self._rules.pop()

if rule_class is None:
Expand Down Expand Up @@ -467,21 +476,24 @@ def _create_analysis(self, *args, **kwargs):
AddReverseLayer,
name="lrp_add_layer_mapping",
)
self._add_conditional_reverse_mapping(
kchecks.is_embedding_layer,
EmbeddingReverseLayer,
name="lrp_embedding_mapping"
)

# FINALIZED constructor.
return super(LRP, self)._create_analysis(*args, **kwargs)


def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state):
##print(" in _default_reverse_mapping:", reverse_state['layer'].__class__.__name__, '(nid: {})'.format(reverse_state['nid']), end='->')
#default_return_layers = [keras.layers.Activation]# TODO extend
if(len(Xs) == len(Ys) and
isinstance(reverse_state['layer'], (keras.layers.Activation,)) and
all([K.int_shape(x) == K.int_shape(y) for x, y in zip(Xs, Ys)])):
# Expect Xs and Ys to have the same shapes.
# There is not mixing of relevances as there is kernel,
# therefore we pass them as they are.
##print('return R')
return reversed_Ys
else:
# This branch covers:
Expand All @@ -491,7 +503,6 @@ def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state):
# Reshape
# Concatenate
# Cropping
##print('ilayers.GradientWRT')
return self._gradient_reverse_mapping(
Xs, Ys, reversed_Ys, reverse_state)

Expand Down
3 changes: 3 additions & 0 deletions innvestigate/utils/keras/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_kgraph():
"is_max_pooling",
"is_input_layer",
"is_batch_normalization_layer",
"is_embedding_layer"
]


Expand Down Expand Up @@ -277,6 +278,8 @@ def is_conv_layer(layer, *args, **kwargs):
)
return isinstance(layer, CONV_LAYERS)

def is_embedding_layer(layer, *args, **kwargs):
return isinstance(layer, keras.layers.Embedding)

def is_batch_normalization_layer(layer, *args, **kwargs):
"""Checks if layer is a batchnorm layer."""
Expand Down

0 comments on commit 49f83af

Please sign in to comment.