Skip to content

Commit

Permalink
Improve and add figures in cryodrgn analyze #219
Browse files Browse the repository at this point in the history
  • Loading branch information
zhonge committed Mar 26, 2023
1 parent 7f683e3 commit 521a459
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 36 deletions.
36 changes: 32 additions & 4 deletions cryodrgn/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,28 @@ def get_ind_for_cluster(
# PLOTTING


def _get_chimerax_colors(K: int) -> List:
colors = [
"#b2b2b2",
"#ffffb2",
"#b2ffff",
"#b2b2ff",
"#ffb2ff",
"#ffb2b2",
"#b2ffb2",
"#e5bf99",
"#99bfe5",
"#cccc99",
]
if K < 10:
colors = colors[0:K]
else:
colors *= K // 10
if K % 10:
colors += colors[0 : (K % 10)]
return colors


def _get_colors(K: int, cmap: Optional[str] = None) -> List:
if cmap is not None:
cm = plt.get_cmap(cmap)
Expand All @@ -244,16 +266,19 @@ def scatter_annotate(
labels: Optional[np.ndarray] = None,
alpha: Union[float, np.ndarray, None] = 0.1,
s: Union[float, np.ndarray, None] = 1,
colors: Optional[list] = None,
) -> Tuple[Figure, Axes]:
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(4, 4))
plt.scatter(x, y, alpha=alpha, s=s, rasterized=True)

# plot cluster centers
if centers_ind is not None:
assert centers is None
centers = np.array([[x[i], y[i]] for i in centers_ind])
if centers is not None:
plt.scatter(centers[:, 0], centers[:, 1], c="k")
if colors is None:
colors = "k"
plt.scatter(centers[:, 0], centers[:, 1], c=colors, edgecolor="black")
if annotate:
assert centers is not None
if labels is None:
Expand All @@ -271,15 +296,18 @@ def scatter_annotate_hex(
centers_ind: Optional[np.ndarray] = None,
annotate: bool = True,
labels: Optional[np.ndarray] = None,
colors: Optional[List] = None,
) -> sns.JointGrid:
g = sns.jointplot(x=x, y=y, kind="hex")
g = sns.jointplot(x=x, y=y, kind="hex", height=4)

# plot cluster centers
if centers_ind is not None:
assert centers is None
centers = np.array([[x[i], y[i]] for i in centers_ind])
if centers is not None:
g.ax_joint.scatter(centers[:, 0], centers[:, 1], color="k", edgecolor="grey")
if colors is None:
colors = "k"
g.ax_joint.scatter(centers[:, 0], centers[:, 1], c=colors, edgecolor="black")
if annotate:
assert centers is not None
if labels is None:
Expand Down
174 changes: 142 additions & 32 deletions cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,70 +133,180 @@ def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20):

# Make some plots
logger.info("Generating plots...")
plt.figure(1)
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], alpha=0.1, s=2)
g.set_axis_labels("PC1", "PC2")
plt.tight_layout()

def plt_pc_labels(x=0, y=1):
plt.xlabel(f"PC{x+1} ({pca.explained_variance_ratio_[x]:.2f})")
plt.ylabel(f"PC{y+1} ({pca.explained_variance_ratio_[y]:.2f})")

def plt_pc_labels_jointplot(g, x=0, y=1):
g.ax_joint.set_xlabel(f"PC{x+1} ({pca.explained_variance_ratio_[x]:.2f})")
g.ax_joint.set_ylabel(f"PC{y+1} ({pca.explained_variance_ratio_[y]:.2f})")

def plt_umap_labels():
plt.xticks([])
plt.yticks([])
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")

def plt_umap_labels_jointplot(g):
g.ax_joint.set_xlabel("UMAP1")
g.ax_joint.set_ylabel("UMAP2")

# PCA -- Style 1 -- Scatter
plt.figure(figsize=(4, 4))
plt.scatter(pc[:, 0], pc[:, 1], alpha=0.1, s=1, rasterized=True)
plt_pc_labels()
plt.savefig(f"{outdir}/z_pca.png")

plt.figure(2)
g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], kind="hex")
g.set_axis_labels("PC1", "PC2")
plt.tight_layout()
# PCA -- Style 2 -- Scatter, with marginals
g = sns.jointplot(pc[:, 0], pc[:, 1], alpha=0.1, s=1, rasterized=True, height=4)
plt_pc_labels_jointplot(g)
plt.savefig(f"{outdir}/z_pca_marginals.png")

# PCA -- Style 3 -- Hexbin
g = sns.jointplot(pc[:, 0], pc[:, 1], height=4, kind="hex")
plt_pc_labels_jointplot(g)
plt.savefig(f"{outdir}/z_pca_hexbin.png")

if umap_emb is not None:
plt.figure(3)
g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], alpha=0.1, s=2)
g.set_axis_labels("UMAP1", "UMAP2")
plt.tight_layout()
# Style 1 -- Scatter
plt.figure(figsize=(4, 4))
plt.scatter(umap_emb[:, 0], umap_emb[:, 1], alpha=0.1, s=1, rasterized=True)
plt_umap_labels()
plt.savefig(f"{outdir}/umap.png")

plt.figure(4)
g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], kind="hex")
g.set_axis_labels("UMAP1", "UMAP2")
plt.tight_layout()
# Style 2 -- Scatter with marginal distributions
g = sns.jointplot(
umap_emb[:, 0], umap_emb[:, 1], alpha=0.1, s=1, rasterized=True, height=4
)
plt_umap_labels_jointplot(g)
plt.savefig(f"{outdir}/umap_marginals.png")

