diff --git a/qcodes/plots/qcmatplotlib.py b/qcodes/plots/qcmatplotlib.py index 1d5305b2a7c3..2dd93448b88c 100644 --- a/qcodes/plots/qcmatplotlib.py +++ b/qcodes/plots/qcmatplotlib.py @@ -3,11 +3,13 @@ using the nbagg backend and matplotlib """ from collections import Mapping +from functools import partial import matplotlib.pyplot as plt from matplotlib.transforms import Bbox import numpy as np from numpy.ma import masked_invalid, getmask +from collections import Sequence from .base import BasePlot @@ -18,10 +20,13 @@ class MatPlot(BasePlot): in the constructor, other traces can be added with MatPlot.add() Args: - *args: shortcut to provide the x/y/z data. See BasePlot.add + *args: Sequence of data to plot. Each element will have its own subplot. + An element can be a single array, or a sequence of arrays. In the + latter case, all arrays will be plotted in the same subplot. - figsize (Tuple[Float, Float]): (width, height) tuple in inches to pass to plt.figure - default (8, 5) + figsize (Tuple[Float, Float]): (width, height) tuple in inches to pass + to plt.figure. If not provided, figsize is determined from + subplots shape interval: period in seconds between update checks @@ -35,35 +40,80 @@ class MatPlot(BasePlot): **kwargs: passed along to MatPlot.add() to add the first data trace """ + + # Maximum default number of subplot columns. Used to determine shape of + # subplots when not explicitly provided + max_subplot_columns = 3 + def __init__(self, *args, figsize=None, interval=1, subplots=None, num=None, **kwargs): - super().__init__(interval) + if subplots is None: + # Subplots is equal to number of args, or 1 if no args provided + subplots = max(len(args), 1) + self._init_plot(subplots, figsize, num=num) - if args or kwargs: - self.add(*args, **kwargs) + + # Add data to plot if passed in args, kwargs are passed to all subplots + for k, arg in enumerate(args): + if isinstance(arg, Sequence): + # Arg consists of multiple elements, add all to same subplot + for subarg in arg: + self[k].add(subarg, **kwargs) + else: + # Arg is single element, add to subplot + self[k].add(arg, **kwargs) - def _init_plot(self, subplots=None, figsize=None, num=None): - if figsize is None: - figsize = (8, 5) + self.tight_layout() - if subplots is None: - subplots = (1, 1) + def __getitem__(self, key): + """ + Subplots can be accessed via indices. + Args: + key: subplot idx + Returns: + Subplot with idx key + """ + return self.subplots[key] + + def _init_plot(self, subplots=None, figsize=None, num=None): if isinstance(subplots, Mapping): + if figsize is None: + figsize = (6, 4) self.fig, self.subplots = plt.subplots(figsize=figsize, num=num, **subplots, squeeze=False) else: + # Format subplots as tuple (nrows, ncols) + if isinstance(subplots, int): + # self.max_subplot_columns defines the limit on how many + # subplots can be in one row. Adjust subplot rows and columns + # accordingly + nrows = int(np.ceil(subplots / self.max_subplot_columns)) + ncols = min(subplots, self.max_subplot_columns) + subplots = (nrows, ncols) + + if figsize is None: + # Adjust figsize depending on rows and columns in subplots + figsize = self.default_figsize(subplots) + self.fig, self.subplots = plt.subplots(*subplots, num=num, - figsize=figsize, squeeze=False) + figsize=figsize, + squeeze=False) - # squeeze=False ensures that subplots is always a 2D array independent of the number - # of subplots. + # squeeze=False ensures that subplots is always a 2D array independent + # of the number of subplots. # However the qcodes api assumes that subplots is always a 1D array # so flatten here self.subplots = self.subplots.flatten() + for k, subplot in enumerate(self.subplots): + # Include `add` method to subplots, making it easier to add data to + # subplots. Note that subplot kwarg is 1-based, to adhere to + # Matplotlib standards + subplot.add = partial(self.add, subplot=k+1) + self.title = self.fig.suptitle('') def clear(self, subplots=None, figsize=None): @@ -75,28 +125,35 @@ def clear(self, subplots=None, figsize=None): self.fig.clf() self._init_plot(subplots, figsize, num=self.fig.number) - def add_to_plot(self, **kwargs): + def add_to_plot(self, use_offset=False, **kwargs): """ adds one trace to this MatPlot. - kwargs: with the following exceptions (mostly the data!), these are - passed directly to the matplotlib plotting routine. - - `subplot`: the 1-based axes number to append to (default 1) - - if kwargs include `z`, we will draw a heatmap (ax.pcolormesh): - `x`, `y`, and `z` are passed as positional args to pcolormesh - - without `z` we draw a scatter/lines plot (ax.plot): - `x`, `y`, and `fmt` (if present) are passed as positional args + Args: + use_offset (bool, Optional): Whether or not ticks can have an offset + + kwargs: with the following exceptions (mostly the data!), these are + passed directly to the matplotlib plotting routine. + `subplot`: the 1-based axes number to append to (default 1) + if kwargs include `z`, we will draw a heatmap (ax.pcolormesh): + `x`, `y`, and `z` are passed as positional args to + pcolormesh + without `z` we draw a scatter/lines plot (ax.plot): + `x`, `y`, and `fmt` (if present) are passed as positional + args """ # TODO some way to specify overlaid axes? - ax = self._get_axes(kwargs) + # Note that there is a conversion from subplot kwarg, which is + # 1-based, to subplot idx, which is 0-based. + ax = self[kwargs.get('subplot', 1) - 1] if 'z' in kwargs: plot_object = self._draw_pcolormesh(ax, **kwargs) else: plot_object = self._draw_plot(ax, **kwargs) + # Specify if axes ticks can have offset or not + ax.ticklabel_format(useOffset=use_offset) + self._update_labels(ax, kwargs) prev_default_title = self.get_default_title() @@ -109,9 +166,6 @@ def add_to_plot(self, **kwargs): # in case the user has updated title, don't change it anymore self.title.set_text(self.get_default_title()) - def _get_axes(self, config): - return self.subplots[config.get('subplot', 1) - 1] - def _update_labels(self, ax, config): for axletter in ("x", "y"): if axletter+'label' in config: @@ -146,6 +200,21 @@ def _update_labels(self, ax, config): axsetter = getattr(ax, "set_{}label".format(axletter)) axsetter("{} ({})".format(label, unit)) + @staticmethod + def default_figsize(subplots): + """ + Provides default figsize for given subplots. + Args: + subplots (Tuple[Int, Int]): shape (nrows, ncols) of subplots + + Returns: + Figsize (Tuple[Float, Float])): (width, height) of default figsize + for given subplot shape + """ + if not isinstance(subplots, tuple): + raise TypeError('Subplots {} must be a tuple'.format(subplots)) + return (min(3 + 3 * subplots[1], 12), 1 + 3 * subplots[0]) + def update_plot(self): """ update the plot. The DataSets themselves have already been updated @@ -164,7 +233,7 @@ def update_plot(self): if plot_object: plot_object.remove() - ax = self._get_axes(config) + ax = self[config.get('subplot', 1) - 1] plot_object = self._draw_pcolormesh(ax, **config) trace['plot_object'] = plot_object @@ -202,11 +271,12 @@ def _draw_plot(self, ax, y, x=None, fmt=None, subplot=1, yunit=None, zunit=None, **kwargs): - # NOTE(alexj)stripping out subplot because which subplot we're in is already - # described by ax, and it's not a kwarg to matplotlib's ax.plot. But I - # didn't want to strip it out of kwargs earlier because it should stay - # part of trace['config']. + # NOTE(alexj)stripping out subplot because which subplot we're in is + # already described by ax, and it's not a kwarg to matplotlib's ax.plot. + # But I didn't want to strip it out of kwargs earlier because it should + # stay part of trace['config']. args = [arg for arg in [x, y, fmt] if arg is not None] + line, = ax.plot(*args, **kwargs) return line @@ -217,21 +287,71 @@ def _draw_pcolormesh(self, ax, z, x=None, y=None, subplot=1, xunit=None, yunit=None, zunit=None, + nticks=None, **kwargs): # NOTE(alexj)stripping out subplot because which subplot we're in is already # described by ax, and it's not a kwarg to matplotlib's ax.plot. But I # didn't want to strip it out of kwargs earlier because it should stay # part of trace['config']. - args = [masked_invalid(arg) for arg in [x, y, z] - if arg is not None] - - for arg in args: - if np.all(getmask(arg)): - # if any entire array is masked, don't draw at all - # there's nothing to draw, and anyway it throws a warning - return False + args_masked = [masked_invalid(arg) for arg in [x, y, z] + if arg is not None] + + if np.any([np.all(getmask(arg)) for arg in args_masked]): + # if the z array is masked, don't draw at all + # there's nothing to draw, and anyway it throws a warning + # pcolormesh does not accept masked x and y axes, so we do not need + # to check for them. + return False + + if x is not None and y is not None: + # If x and y are provided, modify the arrays such that they + # correspond to grid corners instead of grid centers. + # This is to ensure that pcolormesh centers correctly and + # does not ignore edge points. + args = [] + for k, arr in enumerate(args_masked[:-1]): + # If a two-dimensional array is provided, only consider the + # first row/column, depending on the axis + if arr.ndim > 1: + arr = arr[0] if k == 0 else arr[:,0] + + if np.ma.is_masked(arr[1]): + # Only the first element is not nan, in this case pad with + # a value, and separate their values by 1 + arr_pad = np.pad(arr, (1, 0), mode='symmetric') + arr_pad[:2] += [-0.5, 0.5] + else: + # Add padding on both sides equal to endpoints + arr_pad = np.pad(arr, (1, 1), mode='symmetric') + # Add differences to edgepoints (may be nan) + arr_pad[0] += arr_pad[1] - arr_pad[2] + arr_pad[-1] += arr_pad[-2] - arr_pad[-3] + + diff = np.ma.diff(arr_pad) / 2 + # Insert value at beginning and end of diff to ensure same + # length + diff = np.insert(diff, 0, diff[0]) + + arr_pad += diff + # Ignore final value + arr_pad = arr_pad[:-1] + args.append(masked_invalid(arr_pad)) + args.append(args_masked[-1]) + else: + # Only the masked value of z is used as a mask + args = args_masked[-1:] + pc = ax.pcolormesh(*args, **kwargs) + # Set x and y limits if arrays are provided + if x is not None and y is not None: + ax.set_xlim(np.nanmin(args[0]), np.nanmax(args[0])) + ax.set_ylim(np.nanmin(args[1]), np.nanmax(args[1])) + + # Specify preferred number of ticks with labels + if nticks and ax.get_xscale() != 'log' and ax.get_yscale != 'log': + ax.locator_params(nbins=nticks) + if getattr(ax, 'qcodes_colorbar', None): # update_normal doesn't seem to work... ax.qcodes_colorbar.update_bruteforce(pc) @@ -255,6 +375,11 @@ def _draw_pcolormesh(self, ax, z, x=None, y=None, subplot=1, label = "{} ({})".format(zlabel, zunit) ax.qcodes_colorbar.set_label(label) + # Scale colors if z has elements + cmin = np.nanmin(args_masked[-1]) + cmax = np.nanmax(args_masked[-1]) + ax.qcodes_colorbar.set_clim(cmin, cmax) + return pc def save(self, filename=None): @@ -269,3 +394,10 @@ def save(self, filename=None): default = "{}.png".format(self.get_default_title()) filename = filename or default self.fig.savefig(filename) + + def tight_layout(self): + """ + Perform a tight layout on the figure. A bit of additional spacing at + the top is also added for the title. + """ + self.fig.tight_layout(rect=[0, 0, 1, 0.95]) \ No newline at end of file