Skip to content

khairulislam/Temporal-GradCam

Repository files navigation

TemporalGradCam

Temporal interpretation of medical image data.

Dataset

The OASIS_2D dataset contains brain X-ray images of 100 patients.

  • 50 images with disease (label 1), 43 unique patients
  • 50 images with healthy (label 0), 10 unique patients
  • The image size is 256 x 256 colored.
  • For experiement we split the dataset 80:20 for train and test using unique patients. So the same patient will not appear in both training and test. The split is stratified, so balanced amount of healthy and disease examples is present in train and test.
  • No data augmentation is used at this point.

Figure: Length distribution when patients images are converted to a time series. Each patient can have multiple X-rays at different days. Most patients only have one image.

distribution

Model

Currently we have the following two models implemented

For training we freeze all layers except the output Linear layer.

  • Epochs: 25
  • Learning rate: 1e-3
  • Early stop: 5
  • Experiment iteration: 5, the whole experiment is repeated 5 times using different random seed each time. The test results and best model checkpoints are saved.

Figure: Training vistory of one iteration from ResNet

gradient

Temporal Model

To create the temporal version of the OASIS model we,

  • Extracted features from the images using the pretrained models (ResNet or ViT). The extracted feature dimension is equal to the dimension of layer just before the output layer.
    • 512 for ResNet
    • 768 for ViT
  • For each sample
    • find previous images of the same patient, max upto seq length
    • create the time series example [seq_len, feature_dim].
    • smaller sequences are padded at the beginning to the max sequence length. Larger sequences are truncated from the beginning (olders images are dropped).
    • we currently use seq_len 3, around 70% examples fall within this range. Rest are padded.
  • The model is batch first. Pytorch doesn't easily support variable length time sequences.
  • Currently we use a simple DNN model on the the temporal dataset.
    • max epochs 100
    • learning rate 1e-3
    • dropout=0.1
    • hidden_size=64

Results

Following shows the average test result across all five iterations.

Model Loss Accuracy F1-score AUC
ResNet 1.32 83.87 81.32 91.68
ResNet (Seq 3) 1.45 79.87 81.24 82.24
ViT 1.22 85.77 86.50 92.08
ViT (Seq 3) 0.94 87.72 88.25 95.13

The temporal model (sequence length 3) with Vision Transformer is performing best so far.

Interpretation

Interpreting sample patient image for ResNet

Note that, this is not for the temporal model.

No Sample Gradient Shap GradCam Guided GradCam Guided Backprop
1
2
3

Interpreting sample patient image for ViT

Note that, this is not for the temporal model.

No Sample Gradient Shap GradCam Guided GradCam Guided Backprop
1
2
3

Tools

Files

The following files are available for now with pre-trained vision models for transfer learning on the medical dataset.

  • oasis_resnet: Run the oasis dataset with ResNet model.
  • oasis_ViT: Run the oasis dataset with Vision Transformer.

Literature