From f44a98828f42461dfe92e59af5bec99c6ad305eb Mon Sep 17 00:00:00 2001 From: Ellen Zhong Date: Wed, 22 Feb 2023 00:02:44 -0500 Subject: [PATCH] Add a jupyter notebook for regenerating and customizing figures #219 --- cryodrgn/commands/analyze.py | 10 + .../templates/cryoDRGN_figures_template.ipynb | 517 ++++++++++++++++++ 2 files changed, 527 insertions(+) create mode 100644 cryodrgn/templates/cryoDRGN_figures_template.ipynb diff --git a/cryodrgn/commands/analyze.py b/cryodrgn/commands/analyze.py index 6e492eeb..33587348 100644 --- a/cryodrgn/commands/analyze.py +++ b/cryodrgn/commands/analyze.py @@ -391,6 +391,16 @@ def main(args): logger.info(f"{out_ipynb} already exists. Skipping") logger.info(out_ipynb) + # copy over template if file doesn't exist + out_ipynb = f"{outdir}/cryoDRGN_figures.ipynb" + if not os.path.exists(out_ipynb): + logger.info("Creating jupyter notebook...") + ipynb = f"{cryodrgn._ROOT}/templates/cryoDRGN_figures_template.ipynb" + shutil.copyfile(ipynb, out_ipynb) + else: + logger.info(f"{out_ipynb} already exists. Skipping") + logger.info(out_ipynb) + logger.info(f"Finished in {dt.now()-t1}") diff --git a/cryodrgn/templates/cryoDRGN_figures_template.ipynb b/cryodrgn/templates/cryoDRGN_figures_template.ipynb new file mode 100644 index 00000000..eefb2a49 --- /dev/null +++ b/cryodrgn/templates/cryoDRGN_figures_template.ipynb @@ -0,0 +1,517 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "12d6bdad", + "metadata": {}, + "source": [ + "# CryoDRGN visualization and figures\n", + "\n", + "This jupyter notebook provides a template for regenerating and customizing cryoDRGN visualizations and figures" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b98dcf23", + "metadata": {}, + "outputs": [], + "source": [ + "from cryodrgn import analysis\n", + "from cryodrgn import utils\n", + "\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "markdown", + "id": "47ec8987", + "metadata": {}, + "source": [ + "### Load results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "829f0e1d", + "metadata": {}, + "outputs": [], + "source": [ + "# Specify the workdir and the epoch number (0-based index) to analyze\n", + "WORKDIR = '..' \n", + "EPOCH = 19 # CHANGE ME" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "551cc751", + "metadata": {}, + "outputs": [], + "source": [ + "# Load z\n", + "z = utils.load_pkl(f'{WORKDIR}/z.{EPOCH}.pkl')\n", + "umap = utils.load_pkl(f'{WORKDIR}/analyze.{EPOCH}/umap.pkl')" + ] + }, + { + "cell_type": "markdown", + "id": "9cce7848", + "metadata": {}, + "source": [ + "# Plot PCA\n", + "\n", + "Visualize the latent space by principal component analysis (PCA)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81518fa1", + "metadata": {}, + "outputs": [], + "source": [ + "pc, pca = analysis.run_pca(z)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a441bce1", + "metadata": {}, + "outputs": [], + "source": [ + "# Style 1 -- Scatter\n", + "\n", + "plt.figure(figsize=(4,4))\n", + "plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)\n", + "plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))\n", + "#plt.savefig('pca_style1.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fe803e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Style 2 -- Scatter with marginals\n", + "\n", + "g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)\n", + "g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))\n", + "#plt.savefig('pca_style2.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d27280eb", + "metadata": {}, + "outputs": [], + "source": [ + "# Style 3 -- Hexbin/heatmap\n", + "\n", + "g = sns.jointplot(x=pc[:,0], y=pc[:,1], height=4, kind='hex')\n", + "plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))\n", + "#plt.savefig('pca_style3.pdf')" + ] + }, + { + "cell_type": "markdown", + "id": "cea11ef1", + "metadata": {}, + "source": [ + "# Plot UMAP\n", + "\n", + "Visualize the latent space by Uniform Manifold Approximation and Projection (UMAP). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "feb9a1d1", + "metadata": {}, + "outputs": [], + "source": [ + "# Style 1 -- Scatter\n", + "\n", + "plt.figure(figsize=(4,4))\n", + "plt.scatter(umap[:,0], umap[:,1], alpha=.1, s=1,rasterized=True)\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "plt.xlabel('UMAP1')\n", + "plt.ylabel('UMAP2')\n", + "#plt.savefig('umap_style1.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e5dd8a8", + "metadata": {}, + "outputs": [], + "source": [ + "# Style 2 -- Scatter with marginal distributions\n", + "\n", + "g = sns.jointplot(x=umap[:,0], y=umap[:,1], alpha=.1, s=1,rasterized=True, height=4)\n", + "g.ax_joint.set_xlabel('UMAP1')\n", + "g.ax_joint.set_ylabel('UMAP2')\n", + "#plt.savefig('umap_style2.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc4def2a", + "metadata": {}, + "outputs": [], + "source": [ + "# Style 3 -- Hexbin / heatmap\n", + "\n", + "g = sns.jointplot(x=umap[:,0], y=umap[:,1], kind='hex',height=4)\n", + "g.ax_joint.set_xlabel('UMAP1')\n", + "g.ax_joint.set_ylabel('UMAP2')\n", + "#plt.savefig('umap_style3.pdf')" + ] + }, + { + "cell_type": "markdown", + "id": "7cd98f8e", + "metadata": {}, + "source": [ + "# Plot kmeans samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e992fca8", + "metadata": {}, + "outputs": [], + "source": [ + "# Load points\n", + "K = 20 # CHANGE ME\n", + "kmeans_ind = np.loadtxt(f'{WORKDIR}/analyze.{EPOCH}/kmeans{K}/centers_ind.txt', dtype=int)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f176e1f", + "metadata": {}, + "outputs": [], + "source": [ + "# Default chimerax color map\n", + "colors = analysis._get_chimerax_colors(K)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f171d81", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot kmeans on PCA\n", + "\n", + "f, ax = plt.subplots(figsize=(4,4))\n", + "plt.scatter(pc[:,0], pc[:,1], alpha=.05, s=1,rasterized=True)\n", + "plt.scatter(pc[kmeans_ind,0], pc[kmeans_ind,1], c=colors,edgecolor='black')\n", + "labels = np.arange(len(kmeans_ind))\n", + "centers = pc[kmeans_ind]\n", + "for i in labels:\n", + " ax.annotate(str(i), centers[i, 0:2] + np.array([0.1, 0.1]))\n", + "plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))\n", + "#plt.savefig('pca_w_kmeans.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4071fd5", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot kmeans on UMAP\n", + "\n", + "f, ax = plt.subplots(figsize=(4,4))\n", + "plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)\n", + "plt.scatter(umap[kmeans_ind,0], umap[kmeans_ind,1], c=colors,edgecolor='black')\n", + "labels = np.arange(len(kmeans_ind))\n", + "centers = umap[kmeans_ind]\n", + "for i in labels:\n", + " ax.annotate(str(i), centers[i, 0:2] + np.array([0.1, 0.1]))\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "plt.xlabel('UMAP1')\n", + "plt.ylabel('UMAP2')\n", + "#plt.savefig('umap_w_kmeans.pdf')" + ] + }, + { + "cell_type": "markdown", + "id": "5a3d7052", + "metadata": {}, + "source": [ + "### Plot PC traversals\n", + "\n", + "Visualize the PC axes traversals. By default, plot the first two PCs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0576ccf", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(4,4))\n", + "plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)\n", + "\n", + "# 10 points, from 5th to 95th percentile of PC1 values\n", + "t = np.linspace(np.percentile(pc[:,0],5),np.percentile(pc[:,0],95), 10, endpoint=True)\n", + "plt.scatter(t,np.zeros(10),c='cornflowerblue',edgecolor='white')\n", + "\n", + "plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de260d19", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(4,4))\n", + "plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=1,rasterized=True)\n", + "\n", + "# 10 points, from 5th to 95th percentile of PC2 values\n", + "t = np.linspace(np.percentile(pc[:,1],5),np.percentile(pc[:,1],95),10,endpoint=True)\n", + "plt.scatter(np.zeros(10),t,c='cornflowerblue',edgecolor='white')\n", + "\n", + "plt.xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "plt.ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aefaacef", + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)\n", + "\n", + "t = np.linspace(np.percentile(pc[:,0],5),np.percentile(pc[:,0],95),10,endpoint=True)\n", + "g.ax_joint.scatter(x=t,y=np.zeros(10),c='cornflowerblue',edgecolor='white')\n", + "\n", + "g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))\n", + "#plt.savefig('pca_pc1_traversal.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3785ef98", + "metadata": {}, + "outputs": [], + "source": [ + "g = sns.jointplot(x=pc[:,0], y=pc[:,1], alpha=.1, s=1,rasterized=True, height=4)\n", + "t = np.linspace(np.percentile(pc[:,1],5),np.percentile(pc[:,1],95),10,endpoint=True)\n", + "g.ax_joint.scatter(x=np.zeros(10),y=t,c='cornflowerblue',edgecolor='white')\n", + "g.ax_joint.set_xlabel('PC1 ({:.2f})'.format(pca.explained_variance_ratio_[0]))\n", + "g.ax_joint.set_ylabel('PC2 ({:.2f})'.format(pca.explained_variance_ratio_[1]))\n", + "#plt.savefig('pca_pc2_traversal.pdf')" + ] + }, + { + "cell_type": "markdown", + "id": "2e40a6ef", + "metadata": {}, + "source": [ + "### Plot UMAP \n", + "\n", + "Plot the PC axes traversal paths in the UMAP visualization of the latent space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a62cc0e", + "metadata": {}, + "outputs": [], + "source": [ + "z_pc1 = np.loadtxt('pc1/z_values.txt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5298124f", + "metadata": {}, + "outputs": [], + "source": [ + "z_pc1_on_data, pc1_ind = analysis.get_nearest_point(z, z_pc1)\n", + "((z_pc1_on_data - z_pc1)**2).sum(axis=1)**.5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ae119ad", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(4,4))\n", + "plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)\n", + "plt.scatter(umap[pc1_ind,0], umap[pc1_ind,1], c='cornflowerblue',edgecolor='black')\n", + "\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "plt.xlabel('UMAP1')\n", + "plt.ylabel('UMAP2')\n", + "#plt.savefig('umap_pc1_traversal.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08d1a2a3", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(4,4))\n", + "plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)\n", + "plt.plot(umap[pc1_ind,0], umap[pc1_ind,1], '--',c='k')\n", + "plt.scatter(umap[pc1_ind,0], umap[pc1_ind,1], c='cornflowerblue',edgecolor='black')\n", + "\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "plt.xlabel('UMAP1')\n", + "plt.ylabel('UMAP2')\n", + "#plt.savefig('umap_pc1_traversal_v2.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0571e10", + "metadata": {}, + "outputs": [], + "source": [ + "z_pc2 = np.loadtxt('pc2/z_values.txt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0eb0c9a0", + "metadata": {}, + "outputs": [], + "source": [ + "z_pc2_on_data, pc2_ind = analysis.get_nearest_point(z, z_pc2)\n", + "((z_pc2_on_data - z_pc2)**2).sum(axis=1)**.5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b355769f", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(4,4))\n", + "plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)\n", + "plt.scatter(umap[pc2_ind,0], umap[pc2_ind,1], c='cornflowerblue',edgecolor='black')\n", + "\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "plt.xlabel('UMAP1')\n", + "plt.ylabel('UMAP2')\n", + "#plt.savefig('umap_pc2_traversal.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32a455e8", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(4,4))\n", + "plt.scatter(umap[:,0], umap[:,1], alpha=.05, s=1,rasterized=True)\n", + "plt.plot(umap[pc2_ind,0], umap[pc2_ind,1], '--',c='k')\n", + "plt.scatter(umap[pc2_ind,0], umap[pc2_ind,1], c='cornflowerblue',edgecolor='black')\n", + "\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "plt.xlabel('UMAP1')\n", + "plt.ylabel('UMAP2')\n", + "#plt.savefig('umap_pc2_traversal_v2.pdf')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0856141", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d1a1181", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b24f0d33", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dea70264", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}