Skip to content

Commit

Permalink
Add draw_interactive() method
Browse files Browse the repository at this point in the history
  • Loading branch information
efekhari27 committed Jul 3, 2023
1 parent 07271ad commit 68357fa
Show file tree
Hide file tree
Showing 7 changed files with 482 additions and 399 deletions.
136 changes: 121 additions & 15 deletions copulogram/Copulogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,27 @@
@author: Elias Fekhari
"""


#%%
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import rc
import matplotlib.pyplot as plt

# Interactive imports
from itertools import product
from bokeh.io import show
from bokeh.layouts import gridplot
from bokeh.models import (BasicTicker, Circle, ColumnDataSource,
DataRange1d, Grid, LassoSelectTool, LinearAxis,
Plot, ResetTool)
from bokeh.transform import factor_cmap, linear_cmap
from bokeh.palettes import Category10, Viridis


class Copulogram:
"""
Interactive plot for multivariate distributions.
Draws a plot for multivariate distributions.
The lower triangle is a matrixplot of the data (without transformation),
while the upper triangle is a matrixplot of the ranked data.
Expand Down Expand Up @@ -143,28 +154,123 @@ def draw(self,
return copulogram

def draw_interactive(self,
kde_on_marginals=True,
quantile_contour_levels=None):
color='navy',
alpha=1.,
hue=None,
hue_palette=None,
marker='o',
subplot_size=5
):
"""
Draws the copulogram plot with a static or interactive option.
Parameters
----------
kde_on_marginals : Bool
Defines the type of plot on the diagonal. Histogram when
the variable is set to False, kernel density estimation otherwise.
quantile_contour_levels : 1-d list of floats
When the variable takes a value, the contours of the quantiles
defined by the variable are plotted.
TBD
Returns
-------
copulogram : TBD
"""

df = self.data.copy(deep=True)
df_numeric = df._get_numeric_data()
rdf = df_numeric.rank()

plotted_cols = np.array(df_numeric.columns)
plotted_cols = np.delete(plotted_cols, np.where(plotted_cols == hue)).tolist()
dim = len(plotted_cols)

if hue is not None:
df_numeric[hue] = df[hue]
rdf[hue] = df[hue]
source = ColumnDataSource(data=df_numeric)
rsource = ColumnDataSource(data=rdf)

plot_list = []
for i, (y, x) in enumerate(product(plotted_cols, plotted_cols)):
# Scatter plot
if hue is None:
scatter_color = color
else:
if df[hue].dtype =='O':
if hue_palette is None:
hue_palette = "Category10_3"
scatter_color = factor_cmap(hue, hue_palette, sorted(df[hue].unique()))
else:
if hue_palette is None:
hue_palette="Spectral6"
scatter_color = linear_cmap(hue, hue_palette, low=df[hue].min(), high=df[hue].max())


circle = Circle(x=x, y=y, fill_alpha=alpha, size=5, line_color=None,
fill_color=scatter_color)
# Lower triangle
if (i%dim) <= (i//dim): # Column index smaller than row index (i.e., lower triangle)
# Define one empty plot
p = Plot(x_range=DataRange1d((df[x].min(), df[x].max())), y_range=DataRange1d((df[y].min(), df[y].max())),
background_fill_color="#fafafa",
border_fill_color="white", width=200, height=200, min_border=subplot_size)
r = p.add_glyph(source, circle)
# Delete diagonal plot
if (i%dim) == (i//dim):
r.visible = False
p.grid.grid_line_color = None
# Upper triangle
elif (i%dim) > (i//dim):
# Define one empty plot
p = Plot(x_range=DataRange1d((rdf[x].min(), rdf[x].max())), y_range=DataRange1d((rdf[y].min(), rdf[y].max())),
background_fill_color="#fafafa",
border_fill_color="white", width=200, height=200, min_border=5)
r = p.add_glyph(rsource, circle)
p.x_range.renderers.append(r)
p.y_range.renderers.append(r)
# First column ticks
if i % dim == 0:
p.min_border_left = p.min_border + 4
p.width += 40
yaxis = LinearAxis(axis_label=y)
yaxis.major_label_orientation = "vertical"
p.add_layout(yaxis, "left")
yticker = yaxis.ticker
else:
yticker = BasicTicker()
p.add_layout(Grid(dimension=1, ticker=yticker))

# Last row ticks
if i >= dim * (dim-1):
p.min_border_bottom = p.min_border + 40
p.height += 40
xaxis = LinearAxis(axis_label=x)
p.add_layout(xaxis, "below")
xticker = xaxis.ticker
else:
xticker = BasicTicker()
p.add_layout(Grid(dimension=0, ticker=xticker))
p.add_tools(LassoSelectTool(), ResetTool())
plot_list.append(p)

gridp = gridplot(plot_list, ncols=dim)
show(gridp)

return None

## TODO
# Include contours
# Add interactive aspect
#
##TODO:
# Add docstrings
# Remove the misleading yticks from the top left plot? Ideally we should add the index ticks on the top left of the plot
# Add color bar on the interactive method using : https://docs.bokeh.org/en/latest/docs/examples/basic/data/color_mappers.html


#%%
if __name__ == "__main__":
#data = sns.load_dataset('iris')
import pandas as pd
data = pd.read_csv("../examples/data/wind_waves_ANEMOC_1H.csv", index_col=0)
data = data.iloc[:1000]

output = data["U_hub (m/s)"] ** 3 * ((np.pi / 180) * data["θ_wind (deg)"]) + (data["Hs (m)"] ** 2 * data["Tp (s)"]) / ((np.pi / 180) * data["θ_wave_new (deg)"])
data['output'] = np.log10(output)

copulogram = Copulogram(data)
copulogram.draw_interactive(hue="output")
# %%
Loading

0 comments on commit 68357fa

Please sign in to comment.