Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add waffle representation & partial plot #28

Merged
merged 9 commits into from
May 25, 2021
Merged
286 changes: 263 additions & 23 deletions src/openalea/strawberry/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
from openalea.mtg.algo import orders
from openalea.mtg import stat, algo, traversal

from pandas.core.groupby.groupby import DataError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put first import of standard Python libraries.
Then 3rd party libs.
Then openalea ones (mtg & co)


import numpy as np
from matplotlib.colors import to_rgb
import matplotlib.patches as mpatches
import plotly.express as px
import plotly.graph_objs as go


import six
from six.moves import map
from six.moves import range
Expand Down Expand Up @@ -364,7 +373,8 @@ def _module_variables(g):
module_variables['type_of_crown'] = type_of_crown # Type de crowns (Primary Crown:1, Branch crown:2 extension crown:3)
module_variables['crown_status'] = Crown_status
module_variables['complete_module'] = complete #(True: complete, False: incomplete)

module_variables['stage']= stage

return module_variables


Expand Down Expand Up @@ -599,6 +609,12 @@ def mean_leaf_area(vid,g):
def complete(vid, g):
return g.property("complete").get(vid, False)


def stage(vid, g):
_stage = g.property('Stade')
return next(iter(list(_stage[cid] for cid in g.components(vid) if cid in _stage)), None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not get it.
Why not just take the first element?
You do not need to convert it to a list.
And you can use components_iter



########################## Extraction on node scale ############################################

def extract_at_node_scale(g, vids=[], convert=convert):
Expand All @@ -612,7 +628,7 @@ def extract_at_node_scale(g, vids=[], convert=convert):
orders = algo.orders(g,scale=2)

# Define all the rows
props = ['node_id', 'rank', 'branching_type', 'complete','nb_modules_branching','nb_branch_crown_branching','nb_extension_crown_branching','branching_length', 'Genotype', 'order', 'date','plant']
props = ['node_id', 'rank', 'branching_type', 'complete','nb_modules_branching','nb_branch_crown_branching','nb_extension_crown_branching','branching_length', 'stage', 'Genotype', 'order', 'date','plant']
for prop in props:
node_df[prop] = []

Expand All @@ -635,6 +651,7 @@ def extract_at_node_scale(g, vids=[], convert=convert):
node_df['order'].append(orders[g.complex(vid)]) #scale=2
node_df['plant'].append(plant(vid, g)) #scale=1
node_df['date'].append(date(vid, g)) #scale=1
node_df['stage'].append(stage(vid, g)) # scale=3

df = pd.DataFrame(node_df)

Expand Down Expand Up @@ -795,29 +812,252 @@ def nb_visible_leaves_tree(v, g):
else:
return sum(nb_visible_leaves(m,g) for m in module_tree(v, g))

def genotype(vid, g):
cpx = g.complex_at_scale(vid, scale=1)
_genotype = property(g, 'Genotype')[cpx]
return _genotype

def stage_tree(vid, g):
return list(stage(m,g) for m in module_tree(v, g))

def plant(vid, g):
cpx = g.complex_at_scale(vid, scale=1)
return property(g, 'Plante')[cpx]

def date(vid, g):
# Capriss: 1:'2014/12/10', 2:'2015/01/07',3:'2015/02/15',4:'2015/03/02',5:'2015/04/03',6:'2015/05/27
# Ciflorette: 1:'2014/12/04',2:'2015/01/07',3:'2015/02/13',4:'2015/03/02',5:'2015/03/30',6:'2015/05/27'
# Cir107: 1:'2014/12/10',2:'2015/01/08',3:'2015/02/11',4:'2015/03/04',5:'2015/04/02',6:'2015/05/20'
# Clery: 1:'2014/12/10', 2:'2015/01/07',3:'2015/02/15',4:'2015/03/02',5:'2015/04/03',6:'2015/05/27'
# Darselect: 1:'2014/12/10', 2:'2015/01/09', 3:'2015/02/11', 4:'2015/03/06',5:'2015/04/03',6:'2015/05/20'
# Gariguette: 1:'2014/12/10', 2:'2015/01/08', 3:'2015/02/12',4:'2015/03/06',5:'2015/04/02',6:'2015/05/19'
# d = {'2014/12/10':1,'2015/01/07':2,'2015/02/15':3,'2015/03/02':4,'2015/04/03':5,'2015/05/27':6,
# '2014/12/04':1,'2015/02/13':3,'2015/03/30':5,
# '2015/01/08':2,'2015/02/11':3,'2015/03/04':4,'2015/04/02':5,'2015/05/20':6,
# '2015/01/09':2,'2015/02/12':3,'2015/03/06':4,'2015/05/19':6}
cpx = g.complex_at_scale(vid, scale=1)
# _date = g.property('Sample_date')[cpx]
return g.property('Sample_date')[cpx]
######################### Transformation of dataframe ######################################
def df2waffle(df, date, index, variable, order=None, aggfunc=None, crosstab=None, *args, **kwargs):
'''
Transpose dataframe by variable with plant in columns and rank or order in index
This function are available for extraction at node scale (index='rank') and
extraction at module scale (index= 'order')
Parameters:
-----------
df: dataframe from extract function at differente scale (modules and nodes scale)
date_selected: date which must be processed
variable: variable which must be processed

Returns:
--------
a dataframe in "waffle" shape: index=date, & columns=variable
'''

if order:
data=df[(df['date']==date) & (df['order']==order)]
else:
data=df[df['date']==date]

if index=='rank':
res = data.pivot(index='rank',columns='plant',values=variable)
elif index=='order':
if crosstab:
res = pd.crosstab(index=data['order'], columns=data[variable], normalize='index')
res=res*100
res = res.round(2)
else:
# Catch data error: when values are string and aggfunc compute numbers
try:
res= data.pivot_table(index='order',columns='plant',values=variable, aggfunc=aggfunc)
except DataError:
print("ERROR, the aggregate function does not handle the data type (float func on str?)")
return pd.DataFrame()

else:
res = data.pivot(index=index,columns='plant',values=variable)

# If use plotly heatmap -> comment "res = res.fillna('')"
if res.isnull().values.any():
res = res.fillna('')
res = res.sort_index(ascending=False)
return res



def plot_waffle_plotly_heatmap(df, layout={}, legend_name={}):

def df_to_plotly(df):
return {'z': df.values.tolist(),
'x': df.columns.tolist(),
'y': df.index.tolist()}

height = layout.get('height', 500)
width = layout.get('width', 500)
xlabel = layout.get('xlabel', 'Plant')
xticks = layout.get('xticks', range(0,len(df.columns)))
xticks_label = layout.get('xticks_label', list(df.columns))
ylabel = layout.get('ylabel', '')
yticks = layout.get('yticks', [l-1 for l in list(df.index)])
yticks_label = layout.get('yticks_label', list(range(0,len(df.index))))
title = layout.get('title', '')

hm_layout = go.Layout(plot_bgcolor='rgba(0,0,0,0)',
# xaxis=dict(zeroline=False),
# yaxis=dict(zeroline=False, ),
autosize=False,
width=width, height=height
)

data = go.Heatmap(df_to_plotly(df),
xgap=1,
ygap=1,
colorscale="aggrnyl"
)

fig = go.Figure(data=data, layout=hm_layout)

return fig


def plot_waffle_plotly_imshow(df, layout={}, legend_name={}):
colormap_used = plt.cm.coolwarm

values = list(set(df.values.flatten()))
if '' in values:
values.remove('')
try:
values.sort()
except TypeError:
values = [str(i) for i in values]
values.sort()
values.insert(0,'')

color_map = {val: colormap_used(i/len(values)) for i, val in enumerate(values)}

# Add the "empty" variable - and set its color as white
color_map[''] = (1., 1., 1., 1.)

data = np.array(df)

# Create an array where each cell is a colormap value RGBA
data_3d = np.ndarray(shape=(data.shape[0], data.shape[1], 4), dtype=float)
for i in range(0, data.shape[0]):
for j in range(0, data.shape[1]):
data_3d[i][j] = color_map[data[i][j]]

# drop the A
data_3d_rgb = np.array([[to_rgb([v for v in row]) for row in col] for col in data_3d], dtype=np.float64)

yticks = list(range(0,data.shape[0]))
yticks.reverse()

fig = px.imshow(data,
labels={'x':'Plant', 'y':'Node'},
x=list(range(1,data.shape[1]+1)),
y=yticks,
origin='lower',
color_continuous_scale='aggrnyl',
# colorbar={}
)
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)',
)
return fig


