Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed improvement for calculating (expected_)integrated_grad for multiple outputs at the same time. #84

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

SeppeDeWinter
Copy link
Collaborator

No description provided.

@SeppeDeWinter
Copy link
Collaborator Author

testing code

import numpy as np
import tensorflow as tf
import os
import random
import time

from crested.tl._explainer_tf import (
    saliency_map,
    integrated_grad,
    expected_integrated_grad,
)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# load human model
path_to_human_model = "../../De_Winter_hNTorg/EMBRYO_ANALYSIS/DEEPTOPIC/model_20241213/"

model = tf.keras.models.model_from_json(
    open(os.path.join(path_to_human_model, "model.json")).read(),
    custom_objects={"Functional": tf.keras.models.Model},
)

model.load_weights(os.path.join(path_to_human_model, "model_epoch_36.hdf5"))


def create_test_data(n_seq, seq_size, seed):
    random.seed(seed)
    X = np.zeros((n_seq, seq_size, 4))
    for i in range(n_seq):
        for j in range(seq_size):
            tmp = np.zeros(4)
            idx = random.randint(0, 3)
            tmp[idx] = 1
            X[i, j] = tmp
            del tmp
    return X

def saliency_map_multi_output(X, model, class_index):
    if not tf.is_tensor(X):
        X = tf.Variable(X)
    with tf.GradientTape(persistent=True) as tape:
        tape.watch(X)
        outputs = model(X)
        outputs_per_c = [outputs[:, c] for c in class_index]
    grads = np.empty((len(class_index), *X.shape))
    for i in range(len(class_index)):
        grads[i] = tape.gradient(outputs_per_c[i], X)
    del tape
    return grads


def test_1(data, class_index):
    t1 = time.perf_counter()
    r = np.zeros((len(class_index), *data.shape))
    for i, c in enumerate(class_index):
        for j in range(data.shape[0]):
            r[i, j] = saliency_map(data[j : j + 1], model, class_index=c)
    t2 = time.perf_counter()
    print(f"Took {t2 - t1} seconds")
    return r


def test_2(data, class_index):
    t1 = time.perf_counter()
    r = np.zeros((len(class_index), *data.shape))
    for i in range(data.shape[0]):
        r[:, i, :] = saliency_map_multi_output(
            data[i : i + 1], model, class_index=class_index
        ).squeeze()
    t2 = time.perf_counter()
    print(f"Took {t2 - t1} seconds")
    return r


r1 = test_1(
    data=create_test_data(100, 500, 123), class_index=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
)
# 17 seconds

r2 = test_2(
    data=create_test_data(100, 500, 123), class_index=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
)
# 7 seconds

np.allclose(r1, r2)
# True

def integrated_grad_multi_outut(
    x, model, baseline, num_steps=25, class_index=None, func=tf.math.reduce_mean
):
    def integral_approximation(gradients):
        # riemann_trapezoidal
        grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)
        integrated_gradients = tf.math.reduce_mean(grads, axis=0)
        return integrated_gradients

    def interpolate_data(baseline, x, steps):
        steps_x = steps[:, tf.newaxis, tf.newaxis]
        delta = x - baseline
        x = baseline + steps_x * delta
        return x

    steps = tf.linspace(start=0.0, stop=1.0, num=num_steps + 1)
    x_interp = interpolate_data(baseline, x, steps)
    grad = saliency_map_multi_output(x_interp, model, class_index=class_index)
    avg_grad = integral_approximation(grad.swapaxes(0, 1))
    return avg_grad


def test_1(data, class_index):
    t1 = time.perf_counter()
    r = np.zeros((len(class_index), *data.shape))
    for i, c in enumerate(class_index):
        for j in range(data.shape[0]):
            r[i, j] = integrated_grad(
                data[j : j + 1], model, baseline=np.zeros((1, 500, 4)), class_index=c
            )
    t2 = time.perf_counter()
    print(f"Took {t2 - t1} seconds")
    return r


def test_2(data, class_index):
    t1 = time.perf_counter()
    r = np.zeros((len(class_index), *data.shape))
    for j in range(data.shape[0]):
        r[:, j, :] = integrated_grad_multi_outut(
            data[j : j + 1],
            model,
            baseline=np.zeros((1, 500, 4)),
            class_index=class_index,
        )
    t2 = time.perf_counter()
    print(f"Took {t2 - t1} seconds")
    return r


