Skip to content

Commit

Permalink
Fix Perturbate on RGB images
Browse files Browse the repository at this point in the history
Closes #299
  • Loading branch information
adrhill committed Jan 31, 2023
1 parent 5cf7a3e commit 99149e2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 48 deletions.
30 changes: 15 additions & 15 deletions examples/mnist_perturbation.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ select = [
]
ignore = [
"E741", # Ignore warning "ambiguous variable name 'l'" as it is useful to iterate over layers
"B905", # zip's `strict` parameter was only added in Python 3.10
]

[tool.pylint."messages control"]
Expand Down
64 changes: 31 additions & 33 deletions src/innvestigate/tools/perturbate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ def __init__(
self.value_range = value_range

@staticmethod
def compute_perturbation_mask(ranks, num_perturbated_regions):
def _compute_perturbation_mask(ranks, num_perturbated_regions):
perturbation_mask_regions = ranks <= num_perturbated_regions - 1
return perturbation_mask_regions

@staticmethod
def compute_region_ordering(aggregated_regions):
def _compute_region_ordering(aggregated_regions):
# 0 means highest scoring region
new_shape = tuple(aggregated_regions.shape[:2]) + (-1,)
order = np.argsort(-aggregated_regions.reshape(new_shape), axis=-1)
ranks = order.argsort().reshape(aggregated_regions.shape)
return ranks

def expand_regions_to_pixels(self, regions):
def _expand_regions_to_pixels(self, regions):
# Resize to pixels (repeat values).
# (n, c, h_aggregated_region, w_aggregated_region) ->
# (n, c, h_aggregated_region, h_region, w_aggregated_region, w_region)
Expand All @@ -123,7 +123,7 @@ def expand_regions_to_pixels(self, regions):

return region_pixels

def reshape_region_pixels(self, region_pixels, target_shape):
def _reshape_region_pixels(self, region_pixels, target_shape):
# Reshape to output shape
pixels = region_pixels.reshape(target_shape)
assert (
Expand All @@ -134,7 +134,7 @@ def reshape_region_pixels(self, region_pixels, target_shape):
)
return pixels

def pad(self, analysis):
def _pad(self, analysis):
pad_shape = self.region_shape - np.array(analysis.shape[2:]) % self.region_shape
assert np.all(pad_shape < self.region_shape)

Expand All @@ -153,7 +153,7 @@ def pad(self, analysis):
), analysis.shape[2:]
return analysis, pad_shape_before

def reshape_to_regions(self, analysis):
def _reshape_to_regions(self, analysis):
aggregated_shape = tuple(
(np.array(analysis.shape[2:]) / self.region_shape).astype(int)
)
Expand All @@ -169,40 +169,37 @@ def reshape_to_regions(self, analysis):
)
return regions

def aggregate_regions(self, analysis):
regions = self.reshape_to_regions(analysis)
def _aggregate_regions(self, analysis):
# Make sure color channel has been reduced:
assert analysis.shape[1] == 1, analysis.shape

regions = self._reshape_to_regions(analysis)
aggregated_regions = self.aggregation_function(regions, axis=(3, 5))
return aggregated_regions

def perturbate_regions(self, x, perturbation_mask_regions):
def _perturbate_regions(self, x, perturbation_mask_regions):
# Perturbate every region in tensor.
# A single region (at region_x, region_y in sample)
# should be in mask[sample, channel, region_x, :, region_y, :]

x_perturbated = self.reshape_to_regions(x)
n_channels = x.shape[1]
xp = self._reshape_to_regions(x) # perturbed output
for sample_idx, channel_idx, region_row, region_col in np.ndindex(
perturbation_mask_regions.shape
):
region = x_perturbated[
sample_idx, channel_idx, region_row, :, region_col, :
]
region_mask = perturbation_mask_regions[
sample_idx, channel_idx, region_row, region_col
]
if region_mask:
x_perturbated[
sample_idx, channel_idx, region_row, :, region_col, :
] = self.perturbation_function(region)

if self.value_range is not None:
np.clip(
x_perturbated,
self.value_range[0],
self.value_range[1],
x_perturbated,
)
x_perturbated = self.reshape_region_pixels(x_perturbated, x.shape)
return x_perturbated
for c in range(n_channels):
region = xp[sample_idx, c, region_row, :, region_col, :]
xp[
sample_idx, c, region_row, :, region_col, :
] = self.perturbation_function(region)

if self.value_range is not None:
np.clip(xp, self.value_range[0], self.value_range[1], xp)
xp = self._reshape_region_pixels(xp, x.shape)
return xp

def perturbate_on_batch(self, x, analysis):
"""
Expand All @@ -213,6 +210,7 @@ def perturbate_on_batch(self, x, analysis):
:return: Batch of perturbated images
:rtype: numpy.ndarray
"""
# Internally use channels_first (BCHW) only:
if kbackend.image_data_format() == "channels_last":
x = np.moveaxis(x, 3, 1)
analysis = np.moveaxis(analysis, 3, 1)
Expand All @@ -226,18 +224,18 @@ def perturbate_on_batch(self, x, analysis):

padding = not np.all(np.array(analysis.shape[2:]) % self.region_shape == 0)
if padding:
analysis, _pad_shape_before_analysis = self.pad(analysis)
x, pad_shape_before_x = self.pad(x)
aggregated_regions = self.aggregate_regions(analysis)
analysis, _pad_shape_before_analysis = self._pad(analysis)
x, pad_shape_before_x = self._pad(x)
aggregated_regions = self._aggregate_regions(analysis)

# Compute perturbation mask
# (mask with ones where the input should be perturbated, zeros otherwise)
ranks = self.compute_region_ordering(aggregated_regions)
perturbation_mask_regions = self.compute_perturbation_mask(
ranks = self._compute_region_ordering(aggregated_regions)
perturbation_mask_regions = self._compute_perturbation_mask(
ranks, self.num_perturbed_regions
)
# Perturbate each region
x_perturbated = self.perturbate_regions(x, perturbation_mask_regions)
x_perturbated = self._perturbate_regions(x, perturbation_mask_regions)

# Crop the original image region to remove the padding
if padding:
Expand Down

0 comments on commit 99149e2

Please sign in to comment.