Skip to content

Commit

Permalink
different colormaps for rows/columns by default in decorated clusterm…
Browse files Browse the repository at this point in the history
…ap; add tests for decorated clustermap, related to #10
  • Loading branch information
afrendeiro committed May 20, 2020
1 parent 7aec22e commit 4298023
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 49 deletions.
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.DEFAULT_GOAL := install
.DEFAULT_GOAL := all

all: install clean
all: install clean test

move_models_out:
mv _models ../
Expand Down Expand Up @@ -31,6 +31,9 @@ install:
${MAKE} clean
${MAKE} move_models_in

test:
python -m pytest imc/

run:
python imcpipeline/runner.py \
--divvy slurm \
Expand Down
4 changes: 2 additions & 2 deletions imc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def setup_logger(level=logging.INFO):
logger = logging.getLogger("imcpipeline")
logger = logging.getLogger("imc")
logger.setLevel(level)

handler = logging.StreamHandler(sys.stdout)
Expand All @@ -28,7 +28,7 @@ def setup_logger(level=logging.INFO):
LOGGER = setup_logger()

# Setup joblib memory
JOBLIB_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".imcpipeline")
JOBLIB_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".imc")
MEMORY = Memory(location=JOBLIB_CACHE_DIR, verbose=0)

# Decorate seaborn clustermap
Expand Down
14 changes: 6 additions & 8 deletions imc/data_models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,7 @@ def channel_summary(
res,
cbar_kws=dict(label=red_func.capitalize() + cbar_label),
row_colors=channel_mean,
row_colors_cmap="Greens",
col_colors=cell_density,
col_colors_cmap="BuPu",
metric="correlation",
xticklabels=True,
yticklabels=True,
Expand Down Expand Up @@ -441,9 +439,8 @@ def quantify_cells(
quantification = self.quantify_cell_intensity(samples=samples, rois=rois)
if not set_attribute:
return quantification
else:
self.quantification = quantification
return None
self.quantification = quantification
return None

def quantify_cell_intensity(
self,
Expand Down Expand Up @@ -948,10 +945,11 @@ def measure_adjacency(
)
sns.clustermap(mean_f, cmap="RdBu_r", center=0, robust=True)

__v = np.percentile(melted["value"].abs(), 95)
v = np.percentile(melted["value"].abs(), 95)
n, m = get_grid_dims(len(freqs))
fig, axes = plt.subplots(n, m, figsize=(m * 5, n * 5), sharex=True, sharey=True)
axes = axes.flatten()
i = -1
for i, (dfs, roi) in enumerate(zip(freqs, rois)):
axes[i].set_title(roi.name)
sns.heatmap(
Expand All @@ -963,8 +961,8 @@ def measure_adjacency(
square=True,
xticklabels=True,
yticklabels=True,
vmin=-__v,
vmax=__v,
vmin=-v,
vmax=v,
)
for axs in axes[i + 1 :]:
axs.axis("off")
Expand Down
54 changes: 43 additions & 11 deletions imc/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,41 @@


SEQUENCIAL_CMAPS = [
"Purples", "Greens", "Oranges", "Greys", "Reds", "Blues",
"YlOrBr", "YlOrRd", "OrRd", "PuRd", "RdPu", "BuPu",
"GnBu", "PuBu", "YlGnBu", "PuBuGn", "BuGn", "YlGn",
"binary", "gist_yarg", "gist_gray", "gray", "bone", "pink",
"spring", "summer", "autumn", "winter", "cool", "Wistia",
"hot", "afmhot", "gist_heat", "copper"]
"Purples",
"Greens",
"Oranges",
"Greys",
"Reds",
"Blues",
"YlOrBr",
"YlOrRd",
"OrRd",
"PuRd",
"RdPu",
"BuPu",
"GnBu",
"PuBu",
"YlGnBu",
"PuBuGn",
"BuGn",
"YlGn",
"binary",
"gist_yarg",
"gist_gray",
"gray",
"bone",
"pink",
"spring",
"summer",
"autumn",
"winter",
"cool",
"Wistia",
"hot",
"afmhot",
"gist_heat",
"copper",
]


def to_color_series(x: pd.Series, cmap: Optional[str] = "Greens") -> pd.Series:
Expand All @@ -32,13 +61,16 @@ def to_color_series(x: pd.Series, cmap: Optional[str] = "Greens") -> pd.Series:


def to_color_dataframe(
x: Union[pd.Series, pd.DataFrame], cmaps: Optional[Union[str, List[str]]] = None
x: Union[pd.Series, pd.DataFrame],
cmaps: Optional[Union[str, List[str]]] = None,
offset: int = 0,
) -> pd.DataFrame:
"""Map a numeric pandas DataFrame to RGB values."""
if isinstance(x, pd.Series):
x = x.to_frame()
if cmaps is None:
cmaps = [plt.get_cmap(cmap) for cmap in SEQUENCIAL_CMAPS[: x.shape[1]]]
# the offset is in order to get different colors for rows and columns by default
cmaps = [plt.get_cmap(cmap) for cmap in SEQUENCIAL_CMAPS[offset:]]
if isinstance(cmaps, str):
cmaps = [cmaps]
return pd.concat([to_color_series(x[col], cmap) for col, cmap in zip(x, cmaps)], axis=1)
Expand Down Expand Up @@ -129,7 +161,7 @@ def clustermap(*args, **kwargs):
if isinstance(kwargs[arg + "_colors"], (pd.DataFrame, pd.Series)):
_kwargs[arg + "s"] = kwargs[arg + "_colors"]
kwargs[arg + "_colors"] = to_color_dataframe(
kwargs[arg + "_colors"], cmaps[arg]
x=kwargs[arg + "_colors"], cmaps=cmaps[arg], offset=1 if arg == "row" else 0
)
grid = f(*args, **kwargs)
_add_colorbars(grid, **_kwargs, row_cmaps=cmaps["row"], col_cmaps=cmaps["col"])
Expand Down Expand Up @@ -239,8 +271,8 @@ def get_transparent_cmaps(n: int = 3, from_palette: Optional[str] = "colorblind"
def cell_labels_to_mask(mask: Array, labels: Union[Series, Dict]) -> Array:
"""Replaces integers in `mask` with values from the mapping in `labels`."""
res = np.zeros(mask.shape, dtype=int)
for __k, __v in labels.items():
res[mask == __k] = __v
for k, v in labels.items():
res[mask == k] = v
return res


Expand Down
21 changes: 5 additions & 16 deletions imc/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
Functions for high order operations.
"""

from __future__ import annotations
# ^^ this will fix the type annotatiton of not yet undefined classes
from __future__ import annotations # fix the type annotatiton of not yet undefined classes
from collections import Counter
import os
import re
Expand Down Expand Up @@ -42,12 +41,7 @@
from imc.exceptions import cast
from imc.types import DataFrame, Series, Array, Path, MultiIndexSeries
from imc.utils import read_image_from_file, estimate_noise, double_z_score
from imc.graphics import colorbar_decorator, get_grid_dims, rasterize_scanpy, add_legend

# import leidenalg as la


sns.clustermap = colorbar_decorator(sns.clustermap)
from imc.graphics import get_grid_dims, rasterize_scanpy, add_legend


matplotlib.rcParams["svg.fonttype"] = "none"
Expand Down Expand Up @@ -582,9 +576,7 @@ def single_cell_analysis(

kwargs = dict(
row_colors=row_means,
row_colors_cmaps=["Greens"],
col_colors=col_counts,
col_colors_cmaps=["Purples"],
metric="correlation",
robust=True,
xticklabels=True,
Expand Down Expand Up @@ -742,8 +734,8 @@ def derive_reference_cell_type_labels(
robust=True,
xticklabels=True,
yticklabels=True,
row_colors=cmeans, # row_colors_cmaps=['Greens'],
# col_colors=fractions, col_colors_cmaps=["Purples"]
row_colors=cmeans,
# col_colors=fractions,
)
opts = [
(mean_expr, "original", dict()),
Expand Down Expand Up @@ -775,9 +767,7 @@ def derive_reference_cell_type_labels(
cmap="RdBu_r",
cbar_kws=dict(label="Mean expression (Z-score)"),
row_colors=cmeans,
row_colors_cmaps=["Greens"],
col_colors=fractions,
col_colors_cmaps=["Purples"],
)
opts = [
(mean_expr_z_l, "labeled.both_z", dict()),
Expand All @@ -798,9 +788,7 @@ def derive_reference_cell_type_labels(
xticklabels=True,
yticklabels=True,
row_colors=fractions_l,
row_colors_cmaps=["Purples"],
col_colors=fractions_l,
col_colors_cmaps=["Purples"],
)
grid.savefig(
output_prefix + "mean_expression_per_cluster.labeled.both_z.correlation.svg", **FIG_KWS
Expand Down Expand Up @@ -1030,6 +1018,7 @@ def measure_cell_type_adjacency(
kws2 = dict(vmin=-v, vmax=v, cbar_kws=dict(label="Log odds interaction"))
sns.heatmap(norm_freqs, ax=axes[1], **kws, **kws2)
fig.savefig(output_prefix + "cluster_adjacency_graph.norm_over_random.heatmap.svg", **FIG_KWS)
del kws["square"]
grid = sns.clustermap(norm_freqs, **kws, **kws2)
grid.savefig(
output_prefix + "cluster_adjacency_graph.norm_over_random.clustermap.svg", **FIG_KWS
Expand Down
Empty file added imc/tests/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions imc/tests/test_graphics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env python

import numpy as np
import pandas as pd
import seaborn as sns
import pytest


class Test_colorbar_decorator:
def test_random_plot(self):
x = pd.DataFrame(np.random.random((10, 5))).rename_axis(index="rows", columns="columns")
g = sns.clustermap(x, row_colors=x.mean(1), col_colors=x.mean(0))
assert len(g.fig.get_axes()) == 8
37 changes: 27 additions & 10 deletions imc/types.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
#!/usr/bin/env python

"""
Types used in the library as defined here.
See https://docs.python.org/3/library/typing.html
for more information.
"""

from __future__ import annotations
import os
from typing import Union, TypeVar
import pathlib

import matplotlib
import pandas
import numpy
import matplotlib # type: ignore
import pandas # type: ignore
import numpy # type: ignore


class Path(pathlib.Path):
"""
A pathlib.Path child that allows concatenation with strings using the
addition operator
A pathlib.Path child class that allows concatenation with strings using the
addition operator.
In addition, it implements the ``startswith`` and ``endswith`` methods
just like in the base :obj:`str` type.
"""

_flavour = pathlib._windows_flavour if os.name == "nt" else pathlib._posix_flavour
_flavour = (
pathlib._windows_flavour # pylint: disable=W0212
if os.name == "nt"
else pathlib._posix_flavour # pylint: disable=W0212
)

def __add__(self, string: str) -> "Path":
return Path(str(self) + string)
Expand All @@ -32,11 +48,12 @@ def replace_(self, patt: str, repl: str) -> Path:
GenericType = TypeVar("GenericType")

# type aliasing (done with Union to distinguish from other declared variables)
Axis = Union[matplotlib.axis.Axis]
Figure = Union[matplotlib.figure.Figure]
Patch = Union[matplotlib.patches.Patch]
Array = Union[numpy.ndarray]
DataFrame = Union[pandas.DataFrame]
Series = Union[pandas.Series]
MultiIndexSeries = Union[pandas.Series]
DataFrame = Union[pandas.DataFrame]

Figure = Union[matplotlib.figure.Figure]
Axis = Union[matplotlib.axis.Axis]
Patch = Union[matplotlib.patches.Patch]
ColorMap = Union[matplotlib.colors.LinearSegmentedColormap]
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def parse_requirements(req_file):
classifiers=[
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Development Status :: 3 - Alpha",
"Typing :: Typed",
"License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
Expand Down

0 comments on commit 4298023

Please sign in to comment.