# Style 3 -- Hexbin / heatmap
g = sns.jointplot(umap_emb[:, 0], umap_emb[:, 1], kind="hex", height=4)
plt_umap_labels_jointplot(g)
plt.savefig(f"{outdir}/umap_hexbin.png")

# Plot kmeans sample points
colors = analysis._get_chimerax_colors(K)
analysis.scatter_annotate(
pc[:, 0], pc[:, 1], centers_ind=centers_ind, annotate=True
pc[:, 0],
pc[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt.xlabel("PC1")
plt.ylabel("PC2")
plt_pc_labels()
plt.savefig(f"{outdir}/kmeans{K}/z_pca.png")

g = analysis.scatter_annotate_hex(
pc[:, 0], pc[:, 1], centers_ind=centers_ind, annotate=True
pc[:, 0],
pc[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
g.set_axis_labels("PC1", "PC2")
plt.tight_layout()
plt_pc_labels_jointplot(g)
plt.savefig(f"{outdir}/kmeans{K}/z_pca_hex.png")

if umap_emb is not None:
analysis.scatter_annotate(
umap_emb[:, 0], umap_emb[:, 1], centers_ind=centers_ind, annotate=True
umap_emb[:, 0],
umap_emb[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt_umap_labels()
plt.savefig(f"{outdir}/kmeans{K}/umap.png")

g = analysis.scatter_annotate_hex(
umap_emb[:, 0], umap_emb[:, 1], centers_ind=centers_ind, annotate=True
umap_emb[:, 0],
umap_emb[:, 1],
centers_ind=centers_ind,
annotate=True,
colors=colors,
)
g.set_axis_labels("UMAP1", "UMAP2")
plt.tight_layout()
plt_umap_labels_jointplot(g)
plt.savefig(f"{outdir}/kmeans{K}/umap_hex.png")

# Plot PC trajectories
for i in range(num_pcs):
start, end = np.percentile(pc[:, i], (5, 95))
z_pc = analysis.get_pc_traj(pca, z.shape[1], 10, i + 1, start, end)
if umap_emb is not None:
# UMAP, colored by PCX
analysis.scatter_color(
umap_emb[:, 0], umap_emb[:, 1], pc[:, i], label=f"PC{i+1}"
umap_emb[:, 0],
umap_emb[:, 1],
pc[:, i],
label=f"PC{i+1}",
)
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.tight_layout()
plt_umap_labels()
plt.savefig(f"{outdir}/pc{i+1}/umap.png")

# UMAP, with PC traversal
z_pc_on_data, pc_ind = analysis.get_nearest_point(z, z_pc)
dists = ((z_pc_on_data - z_pc) ** 2).sum(axis=1) ** 0.5
if np.any(dists > 2):
logger.warn(
f"Warning: PC{i+1} point locations in UMAP plot may be inaccurate"
)
plt.figure(figsize=(4, 4))
plt.scatter(
umap_emb[:, 0], umap_emb[:, 1], alpha=0.05, s=1, rasterized=True
)
plt.scatter(
umap_emb[pc_ind, 0],
umap_emb[pc_ind, 1],
c="cornflowerblue",
edgecolor="black",
)
plt_umap_labels()
plt.savefig(f"{outdir}/pc{i+1}/umap_traversal.png")

# UMAP, with PC traversal, connected
plt.figure(figsize=(4, 4))
plt.scatter(
umap_emb[:, 0], umap_emb[:, 1], alpha=0.05, s=1, rasterized=True
)
plt.plot(umap_emb[pc_ind, 0], umap_emb[pc_ind, 1], "--", c="k")
plt.scatter(
umap_emb[pc_ind, 0],
umap_emb[pc_ind, 1],
c="cornflowerblue",
edgecolor="black",
)
plt_umap_labels()
plt.savefig(f"{outdir}/pc{i+1}/umap_traversal_connected.png")

# 10 points, from 5th to 95th percentile of PC1 values
t = np.linspace(start, end, 10, endpoint=True)
plt.figure(figsize=(4, 4))
if i > 0 and i == num_pcs - 1:
plt.scatter(pc[:, i - 1], pc[:, i], alpha=0.1, s=1, rasterized=True)
plt.scatter(np.zeros(10), t, c="cornflowerblue", edgecolor="white")
plt_pc_labels(i - 1, i)
else:
plt.scatter(pc[:, i], pc[:, i + 1], alpha=0.1, s=1, rasterized=True)
plt.scatter(t, np.zeros(10), c="cornflowerblue", edgecolor="white")
plt_pc_labels(i, i + 1)
plt.savefig(f"{outdir}/pc{i+1}/pca_traversal.png")

if i > 0 and i == num_pcs - 1:
g = sns.jointplot(
pc[:, i - 1], pc[:, i], alpha=0.1, s=1, rasterized=True, height=4
)
g.ax_joint.scatter(np.zeros(10), t, c="cornflowerblue", edgecolor="white")
plt_pc_labels_jointplot(g, i - 1, i)
else:
g = sns.jointplot(
pc[:, i], pc[:, i + 1], alpha=0.1, s=1, rasterized=True, height=4
)
g.ax_joint.scatter(t, np.zeros(10), c="cornflowerblue", edgecolor="white")
plt_pc_labels_jointplot(g)
plt.savefig(f"{outdir}/pc{i+1}/pca_traversal_hex.png")


class VolumeGenerator:
"""Helper class to call analysis.gen_volumes"""
Expand Down

0 comments on commit 521a459

Please sign in to comment.