def plot_waffle_matplotlib(df, layout={}, legend_name={}):
height = layout.get('height', 500)
width = layout.get('width', 500)
xlabel = layout.get('xlabel', 'Plant')
xticks = layout.get('xticks', range(0,len(df.columns)))
xticks_label = layout.get('xticks_label', list(df.columns))
ylabel = layout.get('ylabel', '')
yticks = layout.get('yticks', [l-1 for l in list(df.index)])
yticks_label = layout.get('yticks_label', list(range(0,len(df.index))))
title = layout.get('title', '')

colormap_used = plt.cm.coolwarm

# Sort the variables. When variables are int or float, remove the str('') (that replaced the NaN) before sorting
values = list(set(df.values.flatten()))
if '' in values:
values.remove('')
try:
values.sort()
except TypeError:
values = [str(i) for i in values]
values.sort()
values.insert(0,'')

w_height = len(df.index)
w_width = len(df.columns)
color_map = {val: colormap_used(i/len(values)) for i, val in enumerate(values)}

# Add the "empty" variable - and set its color as white
color_map[''] = (1., 1., 1., 1.)

data = np.array(df)

# Create an array where each cell is a colormap value RGBA
data_3d = np.ndarray(shape=(data.shape[0], data.shape[1], 4), dtype=float)
for i in range(0, data.shape[0]):
for j in range(0, data.shape[1]):
data_3d[i][j] = color_map[data[i][j]]