r1 = test_1(create_test_data(100, 500, 123), class_index=[1, 2, 3])
# 6 seconds

r2 = test_2(create_test_data(100, 500, 123), class_index=[1, 2, 3])
# 3 seconds

np.allclose(r1, r2)
# True

r1 = integrated_grad(
    create_test_data(1, 500, 123), model, np.zeros((1, 500, 4)), class_index=1
)
r1.shape
# (1, 500, 4)

r2 = integrated_grad(
    create_test_data(1, 500, 123), model, np.zeros((1, 500, 4)), class_index=[1, 2, 3]
)
r2.shape
# TensorShape([3, 500, 4])

r3 = integrated_grad(
    create_test_data(1, 500, 123), model, np.zeros((1, 500, 4)), class_index=None
)
r3.shape
# (1, 500, 4)

np.allclose(r1[0], r2[0].numpy())
# True

baselines = create_test_data(10, 500, 111)

r1 = expected_integrated_grad(
    create_test_data(1, 500, 123), model, baselines, class_index=1
)
r1.shape
# (1, 500, 4)

r2 = expected_integrated_grad(
    create_test_data(1, 500, 123), model, baselines, class_index=[1, 2, 3]
)
r2.shape
# TensorShape([3, 500, 4])

np.allclose(r1[0], r2[0])
# True

@SeppeDeWinter
Copy link
Collaborator Author

Could be further improved by batching over sequences

import numpy as np
import tensorflow as tf
import os
import random
import time

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# load human model
path_to_human_model = "../../De_Winter_hNTorg/EMBRYO_ANALYSIS/DEEPTOPIC/model_20241213/"

model = tf.keras.models.model_from_json(
    open(os.path.join(path_to_human_model, "model.json")).read(),
    custom_objects={"Functional": tf.keras.models.Model},
)

model.load_weights(os.path.join(path_to_human_model, "model_epoch_36.hdf5"))


def create_test_data(n_seq, seq_size, seed):
    random.seed(seed)
    X = np.zeros((n_seq, seq_size, 4))
    for i in range(n_seq):
        for j in range(seq_size):
            tmp = np.zeros(4)
            idx = random.randint(0, 3)
            tmp[idx] = 1
            X[i, j] = tmp
            del tmp
    return X


from crested.tl._explainer_tf import (
    saliency_map,
    function_batch,
    integrated_grad,
    expected_integrated_grad,
)


def test_1(data, batch_size=128):
    t1 = time.perf_counter()
    r = function_batch(data, saliency_map, batch_size, model=model, class_index=0)
    t2 = time.perf_counter()
    print(f"Took {t2 - t1} seconds")
    return r


def test_2(data):
    t1 = time.perf_counter()
    r = []
    for i in range(data.shape[0]):
        r.append(saliency_map(data[i : i + 1], model, class_index=0))
    t2 = time.perf_counter()
    print(f"Took {t2 - t1} seconds")
    return np.array(r).squeeze()


r1 = test_1(create_test_data(500, 500, 100), batch_size=128)
# 0.1 seconds

r2 = test_2(create_test_data(500, 500, 100))
# 9 seconds

However

np.allclose(r1, r2)
>>> False

Passing a batch of sequences through the model or a single sequence results in different results.

It is a 100x speed-up though ...

@LukasMahieu
Copy link
Collaborator

LukasMahieu commented Dec 26, 2024

Thanks, Seppe, nice change. Tested this too and it looks good; I get similar speed-ups. There's probably a bunch of changes like this we can still make that should result in speed-ups, especially on the torch side.

The batching is not an issue I think. In function_batch(...) you have the dataset.batch(..) function which takes an optional 'deterministic' parameter. If you set that to True (which we should), you should get the same outputs. I believe this is the case since having it non-deterministic in a parallel setting will cause different batches to be selected, and we have a couple of operations where that will cause (minor) differences (e.g. taking the mean).

You can implement and verify that and merge this PR.

Copy link
Collaborator

@LukasMahieu LukasMahieu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding 'deterministic=True' to function_batch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants