Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass pair plot scatter_kwargs to plt.scatter(), not plt.plot() #1889

Closed
grahamgower opened this issue Oct 15, 2021 · 14 comments
Closed

pass pair plot scatter_kwargs to plt.scatter(), not plt.plot() #1889

grahamgower opened this issue Oct 15, 2021 · 14 comments

Comments

@grahamgower
Copy link

grahamgower commented Oct 15, 2021

I want to do a pair plot with a colour map, where the colour of the points indicates the posterior probability. I expected to just be able to pass scatter_kwargs={"c": probs} to arviz.plot_pair(), but it seems the matplotlib backend uses plt.plot() instead of plt.scatter(), and plt.plot() doesn't seem to support this (I think?). Also, it's kinda confusing that this is named scatter_kwargs, when it doesn't use the scatter function.

@tattvams
Copy link
Contributor

Hey can you please elaborate this problem and what exactly is that you are trying to achieve and can you post any screenshots of the problem.

@OriolAbril
Copy link
Member

OriolAbril commented Feb 27, 2022

iiuc, the main goal is being able to use the c argument in scatter which is similar to color in plot but a bit more flexible.

That would mean updating the docs at

Additional keywords passed to :meth:`matplotlib.axes.Axes.plot` when using scatter kind
, the dealiasing actually assumes this is passed down to scatter so bonus points for fixing a bug we didn't know we had 😅 and then it should be passed to ax.scatter instead of ax.plot in https://github.com/arviz-devs/arviz/blob/main/arviz/plots/backends/matplotlib/pairplot.py#L177 and https://github.com/arviz-devs/arviz/blob/main/arviz/plots/backends/matplotlib/pairplot.py#L177

Side note, if you are not comfortable with git merging and rebasing, it's probably a good idea to wait for #1985 to be merged before starting to work on this.

@tattvams
Copy link
Contributor

tattvams commented Mar 1, 2022

Hey, so I looked a bit into this and c argument in scatter is working for example in this code :

import matplotlib.pyplot as plt

import arviz as az

az.style.use("arviz-darkgrid")

centered = az.load_arviz_data("centered_eight")

coords = {"school": ["Choate", "Deerfield"]}
colors = [0.3,0.8,0.7]
az.plot_pair(
    centered, var_names=["theta", "mu", "tau"], coords=coords, divergences=True, textsize=22, scatter_kwargs={"c": colors}
)
plt.show()

image

And the reason why dealisaing is assuming that this is passed down to scatter is because scatter_kwargs returns kind scatter
which i am assuming should be the same (correct me if I am wrong and if i should change it). I did commit some changes for ax.plot -> ax.scatter. Let me know if I understood something wrong and should change it

@OriolAbril
Copy link
Member

The scatter_kwargs with a c key works after your changes or does so already in the current code?

And the reason why dealisaing is assuming that this is passed down to scatter is because scatter_kwargs returns kind scatter
which i am assuming should be the same (correct me if I am wrong and if i should change it).

The dealiasing should not be changed. It will be correct once the issue is fixed, but is currently incorrect because it uses "scatter" as the second argument instead of "plot", but those arguments are incorrectly passed to an ax.plot instance right now.

@tattvams
Copy link
Contributor

tattvams commented Mar 2, 2022

It worked without my changes and isn't dealiasing using scatter keyword only when scatter_kwargs .... So is there no other keyword like plot_kwargs ?

@grahamgower
Copy link
Author

grahamgower commented Mar 2, 2022

Hi @tattvams,

It doesn't currently work. Let me illustrate. The following code generates random x and y coordinates, and also some z value based on the x and y coordinates. Matplotlib's scatter() function can plot the value of this z variable using a colour map by passing it in the c= parameter. In the plot, each dot has a different colour, according to the specific z value.

import matplotlib.pyplot as plt
import numpy as np

rng = np.random.default_rng(1234)
x, y = rng.uniform(0, 10, size=(2, 100))
z = x + y
plt.scatter(x, y, c=z)
plt.show()

Figure_1

If you change the plt.scatter to instead be plt.plot, you will recieve the following error message.

