Skip to content

Commit

Permalink
add contribution scores calculation and plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMahieu committed Jun 12, 2024
1 parent 7748062 commit 4d0f11a
Show file tree
Hide file tree
Showing 6 changed files with 517 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/crested/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._contribution_scores import contribution_scores
103 changes: 103 additions & 0 deletions src/crested/pl/_contribution_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Plot contribution scores."""

import logomaker
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger

from crested._logging import log_and_raise
from crested.pl._utils import grad_times_input_to_df


def _plot_attribution_map(saliency_df, ax=None, figsize=(20, 1)):
"""Plot an attribution map using logomaker"""
logomaker.Logo(saliency_df, figsize=figsize, ax=ax)
if ax is None:
ax = plt.gca()
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
# ax.yaxis.set_ticks_position("none")
ax.xaxis.set_ticks_position("none")
plt.xticks([])


@log_and_raise(ValueError)
def _check_contrib_params(
zoom_n_bases: int | None,
scores: np.ndarray,
):
"""Check contribution scores parameters."""
if zoom_n_bases is not None and zoom_n_bases > scores.shape[2]:
raise ValueError(
f"zoom_n_bases ({zoom_n_bases}) must be less than or equal to the number of bases in the sequence ({scores.shape[2]})"
)


def contribution_scores(
scores: np.ndarray,
seqs_one_hot: np.ndarray,
class_names: list,
zoom_n_bases: int | None = None,
highlight_positions: list[tuple[int, int]] | None = None,
ylim: tuple | None = None,
save_path: str | None = None,
):
"""Visualize interpretation scores with optional highlighted positions."""
# Center and zoom
_check_contrib_params(zoom_n_bases, scores)

if zoom_n_bases is None:
zoom_n_bases = scores.shape[2]
center = int(scores.shape[2] / 2)
start_idx = center - int(zoom_n_bases / 2)
scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :]

global_min = scores.min()
global_max = scores.max()

# Plot
logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)")
for seq in range(seqs_one_hot.shape[0]):
fig_height_per_class = 2
fig = plt.figure(figsize=(50, fig_height_per_class * len(class_names)))
for i, class_name in enumerate(class_names):
seq_class_scores = scores[seq, i, :, :]
seq_class_x = seqs_one_hot[seq, :, :]
intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores)
ax = plt.subplot(len(class_names), 1, i + 1)
_plot_attribution_map(intgrad_df, ax=ax)
text_to_add = class_name
if ylim:
ax.set_ylim(ylim[0], ylim[1])
x_pos = 5
y_pos = 0.75 * ylim[1]
else:
ax.set_ylim([global_min, global_max])
x_pos = 5
y_pos = 0.75 * global_max
ax.text(x_pos, y_pos, text_to_add, fontsize=16, ha="left", va="center")

# Draw rectangles to highlight positions
if highlight_positions:
for start, end in highlight_positions:
ax.add_patch(
plt.Rectangle(
(
start - start_idx - 0.5,
global_min,
),
end - start,
global_max - global_min,
edgecolor="red",
facecolor="none",
linewidth=0.5,
)
)

plt.xlabel("Position")
plt.xticks(np.arange(0, zoom_n_bases, 50))
if save_path:
plt.savefig(save_path)
plt.close(fig)
else:
plt.show()
46 changes: 46 additions & 0 deletions src/crested/pl/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logomaker
import numpy as np


def grad_times_input_to_df(x, grad, alphabet="ACGT"):
"""Generate pandas dataframe for saliency plot based on grad x inputs"""
x_index = np.argmax(np.squeeze(x), axis=1)
grad = np.squeeze(grad)
L, A = grad.shape

seq = ""
saliency = np.zeros((L))
for i in range(L):
seq += alphabet[x_index[i]]
saliency[i] = grad[i, x_index[i]]

# create saliency matrix
saliency_df = logomaker.saliency_to_matrix(seq=seq, values=saliency)
return saliency_df


def grad_times_input_to_df_mutagenesis(x, grad, alphabet="ACGT"):
import pandas as pd

"""Generate pandas dataframe for mutagenesis plot based on grad x inputs"""
x = np.squeeze(x) # Ensure x is correctly squeezed
grad = np.squeeze(grad)
L, A = x.shape

# Get original nucleotides' indices, ensure it's 1D
x_index = np.argmax(x, axis=1)

# Convert index array to nucleotide letters
original_nucleotides = np.array([alphabet[idx] for idx in x_index])

# Data preparation for DataFrame
data = {
"Position": np.repeat(np.arange(L), A),
"Nucleotide": np.tile(list(alphabet), L),
"Effect": grad.reshape(
-1
), # Flatten grad assuming it matches the reshaped size
"Original": np.repeat(original_nucleotides, A),
}
df = pd.DataFrame(data)
return df
71 changes: 69 additions & 2 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from crested._logging import log_and_raise
from crested.tl import TaskConfig
from crested.tl._explainer import Explainer
from crested.tl._utils import one_hot_encode_sequence
from crested.tl.data import AnnDataModule


Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(
def _initialize_callbacks(
save_dir: os.PathLike,
model_checkpointing: bool,
model_checkpointing_best_only: bool | None,
model_checkpointing_best_only: bool,
early_stopping: bool,
early_stopping_patience: int | None,
learning_rate_reduce: bool,
Expand Down Expand Up @@ -213,7 +215,8 @@ def predict(
self._check_predict_params(anndata, model_name)
self._check_gpu_availability()

self.anndatamodule.setup("predict")
if self.anndatamodule.predict_dataset is None:
self.anndatamodule.setup("predict")
predict_loader = self.anndatamodule.predict_dataloader

n_predict_steps = len(predict_loader)
Expand All @@ -227,6 +230,70 @@ def predict(

return predictions

def calculate_contribution_scores(
self,
region_idx: str,
class_indices: list | None = None,
method: str = "integrated_grad",
return_one_hot: bool = True,
) -> tuple(np.ndarray, np.ndarray) | np.ndarray:
"""Calculate contribution scores based on given method for a specified region."""
if self.anndatamodule.predict_dataset is None:
self.anndatamodule.setup("predict")

if isinstance(region_idx, str):
region_idx = [region_idx]

all_scores = []
all_one_hot_sequences = []

for region in region_idx:
sequence = self.anndatamodule.predict_dataset.sequence_loader.get_sequence(
region
)
x = one_hot_encode_sequence(sequence)
all_one_hot_sequences.append(x)

if class_indices is not None:
n_classes = len(class_indices)
else:
n_classes = 1 # 'combined' class
class_indices = [None]

scores = np.zeros(
(x.shape[0], n_classes, x.shape[1], x.shape[2])
) # (N, C, W, 4)

for i, class_index in enumerate(class_indices):
explainer = Explainer(self.model, class_index=class_index)
if method == "integrated_grad":
scores[:, i, :, :] = explainer.integrated_grad(
x, baseline_type="zeros"
)
elif method == "smooth_grad":
scores[:, i, :, :] = explainer.smoothgrad(
x, num_samples=50, mean=0.0, stddev=0.1
)
elif method == "mutagenesis":
scores[:, i, :, :] = explainer.mutagenesis(
x, class_index=class_index
)
elif method == "saliency":
scores[:, i, :, :] = explainer.saliency_maps(x)
elif method == "expected_integrated_grad":
scores[:, i, :, :] = explainer.expected_integrated_grad(
x, num_baseline=25
)

all_scores.append(scores)

if return_one_hot:
return np.concatenate(all_scores, axis=0), np.concatenate(
all_one_hot_sequences, axis=0
)
else:
return np.concatenate(all_scores, axis=0)

@staticmethod
def _check_gpu_availability():
"""Check if GPUs are available."""
Expand Down
Loading

0 comments on commit 4d0f11a

Please sign in to comment.