# display the plot
fig, ax = plt.subplots(1,1)
fig.set_size_inches(height, width)
fig = ax.imshow(data_3d)

# Get the axis.
ax = plt.gca()

# Minor ticks
ax.set_xticks(np.arange(-.5, (w_width), 1), minor=True);
ax.set_yticks(np.arange(-.5, (w_height), 1), minor=True);

# Gridlines based on minor ticks
ax.grid(which='minor', color='w', linestyle='-', linewidth=2)

# Manually constructing a legend solves your "catagorical" problem.
legend_handles = []

for i, val in enumerate(values):
if val!= "":
color_val = color_map[val]
legend_handles.append(mpatches.Patch(color=color_val, label=legend_name.get(val, val)))

# Add the legend.
plt.legend(handles=legend_handles, loc=(1,0))
plt.xlabel(xlabel)
plt.ylabel(ylabel)

plt.xticks(ticks=xticks, labels=xticks_label)
plt.yticks(ticks=yticks, labels=yticks_label)

plt.title(title)

plt.show()

return fig


def plot_waffle(df, layout={}, legend_name={}, savepath=None, plot_func='matplotlib'):
"""
Plot a dataframe in "waffle" shape

layout: dict of layout parameters:
height/width: size of the picture in inch
x/ylabel: label of the x/y axis
x/yticks: ticks of the x/y axis
x/yticks_labels: labels of the ticks on the x/y axis
title: title
plot_func: library used for the ploting:
matplotlib: matplotlib.pyplot.subplot.imshow
plotly.imshow: plotly.express.imshow
plotly.heatmap: plotly.graph_objs.heatmap
"""

## Axes not working - Plotly heatmap
if plot_func=='plotly.heatmap':
fig= plot_waffle_plotly_heatmap(df=df, layout=layout, legend_name=legend_name)

# Plotly imshow
elif plot_func=='plotly.imshow':
fig= plot_waffle_plotly_imshow(df=df, layout=layout, legend_name=legend_name)

# With matplotlib
elif plot_func=='matplotlib':
try:
fig= plot_waffle_matplotlib(df=df, layout=layout, legend_name=legend_name)
except ValueError:
fig={}

if savepath:
plt.savefig(savepath)

return fig


def plot_pie(df):
return px.pie(df, values=df.mean(axis=0), names=df.columns)
5 changes: 0 additions & 5 deletions src/openalea/strawberry/visu3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,3 @@ def color_code(g):
nid.color = (0, int(127+127/(len(stades)-1)*(i)),255)
else:
nid.color = (153, 102, 51)


# 2D visualization
#############################################################################
# TODO: add visualization functions
26 changes: 26 additions & 0 deletions test/test_waffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path
import os
from openalea.mtg.io import read_mtg_file, write_mtg
from openalea.strawberry.analysis import extract_at_node_scale, extract_at_module_scale
from openalea.deploy.shared_data import shared_data
import openalea.strawberry

from openalea.strawberry.analysis import df2waffle

def name(f):
return f.basename().splitext()[0]

def test_df2waffle():
files = shared_data(openalea.strawberry).glob('*.mtg')
mtg_path = dict((name(f), f) for f in files)
mtg = read_mtg_file(mtg_path['Capriss'])

df = extract_at_node_scale(mtg)

node_scale = df2waffle(df, index='rank', date='2015/03/02', variable='branching_type')
assert node_scale.shape == (20, 9)

df = extract_at_module_scale(mtg)
module_scale=df2waffle(df, index='order', date='2015/03/02', variable='crown_status', aggfunc='median')
assert module_scale.shape == (3, 9)