-
Notifications
You must be signed in to change notification settings - Fork 11
4. Interpreting the Model
Now, we are looking to interpret the results of our model. In many traditional methylation analyses, this would mean finding differentially methylated regions and uncovering the reason for this difference by querying methods such as gometh, LOLA, etc.
We need a model explanation framework, which is a bit different from the aforementioned process. Why did this model make this prediction for the individual? If this were image data, for each image, we'd want some heatmap to highlight regions the model thinks is predictive of a particular class. For the molecular data, we'd want to highlight individual or groups of CpGs. Regardless, we adopted the SHAP framework (https://github.com/slundberg/shap). For more information, you can check out the supplementals of our paper or the original SHAP paper.
The goal, simply put, is that every complex model can be estimated by a simpler linear model, and you can make a separate linear model for each individual prediction. The coefficients of the simpler model denote the importance/contribution of the CpG to the model prediction, and you can sum them across groups to find importances on a less granular level.
Let's run some model interpretations. Our models are stored in the embeddings/ and predictions/ folders.
methylnet-torque run_torque_job -c "methylnet-interpret produce_shapley_data -mth gradient -ssbs 30 -ns 300 -bs 100 -rc 4. -r 0 -rt 0 -cn Age -nf 4000 -c" -gpu -a "module load python/3-Anaconda && source activate methylnet_pro2" -q gpuq -t 4 -n 1
The methylnet-torque module deployes the embedded text using the -c option into a PBS/torque server so it can run on GPUs. If you have a gpu or not, you can still run locally by exposing the right options. In any case, I'm running on Torque using a GPU (-gpu -n 1) for 4 hours (-t) with the command specified by the -c option and put into the gpu-queue on Torque with additional modules to load (-a).
The command itself:
methylnet-interpret produce_shapley_data -mth gradient -ssbs 30 -ns 300 -bs 100 -rc 4. -r 0 -rt 0 -cn Age -nf 4000 -c
We are going to create a ShapleyData object, which stores two things, the SHAP importance for each CpG for each individual, for a given class/multi-regression output; and we also store the top contributing CpGs for quick access. The SHAP data is accessed by the ShapleyDataExplorer class, which has a number of subroutines to analyze the CpGs with and navigate the ShapleyData and the BioInterpreter class, which inputs the top contributing CpGs across classes and individuals to downstream analyses such as LOLA and gometh. For this tutorial, we will focus on more primitive analyses and leave it as an exercise to try to input this data into those analyses. The SHAP data is very high dimensional (number samples x number cpgs x number classes), so the ShapleyData class stores a more relevant subset of this data (eg. if the sample was predicted to be of class 1 and not 2, only save the importances towards class 1). Also, when training the SHAP model, the training and test sets should be fed into the model. The training sets represents background information that the test data regresses on when forming the individual coefficients. See the SHAP and MethylNet paper for more information.
With regards to the command above, -c enables cuda, -cn is the pheno column to try to estimate (sum of the CpG contributions should add up to the age, offset by the expected age across the training data). The SHAP coefficients are found using a gradient-based interpretation method, which means it uses the gradients calculated between the output and input of the model. These approximate the coefficients of a linear model, which is what we hope to create. The SHAP method denotes a coefficients importance by sampling permutations of activated features to find how important on average is the CpG to the output of the model. How much off is the model prediction if the CpG is excluded? So we run through the entire dataset broken into batches of batch size (-bs) to estimate these coefficients on the test samples. We do this 300 times (-ns ; number of shapley samples) but batch these samples up into jobs of 30 shap sample estimates (-ssbs) to make the problem less memory intensive. (-rc) is a filter that tries to remove poorly performing training and test MethylationArray samples. This method trains using the model from the predictions folder. -r and -rt relate to the number of samples used as background noise and for testing/finding coefficients (for categorical/classification, this is number of test samples per class) respectively. -nf is the number of top contributing CpGs to store in the ShapleyData object though all are stored as well. The other options can be found in the API documentation. There are two other interpretation methods other than gradient (kernel and deep).
At this point, we may want to study the CpG contributions to age prediction on a coarser level, so we bin the test MethylationArray and ShapleyData by age group (new shap data stored in interpretations/shapley_explanations/shapley_binned.p):
methylnet-interpret bin_regression_shaps -c Age -n 8
This means now we have 8 evenly spaced ages to study, now it is a categorical variable, the binned MethylationArray object now has a new column in its pheno matrix "Age_binned".
The CpG contributions can now be summed up by age group, and these age groups can be compared.
The below command now will correlate the shapley coefficients between the age groups and find a hierarchical clustering. -c all selects all of the age groups, though you can choose a subset by specifying this option one time per age group. You can do this for the absolute values of the SHAP scores using the -abs option, which was not done here, but can be useful if looking to find that CpGs that are related to older age (should have strong positive contribution) are similar to those found for low age (should have negative contribution). -log and -hist will plot a histogram of the distribution of the shap values in each group, log scaled, which can be important in the future for deriving some test statistics.
methylnet-interpret return_shap_values -log -c all -hist -s interpretations/shapley_explanations/shapley_binned.p -o interpretations/shap_results/ &
The interpret_biology command will overlap the CpGs with genes using gometh, gsea, or enhancers/promoters/etc via LOLA, or other custom defined sets (horvath clocks, hannum, epitoc, IDOL DMRs) and output onto the console or to log files in interpretations/ folder the amount of overlap/enrichment/depletion. Please check out the API documentation for all of the options and submit an issue if trying to figure out the framework or suggestions. -ov is for an overlap test.
methylnet-interpret interpret_biology -ov -c all -s interpretations/shapley_explanations/shapley_binned.p -cgs hannum
Last but not least, you may want to evaluate and perform hierarchical clusterings of the fine-tuned embeddings or the original VAE embeddings to make sure the models results conform with reality/expected biology. If we instead study the TCGA data, you can use this command to perform this analysis and plot the embedding differences, thus deriving the h-clustering (this does it on the finetuned model):
methylnet-interpret interpret_embedding_classes -i ./predictions/vae_mlp_methyl_arr.pkl
pymethyl-visualize plot_heatmap -m distance -fs .6 -i results/class_embedding_differences.csv -o ./results/class_embedding_differences.png -x -y -c &
Note: Please bear in mind the effects of multi-collinearity, which has the effect of altering the coefficients of collinear features from their true explanatory values, you may want to select out redundant features or find some way to give less weight to them or group them. To this end, if you group CpGs, you should be able to train the model end to end using these groupings as long as they are stored in a MethylationArray. SHAP interpretations should still be able to be computed. But you may have to manipulate the data after this to obtain the downstream biological analysis results other than that offered by this platform.