Traceback (most recent call last):
  File "/tmp/scatter.py", line 7, in <module>
    plt.plot(x, y, c=z)
  File "/usr/lib/python3.10/site-packages/matplotlib/pyplot.py", line 2757, in plot
    return gca().plot(
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_axes.py", line 1632, in plot
    lines = [*self._get_lines(*args, data=data, **kwargs)]
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 312, in __call__
    yield from self._plot_args(this, kwargs)
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 538, in _plot_args
    return [l[0] for l in result]
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 538, in <listcomp>
    return [l[0] for l in result]
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 531, in <genexpr>
    result = (make_artist(x[:, j % ncx], y[:, j % ncy], kw,
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 351, in _makeline
    seg = mlines.Line2D(x, y, **kw)
  File "/usr/lib/python3.10/site-packages/matplotlib/lines.py", line 370, in __init__
    self.set_color(color)
  File "/usr/lib/python3.10/site-packages/matplotlib/lines.py", line 1030, in set_color
    mcolors._check_color_like(color=color)
  File "/usr/lib/python3.10/site-packages/matplotlib/colors.py", line 130, in _check_color_like
    raise ValueError(f"{v!r} is not a valid value for {k}")
ValueError: array([12.74539857,  9.99356751, 12.25003663,  4.36538875,  8.14602254,
        8.78661633,  9.2902787 ,  5.66688708, 15.8612584 ,  5.79305495,
        7.18222539, 15.3973602 , 11.43216566, 18.12592686,  8.76898665,
       12.25352763, 11.16060829, 10.30860385,  2.99619853, 18.21754524,
        7.88966015,  8.59233907, 13.63997113,  7.03851005,  6.16538883,
       10.80263408,  4.82644868,  9.25285528,  6.94186349,  9.78244537,
       15.27522036, 10.16853085, 16.14819546, 14.5970087 ,  2.51296569,
       11.33555053, 11.44611872, 17.38216807,  0.93636827, 14.14022259,
       14.84411785, 11.91736832, 10.21676171, 15.65048893, 13.45479257,
        6.22682582,  4.33358728, 11.47217771,  1.24086803,  4.64838705,
       10.5199589 , 13.35589471,  7.31870321, 11.48630773, 18.42065178,
       14.74225064, 13.98589875,  5.16273114,  7.96485267,  9.7928    ,
        4.15582   , 12.64464648, 17.21968744, 10.69554244, 10.16331922,
       13.92454338,  9.35741315,  4.71664657,  9.48998161,  8.4562641 ,
       10.18939222,  0.98601694,  7.90308634,  8.90914178, 13.07089323,
        0.67881647, 10.5965448 ,  8.31221063, 13.3089983 ,  7.58058182,
        9.43381122,  9.22571312, 12.49790034, 11.15509301, 10.46866392,
        8.63603748, 12.86335841,  7.78792664, 11.90037994, 10.41934229,
       13.32457332, 10.19184625,  7.07384512,  0.42243579,  6.8327167 ,
       10.90716744,  5.70878519, 15.64657765, 18.88753084,  3.33478009]) is not a valid value for color

Your example above with colors = [0.3,0.8,0.7] doesn't produce an error, because [0.3, 0.8, 0.7] are interpreted as RGB values. In your example, all the points are given that one colour.

@tattvams
Copy link
Contributor

tattvams commented Mar 2, 2022

Hi @tattvams,

It doesn't currently work. Let me illustrate. The following code generates random x and y coordinates, and also some z value based on the x and y coordinates. Matplotlib's scatter() function can plot the value of this z variable using a colour map by passing it in the c= parameter. In the plot, each dot has a different colour, according to the specific z value.

import matplotlib.pyplot as plt
import numpy as np

rng = np.random.default_rng(1234)
x, y = rng.uniform(0, 10, size=(2, 100))
z = x + y
plt.scatter(x, y, c=z)
plt.show()

Figure_1

If you change the plt.scatter to instead be plt.plot, you will recieve the following error message.

Traceback (most recent call last):
  File "/tmp/scatter.py", line 7, in <module>
    plt.plot(x, y, c=z)
  File "/usr/lib/python3.10/site-packages/matplotlib/pyplot.py", line 2757, in plot
    return gca().plot(
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_axes.py", line 1632, in plot
    lines = [*self._get_lines(*args, data=data, **kwargs)]
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 312, in __call__
    yield from self._plot_args(this, kwargs)
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 538, in _plot_args
    return [l[0] for l in result]
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 538, in <listcomp>
    return [l[0] for l in result]
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 531, in <genexpr>
    result = (make_artist(x[:, j % ncx], y[:, j % ncy], kw,
  File "/usr/lib/python3.10/site-packages/matplotlib/axes/_base.py", line 351, in _makeline
    seg = mlines.Line2D(x, y, **kw)
  File "/usr/lib/python3.10/site-packages/matplotlib/lines.py", line 370, in __init__
    self.set_color(color)
  File "/usr/lib/python3.10/site-packages/matplotlib/lines.py", line 1030, in set_color
    mcolors._check_color_like(color=color)
  File "/usr/lib/python3.10/site-packages/matplotlib/colors.py", line 130, in _check_color_like
    raise ValueError(f"{v!r} is not a valid value for {k}")
ValueError: array([12.74539857,  9.99356751, 12.25003663,  4.36538875,  8.14602254,
        8.78661633,  9.2902787 ,  5.66688708, 15.8612584 ,  5.79305495,
        7.18222539, 15.3973602 , 11.43216566, 18.12592686,  8.76898665,
       12.25352763, 11.16060829, 10.30860385,  2.99619853, 18.21754524,
        7.88966015,  8.59233907, 13.63997113,  7.03851005,  6.16538883,
       10.80263408,  4.82644868,  9.25285528,  6.94186349,  9.78244537,
       15.27522036, 10.16853085, 16.14819546, 14.5970087 ,  2.51296569,
       11.33555053, 11.44611872, 17.38216807,  0.93636827, 14.14022259,
       14.84411785, 11.91736832, 10.21676171, 15.65048893, 13.45479257,
        6.22682582,  4.33358728, 11.47217771,  1.24086803,  4.64838705,
       10.5199589 , 13.35589471,  7.31870321, 11.48630773, 18.42065178,
       14.74225064, 13.98589875,  5.16273114,  7.96485267,  9.7928    ,
        4.15582   , 12.64464648, 17.21968744, 10.69554244, 10.16331922,
       13.92454338,  9.35741315,  4.71664657,  9.48998161,  8.4562641 ,
       10.18939222,  0.98601694,  7.90308634,  8.90914178, 13.07089323,
        0.67881647, 10.5965448 ,  8.31221063, 13.3089983 ,  7.58058182,
        9.43381122,  9.22571312, 12.49790034, 11.15509301, 10.46866392,
        8.63603748, 12.86335841,  7.78792664, 11.90037994, 10.41934229,
       13.32457332, 10.19184625,  7.07384512,  0.42243579,  6.8327167 ,
       10.90716744,  5.70878519, 15.64657765, 18.88753084,  3.33478009]) is not a valid value for color

Your example above with colors = [0.3,0.8,0.7] doesn't produce an error, because [0.3, 0.8, 0.7] are interpreted as RGB values. In your example, all the points are given that one colour.

Oh ok, I guess then there should be a problem with matplotlib backend because in "c" keyword is used and then passed to scatter which returns PathCollection

@OriolAbril
Copy link
Member

OriolAbril commented Mar 2, 2022

It is not an issue with matplotlib. plot and scatter have different goals and features, being c one of them.

In scatter we want to plot a collection of dots, not necessarly ordered, where the dots themselves might contain more information than merely their position. This information can be encoded in the size or color of the dots which are therefore independent properties between dots. They are accessed with s and c arguments.

In plot we want to plot an ordered sequence of values, so it defaults to a line plot. It is customizable enough for us to use markers and no lines and generate plots that look like scatter plots, but that is not the goal it's more of a side effect. plot should be used if we care about the position of the dots and the relation to the previous and next dots/values. It has a color and markersize (or linewidth) that are applied to all dots uniformly.

Now the caveats:

  1. Aliases For convenience, matplotlib supports aliases to many of its keyword arguments. We can then use lw as an alias of linewidth, ms as markersize, mec as markeredgecolor...
  2. Aliases are function dependent. The aliases I added above are valid for plot (or Line2D, where they are passed to eventually), but which aliases exist depends on which function we are calling. If using fill_between instead of plot, in addition to lw and linewidth, linewidths is also accepted.
  3. Aliases are function dependent to the point they can be different parameters. In plot, c is an alias for color and it gets interpreted as the single color to uniformily apply to all dots. In scatter c is its own parameter with a different behaviour, but not only that, color is not an alias to c in scatter, it's a different and valid parameter with the same behaviour as in plot.

I hope that provides more context on why the dealiasing was wrong and why using plot instead of scatter is an issue here. Both issues on ArviZ side

@tattvams
Copy link
Contributor

tattvams commented Mar 2, 2022

hey, by backend I didn't mean matplotlib what i meant was maybe there was a reason that c didn't accept any values other than rgb because it may have failed to add cmap functionality which i understand is not the case and that it went to plt.plot() instead of plt.scatter() and can you send me source code where .plot() function is defined because I am unable to find it. Thanks!

@OriolAbril
Copy link
Member

@ahartikainen
Copy link
Contributor

I think originally we went with plot interface was due to speed (scatter was slower for a huge number of points).

I don't know what is the current standing, but I think going with .scatter solution is fine.

Here is source for .scatter

https://github.com/matplotlib/matplotlib/blob/f33f5ab12142bcef39e1f8cf5c8fb17d5cce0057/lib/matplotlib/axes/_axes.py#L4278

and here source for .plot

https://github.com/matplotlib/matplotlib/blob/f33f5ab12142bcef39e1f8cf5c8fb17d5cce0057/lib/matplotlib/axes/_axes.py#L1396

@grahamgower
Copy link
Author

Maybe setting rasterize=True improves the speed when there are many points?

The main reason I didn't immediately send a patch fixing this in arviz is because I wasn't sure if there were backwards-compatibilitiy concerns over such a change. Figure size/speed issues would also be a bad behaviour change.

@ahartikainen
Copy link
Contributor

We could probably add that as an default.

Tbh I think scatter is fast enough these days. Speed is always an issue when handling large datasets, but in most cases the datasets are not that large.

@OriolAbril
Copy link
Member

closed by #2069

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants