diff --git a/src/crested/pl/__init__.py b/src/crested/pl/__init__.py new file mode 100644 index 00000000..bb2625d5 --- /dev/null +++ b/src/crested/pl/__init__.py @@ -0,0 +1 @@ +from ._contribution_scores import contribution_scores diff --git a/src/crested/pl/_contribution_scores.py b/src/crested/pl/_contribution_scores.py new file mode 100644 index 00000000..6195e20f --- /dev/null +++ b/src/crested/pl/_contribution_scores.py @@ -0,0 +1,103 @@ +"""Plot contribution scores.""" + +import logomaker +import matplotlib.pyplot as plt +import numpy as np +from loguru import logger + +from crested._logging import log_and_raise +from crested.pl._utils import grad_times_input_to_df + + +def _plot_attribution_map(saliency_df, ax=None, figsize=(20, 1)): + """Plot an attribution map using logomaker""" + logomaker.Logo(saliency_df, figsize=figsize, ax=ax) + if ax is None: + ax = plt.gca() + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + # ax.yaxis.set_ticks_position("none") + ax.xaxis.set_ticks_position("none") + plt.xticks([]) + + +@log_and_raise(ValueError) +def _check_contrib_params( + zoom_n_bases: int | None, + scores: np.ndarray, +): + """Check contribution scores parameters.""" + if zoom_n_bases is not None and zoom_n_bases > scores.shape[2]: + raise ValueError( + f"zoom_n_bases ({zoom_n_bases}) must be less than or equal to the number of bases in the sequence ({scores.shape[2]})" + ) + + +def contribution_scores( + scores: np.ndarray, + seqs_one_hot: np.ndarray, + class_names: list, + zoom_n_bases: int | None = None, + highlight_positions: list[tuple[int, int]] | None = None, + ylim: tuple | None = None, + save_path: str | None = None, +): + """Visualize interpretation scores with optional highlighted positions.""" + # Center and zoom + _check_contrib_params(zoom_n_bases, scores) + + if zoom_n_bases is None: + zoom_n_bases = scores.shape[2] + center = int(scores.shape[2] / 2) + start_idx = center - int(zoom_n_bases / 2) + scores = scores[:, :, start_idx : start_idx + zoom_n_bases, :] + + global_min = scores.min() + global_max = scores.max() + + # Plot + logger.info(f"Plotting contribution scores for {seqs_one_hot.shape[0]} sequence(s)") + for seq in range(seqs_one_hot.shape[0]): + fig_height_per_class = 2 + fig = plt.figure(figsize=(50, fig_height_per_class * len(class_names))) + for i, class_name in enumerate(class_names): + seq_class_scores = scores[seq, i, :, :] + seq_class_x = seqs_one_hot[seq, :, :] + intgrad_df = grad_times_input_to_df(seq_class_x, seq_class_scores) + ax = plt.subplot(len(class_names), 1, i + 1) + _plot_attribution_map(intgrad_df, ax=ax) + text_to_add = class_name + if ylim: + ax.set_ylim(ylim[0], ylim[1]) + x_pos = 5 + y_pos = 0.75 * ylim[1] + else: + ax.set_ylim([global_min, global_max]) + x_pos = 5 + y_pos = 0.75 * global_max + ax.text(x_pos, y_pos, text_to_add, fontsize=16, ha="left", va="center") + + # Draw rectangles to highlight positions + if highlight_positions: + for start, end in highlight_positions: + ax.add_patch( + plt.Rectangle( + ( + start - start_idx - 0.5, + global_min, + ), + end - start, + global_max - global_min, + edgecolor="red", + facecolor="none", + linewidth=0.5, + ) + ) + + plt.xlabel("Position") + plt.xticks(np.arange(0, zoom_n_bases, 50)) + if save_path: + plt.savefig(save_path) + plt.close(fig) + else: + plt.show() diff --git a/src/crested/pl/_utils.py b/src/crested/pl/_utils.py new file mode 100644 index 00000000..14b8f275 --- /dev/null +++ b/src/crested/pl/_utils.py @@ -0,0 +1,46 @@ +import logomaker +import numpy as np + + +def grad_times_input_to_df(x, grad, alphabet="ACGT"): + """Generate pandas dataframe for saliency plot based on grad x inputs""" + x_index = np.argmax(np.squeeze(x), axis=1) + grad = np.squeeze(grad) + L, A = grad.shape + + seq = "" + saliency = np.zeros((L)) + for i in range(L): + seq += alphabet[x_index[i]] + saliency[i] = grad[i, x_index[i]] + + # create saliency matrix + saliency_df = logomaker.saliency_to_matrix(seq=seq, values=saliency) + return saliency_df + + +def grad_times_input_to_df_mutagenesis(x, grad, alphabet="ACGT"): + import pandas as pd + + """Generate pandas dataframe for mutagenesis plot based on grad x inputs""" + x = np.squeeze(x) # Ensure x is correctly squeezed + grad = np.squeeze(grad) + L, A = x.shape + + # Get original nucleotides' indices, ensure it's 1D + x_index = np.argmax(x, axis=1) + + # Convert index array to nucleotide letters + original_nucleotides = np.array([alphabet[idx] for idx in x_index]) + + # Data preparation for DataFrame + data = { + "Position": np.repeat(np.arange(L), A), + "Nucleotide": np.tile(list(alphabet), L), + "Effect": grad.reshape( + -1 + ), # Flatten grad assuming it matches the reshaped size + "Original": np.repeat(original_nucleotides, A), + } + df = pd.DataFrame(data) + return df diff --git a/src/crested/tl/_crested.py b/src/crested/tl/_crested.py index 6a0870f8..272f8ca1 100644 --- a/src/crested/tl/_crested.py +++ b/src/crested/tl/_crested.py @@ -12,6 +12,8 @@ from crested._logging import log_and_raise from crested.tl import TaskConfig +from crested.tl._explainer import Explainer +from crested.tl._utils import one_hot_encode_sequence from crested.tl.data import AnnDataModule @@ -46,7 +48,7 @@ def __init__( def _initialize_callbacks( save_dir: os.PathLike, model_checkpointing: bool, - model_checkpointing_best_only: bool | None, + model_checkpointing_best_only: bool, early_stopping: bool, early_stopping_patience: int | None, learning_rate_reduce: bool, @@ -213,7 +215,8 @@ def predict( self._check_predict_params(anndata, model_name) self._check_gpu_availability() - self.anndatamodule.setup("predict") + if self.anndatamodule.predict_dataset is None: + self.anndatamodule.setup("predict") predict_loader = self.anndatamodule.predict_dataloader n_predict_steps = len(predict_loader) @@ -227,6 +230,70 @@ def predict( return predictions + def calculate_contribution_scores( + self, + region_idx: str, + class_indices: list | None = None, + method: str = "integrated_grad", + return_one_hot: bool = True, + ) -> tuple(np.ndarray, np.ndarray) | np.ndarray: + """Calculate contribution scores based on given method for a specified region.""" + if self.anndatamodule.predict_dataset is None: + self.anndatamodule.setup("predict") + + if isinstance(region_idx, str): + region_idx = [region_idx] + + all_scores = [] + all_one_hot_sequences = [] + + for region in region_idx: + sequence = self.anndatamodule.predict_dataset.sequence_loader.get_sequence( + region + ) + x = one_hot_encode_sequence(sequence) + all_one_hot_sequences.append(x) + + if class_indices is not None: + n_classes = len(class_indices) + else: + n_classes = 1 # 'combined' class + class_indices = [None] + + scores = np.zeros( + (x.shape[0], n_classes, x.shape[1], x.shape[2]) + ) # (N, C, W, 4) + + for i, class_index in enumerate(class_indices): + explainer = Explainer(self.model, class_index=class_index) + if method == "integrated_grad": + scores[:, i, :, :] = explainer.integrated_grad( + x, baseline_type="zeros" + ) + elif method == "smooth_grad": + scores[:, i, :, :] = explainer.smoothgrad( + x, num_samples=50, mean=0.0, stddev=0.1 + ) + elif method == "mutagenesis": + scores[:, i, :, :] = explainer.mutagenesis( + x, class_index=class_index + ) + elif method == "saliency": + scores[:, i, :, :] = explainer.saliency_maps(x) + elif method == "expected_integrated_grad": + scores[:, i, :, :] = explainer.expected_integrated_grad( + x, num_baseline=25 + ) + + all_scores.append(scores) + + if return_one_hot: + return np.concatenate(all_scores, axis=0), np.concatenate( + all_one_hot_sequences, axis=0 + ) + else: + return np.concatenate(all_scores, axis=0) + @staticmethod def _check_gpu_availability(): """Check if GPUs are available.""" diff --git a/src/crested/tl/_explainer.py b/src/crested/tl/_explainer.py new file mode 100644 index 00000000..5e8dd51d --- /dev/null +++ b/src/crested/tl/_explainer.py @@ -0,0 +1,254 @@ +""" +Model explanation functions using 'gradient x input'-based methods. + +Adapted from: https://github.com/p-koo/tfomics/blob/master/tfomics/ +""" + +import numpy as np +import tensorflow as tf + + +class Explainer: + """wrapper class for attribution maps""" + + def __init__(self, model, class_index=None, func=tf.math.reduce_mean): + self.model = model + self.class_index = class_index + self.func = func + + def saliency_maps(self, X, batch_size=128): + return function_batch( + X, + saliency_map, + batch_size, + model=self.model, + class_index=self.class_index, + func=self.func, + ) + + def smoothgrad(self, X, num_samples=50, mean=0.0, stddev=0.1): + return function_batch( + X, + smoothgrad, + batch_size=1, + model=self.model, + num_samples=num_samples, + mean=mean, + stddev=stddev, + class_index=self.class_index, + func=self.func, + ) + + def integrated_grad(self, X, baseline_type="random", num_steps=25): + scores = [] + for x in X: + x = np.expand_dims(x, axis=0) + baseline = self.set_baseline(x, baseline_type, num_samples=1) + intgrad_scores = integrated_grad( + x, + model=self.model, + baseline=baseline, + num_steps=num_steps, + class_index=self.class_index, + func=self.func, + ) + scores.append(intgrad_scores) + return np.concatenate(scores, axis=0) + + def expected_integrated_grad( + self, X, num_baseline=25, baseline_type="random", num_steps=25 + ): + scores = [] + for x in X: + x = np.expand_dims(x, axis=0) + baselines = self.set_baseline(x, baseline_type, num_samples=num_baseline) + intgrad_scores = expected_integrated_grad( + x, + model=self.model, + baselines=baselines, + num_steps=num_steps, + class_index=self.class_index, + func=self.func, + ) + scores.append(intgrad_scores) + return np.concatenate(scores, axis=0) + + def mutagenesis(self, X, class_index=None): + scores = [] + for x in X: + x = np.expand_dims(x, axis=0) + scores.append(mutagenesis(x, self.model, class_index)) + return np.concatenate(scores, axis=0) + + def set_baseline(self, x, baseline, num_samples): + if baseline == "random": + baseline = random_shuffle(x, num_samples) + else: + baseline = np.zeros(x.shape) + return baseline + + +def saliency_map(X, model, class_index=None, func=tf.math.reduce_mean): + """Fast function to generate saliency maps""" + if not tf.is_tensor(X): + X = tf.Variable(X) + + with tf.GradientTape() as tape: + tape.watch(X) + if class_index is not None: + outputs = model(X)[:, class_index] + else: + outputs = func(model(X)) + return tape.gradient(outputs, X) + + +@tf.function +def hessian(X, model, class_index=None, func=tf.math.reduce_mean): + """Fast function to generate saliency maps""" + if not tf.is_tensor(X): + X = tf.Variable(X) + + with tf.GradientTape() as t2: + t2.watch(X) + with tf.GradientTape() as t1: + t1.watch(X) + if class_index is not None: + outputs = model(X)[:, class_index] + else: + outputs = func(model(X)) + g = t1.gradient(outputs, X) + return t2.jacobian(g, X) + + +def smoothgrad( + x, + model, + num_samples=50, + mean=0.0, + stddev=0.1, + class_index=None, + func=tf.math.reduce_mean, +): + _, L, A = x.shape + x_noise = tf.tile(x, (num_samples, 1, 1)) + tf.random.normal( + (num_samples, L, A), mean, stddev + ) + grad = saliency_map(x_noise, model, class_index=class_index, func=func) + return tf.reduce_mean(grad, axis=0, keepdims=True) + + +def integrated_grad( + x, model, baseline, num_steps=25, class_index=None, func=tf.math.reduce_mean +): + def integral_approximation(gradients): + # riemann_trapezoidal + grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0) + integrated_gradients = tf.math.reduce_mean(grads, axis=0) + return integrated_gradients + + def interpolate_data(baseline, x, steps): + steps_x = steps[:, tf.newaxis, tf.newaxis] + delta = x - baseline + x = baseline + steps_x * delta + return x + + steps = tf.linspace(start=0.0, stop=1.0, num=num_steps + 1) + x_interp = interpolate_data(baseline, x, steps) + grad = saliency_map(x_interp, model, class_index=class_index, func=func) + avg_grad = integral_approximation(grad) + avg_grad = np.expand_dims(avg_grad, axis=0) + return avg_grad + + +def expected_integrated_grad( + x, model, baselines, num_steps=25, class_index=None, func=tf.math.reduce_mean +): + """Average integrated gradients across different backgrounds.""" + grads = [] + for baseline in baselines: + grads.append( + integrated_grad( + x, + model, + baseline, + num_steps=num_steps, + class_index=class_index, + func=tf.math.reduce_mean, + ) + ) + return np.mean(np.array(grads), axis=0) + + +def mutagenesis(x, model, class_index=None): + """In silico mutagenesis analysis for a given sequence.""" + + def generate_mutagenesis(x): + _, L, A = x.shape + x_mut = [] + for length in range(L): + for a in range(A): + x_new = np.copy(x) + x_new[0, length, :] = 0 + x_new[0, length, a] = 1 + x_mut.append(x_new) + return np.concatenate(x_mut, axis=0) + + def reconstruct_map(predictions): + _, L, A = x.shape + + mut_score = np.zeros((1, L, A)) + k = 0 + for length in range(L): + for a in range(A): + mut_score[0, length, a] = predictions[k] + k += 1 + return mut_score + + def get_score(x, model, class_index): + score = model.predict(x, verbose=0) + if class_index is None: + score = np.sqrt(np.sum(score**2, axis=-1, keepdims=True)) + else: + score = score[:, class_index] + return score + + # generate mutagenized sequences + x_mut = generate_mutagenesis(x) + + # get baseline wildtype score + wt_score = get_score(x, model, class_index) + predictions = get_score(x_mut, model, class_index) + + # reshape mutagenesis predictiosn + mut_score = reconstruct_map(predictions) + + return mut_score - wt_score + + +def grad_times_input(x, scores): + new_scores = [] + for i, score in enumerate(scores): + new_scores.append(np.sum(x[i] * score, axis=1)) + return np.array(new_scores) + + +def l2_norm(scores): + return np.sum(np.sqrt(scores**2), axis=2) + + +def function_batch(X, fun, batch_size=128, **kwargs): + """Run a function in batches.""" + dataset = tf.data.Dataset.from_tensor_slices(X) + outputs = [] + for x in dataset.batch(batch_size): + outputs.append(fun(x, **kwargs)) + return np.concatenate(outputs, axis=0) + + +def random_shuffle(x, num_samples=1): + """Randomly shuffle sequences. Assumes x shape is (N,L,A)""" + x_shuffle = [] + for _ in range(num_samples): + shuffle = np.random.permutation(x.shape[1]) + x_shuffle.append(x[0, shuffle, :]) + return np.array(x_shuffle) diff --git a/src/crested/tl/_utils.py b/src/crested/tl/_utils.py new file mode 100644 index 00000000..b85b2dd9 --- /dev/null +++ b/src/crested/tl/_utils.py @@ -0,0 +1,44 @@ +import numpy as np + + +def get_hot_encoding_table( + alphabet: str = "ACGT", + neutral_alphabet: str = "N", + neutral_value: float = 0.0, + dtype=np.float32, +) -> np.ndarray: + """Get hot encoding table to encode a DNA sequence to a numpy array with shape (len(sequence), len(alphabet)) using bytes.""" + + def str_to_uint8(string) -> np.ndarray: + """Convert string to byte representation.""" + return np.frombuffer(string.encode("ascii"), dtype=np.uint8) + + # 255 x 4 + hot_encoding_table = np.zeros((np.iinfo(np.uint8).max, len(alphabet)), dtype=dtype) + + # For each ASCII value of the nucleotides used in the alphabet + # (upper and lower case), set 1 in the correct column. + hot_encoding_table[str_to_uint8(alphabet.upper())] = np.eye( + len(alphabet), dtype=dtype + ) + hot_encoding_table[str_to_uint8(alphabet.lower())] = np.eye( + len(alphabet), dtype=dtype + ) + + # For each ASCII value of the nucleotides used in the neutral alphabet + # (upper and lower case), set neutral_value in the correct column. + hot_encoding_table[str_to_uint8(neutral_alphabet.upper())] = neutral_value + hot_encoding_table[str_to_uint8(neutral_alphabet.lower())] = neutral_value + + return hot_encoding_table + + +HOT_ENCODING_TABLE = get_hot_encoding_table() + + +def one_hot_encode_sequence(sequence: str) -> np.ndarray: + """One hot encode a DNA sequence.""" + return np.expand_dims( + HOT_ENCODING_TABLE[np.frombuffer(sequence.encode("ascii"), dtype=np.uint8)], + axis=0, + )