ArXiv • Installation • Training • Cite
Note
🎉 CheXtriev has been accepted at MICCAI 2024!
To set up the environment, run the following command to create a conda environment:
conda env create -f environment.yml
The main dependencies required are Pytorch, Pytorch Lightning, Pytorch Geometric and Faiss.
- Run
main.py
to train any model. Modify paths to the dataset and model checkpoints as necessary. metrics
contains evaluation scripts for various models. These scripts utilizecommon_metrics.py
to compute metrics such as mAP (Mean Average Precision), mHR (Mean Hit Rate), and mRR (Mean Reciprocal Rank). Modify paths to the dataset and model checkpoints within these scripts as needed.dataloader
contains the dataloader implementations for each model in the respective formats.graph_transformer
is adapted from the well maintained GitHub repository of the graph transformer architecture with added functionalities to support the project requirements.model
contains definitions and architectures for the various models used in the project.notebooks
includes Jupyter notebooks used for analysis, visualizations, and initial experiments. These were later converted to Python scripts for streamlined execution.others
contains scripts for data processing and transferring data to HPC cluster, specific to our setup.output
is where results are stored in a tabular format, detailing top-3, top-5, and top-10 retrieved images.Res2Net
contains the multi-scale ResNet50 model definition borrowed from this repository.scripts
includes the command scripts to train any model including hyperparameter tuning.
The Global CNN baseline utilizes ResNet50 to extract latent representations from chest radiographs. Only the classification head is finetuned, while the rest of the network's weights are frozen.
Use the following script to train the ResNet50 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0005 \
--grad_accum 4 \
--task resnet50 \
--run resnet50_fc \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the ResNet50 model:
python metrics/temp_metrics_resnet50.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0005 \
--task resnet50 \
--run resnet50_fc \
--num_workers 32
Attention-based Triplet Hashing (ATH) is a state-of-the-art chest radiograph retrieval method based on attention mechanism and triplet hashing. More details can be found in the Github repository and the paper.
Use the following script to train the ATH model:
python main.py \
--num_classes 9 \
--batch_size 24 \
--lr 0.001 \
--grad_accum 4 \
--dropout 0.0 \
--hash_bits 32 \
--task ath \
--run ath \
--gpu_ids 0 \
--num_workers 36 \
--train \
--log
Use the following script to evaluate the ATH model:
python metrics/temp_metrics_ath.py \
--num_classes 9 \
--batch_size 24 \
--lr 0.001 \
--grad_accum 4 \
--dropout 0.0 \
--hash_bits 32 \
--task ath \
--run ath \
--num_workers 36
AnaXNet is an anatomy-aware multi-label classification model for chest X-rays. For more details, refer to the paper.
Use the following script to train the AnaXNet model:
python main.py \
--num_classes 9 \
--batch_size 32 \
--lr 0.0001 \
--grad_accum 4 \
--task anaxnet \
--run anaxnet_final \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the AnaXNet model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--task anaxnet \
--run anaxnet_final \
--num_workers 32
CheXtriev is a novel graph-based, anatomy-aware framework designed for chest radiograph retrieval. It consists of several variants (V0 to V6), each incorporating various enhancements and modifications.
This variant extracts ResNet50 features from the predefined 18 anatomical regions, and uses mean pooling to obtain the latent representation of the chest radiographs.
Use the following script to train the V0 model for global image level classification:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 1 \
--graph_importance 1.0 \
--pool mean \
--minimalistic \
--task xfactor \
--run mean_pool_global_image_classification_bz \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V0 model for global image level classification:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 1 \
--graph_importance 1.0 \
--pool mean \
--minimalistic \
--task xfactor \
--run mean_pool_global_image_classification_bz \
--num_workers 32
Use the following script to train the V0 model for local anatomy level classification:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 1 \
--graph_importance 0.0 \
--pool mean \
--minimalistic \
--task xfactor \
--run mean_pool_node_classification_bz \
--gpu_ids 0 1 \
--num_workers 10 \
--train \
--log
Use the following script to train the V0 model for local anatomy level classification:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 1 \
--graph_importance 0.0 \
--pool mean \
--minimalistic \
--task xfactor \
--run mean_pool_node_classification_bz \
--num_workers 20
In V1, anatomical features processed through ResNet50 are further contextualized using a graph transformer, with edge connections (binary) based on label co-occurence. This model is supervised globally at the image level.
Use the following script to train the V1 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 2 \
--graph_importance 1.0 \
--task graph_transformer \
--run best_config_adj_mat \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V1 model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--num_layers 2 \
--graph_importance 1.0 \
--task graph_transformer \
--run best_config_adj_mat \
--num_workers 32
In V2, anatomical features processed through ResNet50 are further contextualized using a graph transformer, with fully connected uniform edge connections to model relationships among the anatomical structures. This model is supervised globally at the image level.
Use the following script to train the V2 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--task graph_transformer \
--run best_config_abs_pos \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V2 model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--task graph_transformer \
--run best_config_abs_pos \
--num_workers 32
V3 builds on V2 by introducing learnable positional embeddings, enhancing the model's ability to capture spatial relationships between anatomical features.
Use the following script to train the V3 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--abs_pos \
--task graph_transformer \
--run best_config_abs_pos \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V3 model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--abs_pos \
--task graph_transformer \
--run best_config_abs_pos \
--num_workers 32
V4 modifies V3 by making the fully connected edges unique and entirely learnable and supervised globally at the image level. We use local multi-level features with gated residuals in V4 only.
Use the following script to train the V4 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 2 \
--graph_importance 0.0 \
--fully_connected \
--abs_pos \
--accept_edges \
--residual_type 2 \
--task graph_transformer \
--run best_config_with_edges_local_anatomy \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V4 model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--num_layers 2 \
--graph_importance 0.0 \
--fully_connected \
--abs_pos \
--accept_edges \
--residual_type 2 \
--task graph_transformer \
--run best_config_with_edges_local_anatomy \
--num_workers 32
V5 alters V4 by omitting the learnable positional embeddings, supervising globally at the image level and uses global multi-level features with gated residuals.
Use the following script to train the V5 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--accept_edges \
--residual_type 2 \
--task graph_transformer \
--run best_config_with_edges_without_pos_emb \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V5 model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--accept_edges \
--residual_type 2 \
--task graph_transformer \
--run best_config_with_edges_without_pos_emb \
--num_workers 32
V6 is the best configuration, where detected anatomies are processed through ResNet50 and then passed through two layers of Graph Transformers with learnable continuous edges and positional embeddings. This model is supervised globally at the image level.
Use the following script to train the V6 model:
python main.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--grad_accum 8 \
--dropout 0.0 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--abs_pos \
--accept_edges \
--task graph_transformer \
--run best_config_abs_pos_with_edges \
--gpu_ids 0 1 \
--num_workers 20 \
--train \
--log
Use the following script to evaluate the V6 model:
python metrics/temp_metrics_anaxnet.py \
--num_classes 9 \
--batch_size 16 \
--lr 0.0001 \
--num_layers 2 \
--graph_importance 1.0 \
--fully_connected \
--abs_pos \
--accept_edges \
--task graph_transformer \
--run best_config_abs_pos_with_edges \
--num_workers 32