Skip to content

Commit

Permalink
Merge pull request #28 from openalea-incubator/waffle
Browse files Browse the repository at this point in the history
Add waffle representation & partial plot
  • Loading branch information
pomme-abricot authored May 25, 2021
2 parents 8fe8f6a + 71d56b1 commit 7150ae4
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 30 deletions.
289 changes: 264 additions & 25 deletions src/openalea/strawberry/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
import matplotlib.pyplot as plt
from itertools import chain

from openalea.mtg.algo import orders
from openalea.mtg import stat, algo, traversal
from pandas.core.groupby.groupby import DataError

import numpy as np
from matplotlib.colors import to_rgb
import matplotlib.patches as mpatches
import plotly.express as px
import plotly.graph_objs as go

import six
from six.moves import map
from six.moves import range

from openalea.mtg.algo import orders
from openalea.mtg import stat, algo, traversal


convert = dict(Stade='Stade',
Fleurs_ouverte='FLWRNUMBER_OPEN',
Expand Down Expand Up @@ -364,7 +372,8 @@ def _module_variables(g):
module_variables['type_of_crown'] = type_of_crown # Type de crowns (Primary Crown:1, Branch crown:2 extension crown:3)
module_variables['crown_status'] = Crown_status
module_variables['complete_module'] = complete #(True: complete, False: incomplete)

module_variables['stage']= stage

return module_variables


Expand Down Expand Up @@ -599,6 +608,12 @@ def mean_leaf_area(vid,g):
def complete(vid, g):
return g.property("complete").get(vid, False)


def stage(vid, g):
_stage = g.property('Stade')
return next((_stage[cid] for cid in g.components(vid) if cid in _stage), None)


########################## Extraction on node scale ############################################

def extract_at_node_scale(g, vids=[], convert=convert):
Expand All @@ -612,7 +627,7 @@ def extract_at_node_scale(g, vids=[], convert=convert):
orders = algo.orders(g,scale=2)

# Define all the rows
props = ['node_id', 'rank', 'branching_type', 'complete','nb_modules_branching','nb_branch_crown_branching','nb_extension_crown_branching','branching_length', 'Genotype', 'order', 'date','plant']
props = ['node_id', 'rank', 'branching_type', 'complete','nb_modules_branching','nb_branch_crown_branching','nb_extension_crown_branching','branching_length', 'stage', 'Genotype', 'order', 'date','plant']
for prop in props:
node_df[prop] = []

Expand All @@ -635,6 +650,7 @@ def extract_at_node_scale(g, vids=[], convert=convert):
node_df['order'].append(orders[g.complex(vid)]) #scale=2
node_df['plant'].append(plant(vid, g)) #scale=1
node_df['date'].append(date(vid, g)) #scale=1
node_df['stage'].append(stage(vid, g)) # scale=3

df = pd.DataFrame(node_df)

Expand Down Expand Up @@ -795,29 +811,252 @@ def nb_visible_leaves_tree(v, g):
else:
return sum(nb_visible_leaves(m,g) for m in module_tree(v, g))

def genotype(vid, g):
cpx = g.complex_at_scale(vid, scale=1)
_genotype = property(g, 'Genotype')[cpx]
return _genotype

def stage_tree(vid, g):
return list(stage(m,g) for m in module_tree(v, g))

def plant(vid, g):
cpx = g.complex_at_scale(vid, scale=1)
return property(g, 'Plante')[cpx]

def date(vid, g):
# Capriss: 1:'2014/12/10', 2:'2015/01/07',3:'2015/02/15',4:'2015/03/02',5:'2015/04/03',6:'2015/05/27
# Ciflorette: 1:'2014/12/04',2:'2015/01/07',3:'2015/02/13',4:'2015/03/02',5:'2015/03/30',6:'2015/05/27'
# Cir107: 1:'2014/12/10',2:'2015/01/08',3:'2015/02/11',4:'2015/03/04',5:'2015/04/02',6:'2015/05/20'
# Clery: 1:'2014/12/10', 2:'2015/01/07',3:'2015/02/15',4:'2015/03/02',5:'2015/04/03',6:'2015/05/27'
# Darselect: 1:'2014/12/10', 2:'2015/01/09', 3:'2015/02/11', 4:'2015/03/06',5:'2015/04/03',6:'2015/05/20'
# Gariguette: 1:'2014/12/10', 2:'2015/01/08', 3:'2015/02/12',4:'2015/03/06',5:'2015/04/02',6:'2015/05/19'
# d = {'2014/12/10':1,'2015/01/07':2,'2015/02/15':3,'2015/03/02':4,'2015/04/03':5,'2015/05/27':6,
# '2014/12/04':1,'2015/02/13':3,'2015/03/30':5,
# '2015/01/08':2,'2015/02/11':3,'2015/03/04':4,'2015/04/02':5,'2015/05/20':6,
# '2015/01/09':2,'2015/02/12':3,'2015/03/06':4,'2015/05/19':6}
cpx = g.complex_at_scale(vid, scale=1)
# _date = g.property('Sample_date')[cpx]
return g.property('Sample_date')[cpx]
######################### Transformation of dataframe ######################################
def df2waffle(df, date, index, variable, order=None, aggfunc=None, crosstab=None, *args, **kwargs):
'''
Transpose dataframe by variable with plant in columns and rank or order in index
This function are available for extraction at node scale (index='rank') and
extraction at module scale (index= 'order')
Parameters:
-----------
df: dataframe from extract function at differente scale (modules and nodes scale)
date_selected: date which must be processed
variable: variable which must be processed
Returns:
--------
a dataframe in "waffle" shape: index=date, & columns=variable
'''

if order:
data=df[(df['date']==date) & (df['order']==order)]
else:
data=df[df['date']==date]

if index=='rank':
res = data.pivot(index='rank',columns='plant',values=variable)
elif index=='order':
if crosstab:
res = pd.crosstab(index=data['order'], columns=data[variable], normalize='index')
res=res*100
res = res.round(2)
else:
# Catch data error: when values are string and aggfunc compute numbers
try:
res= data.pivot_table(index='order',columns='plant',values=variable, aggfunc=aggfunc)
except DataError:
print("ERROR, the aggregate function does not handle the data type (float func on str?)")
return pd.DataFrame()

else:
res = data.pivot(index=index,columns='plant',values=variable)

# If use plotly heatmap -> comment "res = res.fillna('')"
if res.isnull().values.any():
res = res.fillna('')
res = res.sort_index(ascending=False)
return res



def plot_waffle_plotly_heatmap(df, layout={}, legend_name={}):

def df_to_plotly(df):
return {'z': df.values.tolist(),
'x': df.columns.tolist(),
'y': df.index.tolist()}

height = layout.get('height', 500)
width = layout.get('width', 500)
xlabel = layout.get('xlabel', 'Plant')
xticks = layout.get('xticks', range(0,len(df.columns)))
xticks_label = layout.get('xticks_label', list(df.columns))
ylabel = layout.get('ylabel', '')
yticks = layout.get('yticks', [l-1 for l in list(df.index)])
yticks_label = layout.get('yticks_label', list(range(0,len(df.index))))
title = layout.get('title', '')

hm_layout = go.Layout(plot_bgcolor='rgba(0,0,0,0)',
# xaxis=dict(zeroline=False),
# yaxis=dict(zeroline=False, ),
autosize=False,
width=width, height=height
)

data = go.Heatmap(df_to_plotly(df),
xgap=1,
ygap=1,
colorscale="aggrnyl"
)

fig = go.Figure(data=data, layout=hm_layout)

return fig


def plot_waffle_plotly_imshow(df, layout={}, legend_name={}):
colormap_used = plt.cm.coolwarm

values = list(set(df.values.flatten()))
if '' in values:
values.remove('')
try:
values.sort()
except TypeError:
values = [str(i) for i in values]
values.sort()
values.insert(0,'')

color_map = {val: colormap_used(i/len(values)) for i, val in enumerate(values)}

# Add the "empty" variable - and set its color as white
color_map[''] = (1., 1., 1., 1.)

data = np.array(df)

# Create an array where each cell is a colormap value RGBA
data_3d = np.ndarray(shape=(data.shape[0], data.shape[1], 4), dtype=float)
for i in range(0, data.shape[0]):
for j in range(0, data.shape[1]):
data_3d[i][j] = color_map[data[i][j]]

# drop the A
data_3d_rgb = np.array([[to_rgb([v for v in row]) for row in col] for col in data_3d], dtype=np.float64)

yticks = list(range(0,data.shape[0]))
yticks.reverse()

fig = px.imshow(data,
labels={'x':'Plant', 'y':'Node'},
x=list(range(1,data.shape[1]+1)),
y=yticks,
origin='lower',
color_continuous_scale='aggrnyl',
# colorbar={}
)
fig.update_layout(plot_bgcolor='rgba(0,0,0,0)',
)
return fig


def plot_waffle_matplotlib(df, layout={}, legend_name={}):
height = layout.get('height', 500)
width = layout.get('width', 500)
xlabel = layout.get('xlabel', 'Plant')
xticks = layout.get('xticks', range(0,len(df.columns)))
xticks_label = layout.get('xticks_label', list(df.columns))
ylabel = layout.get('ylabel', '')
yticks = layout.get('yticks', [l-1 for l in list(df.index)])
yticks_label = layout.get('yticks_label', list(range(0,len(df.index))))
title = layout.get('title', '')

colormap_used = plt.cm.coolwarm

# Sort the variables. When variables are int or float, remove the str('') (that replaced the NaN) before sorting
values = list(set(df.values.flatten()))
if '' in values:
values.remove('')
try:
values.sort()
except TypeError:
values = [str(i) for i in values]
values.sort()
values.insert(0,'')

w_height = len(df.index)
w_width = len(df.columns)
color_map = {val: colormap_used(i/len(values)) for i, val in enumerate(values)}

# Add the "empty" variable - and set its color as white
color_map[''] = (1., 1., 1., 1.)

data = np.array(df)

# Create an array where each cell is a colormap value RGBA
data_3d = np.ndarray(shape=(data.shape[0], data.shape[1], 4), dtype=float)
for i in range(0, data.shape[0]):
for j in range(0, data.shape[1]):
data_3d[i][j] = color_map[data[i][j]]

# display the plot
fig, ax = plt.subplots(1,1)
fig.set_size_inches(height, width)
fig = ax.imshow(data_3d)

# Get the axis.
ax = plt.gca()

# Minor ticks
ax.set_xticks(np.arange(-.5, (w_width), 1), minor=True);
ax.set_yticks(np.arange(-.5, (w_height), 1), minor=True);

# Gridlines based on minor ticks
ax.grid(which='minor', color='w', linestyle='-', linewidth=2)

# Manually constructing a legend solves your "catagorical" problem.
legend_handles = []

for i, val in enumerate(values):
if val!= "":
color_val = color_map[val]
legend_handles.append(mpatches.Patch(color=color_val, label=legend_name.get(val, val)))

# Add the legend.
plt.legend(handles=legend_handles, loc=(1,0))
plt.xlabel(xlabel)
plt.ylabel(ylabel)

plt.xticks(ticks=xticks, labels=xticks_label)
plt.yticks(ticks=yticks, labels=yticks_label)

plt.title(title)

plt.show()

return fig


def plot_waffle(df, layout={}, legend_name={}, savepath=None, plot_func='matplotlib'):
"""
Plot a dataframe in "waffle" shape
layout: dict of layout parameters:
height/width: size of the picture in inch
x/ylabel: label of the x/y axis
x/yticks: ticks of the x/y axis
x/yticks_labels: labels of the ticks on the x/y axis
title: title
plot_func: library used for the ploting:
matplotlib: matplotlib.pyplot.subplot.imshow
plotly.imshow: plotly.express.imshow
plotly.heatmap: plotly.graph_objs.heatmap
"""

## Axes not working - Plotly heatmap
if plot_func=='plotly.heatmap':
fig= plot_waffle_plotly_heatmap(df=df, layout=layout, legend_name=legend_name)

# Plotly imshow
elif plot_func=='plotly.imshow':
fig= plot_waffle_plotly_imshow(df=df, layout=layout, legend_name=legend_name)

# With matplotlib
elif plot_func=='matplotlib':
try:
fig= plot_waffle_matplotlib(df=df, layout=layout, legend_name=legend_name)
except ValueError:
fig={}

if savepath:
plt.savefig(savepath)

return fig


def plot_pie(df):
return px.pie(df, values=df.mean(axis=0), names=df.columns)
5 changes: 0 additions & 5 deletions src/openalea/strawberry/visu3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,3 @@ def color_code(g):
nid.color = (0, int(127+127/(len(stades)-1)*(i)),255)
else:
nid.color = (153, 102, 51)


# 2D visualization
#############################################################################
# TODO: add visualization functions
26 changes: 26 additions & 0 deletions test/test_waffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path
import os
from openalea.mtg.io import read_mtg_file, write_mtg
from openalea.strawberry.analysis import extract_at_node_scale, extract_at_module_scale
from openalea.deploy.shared_data import shared_data
import openalea.strawberry

from openalea.strawberry.analysis import df2waffle

def name(f):
return f.basename().splitext()[0]

def test_df2waffle():
files = shared_data(openalea.strawberry).glob('*.mtg')
mtg_path = dict((name(f), f) for f in files)
mtg = read_mtg_file(mtg_path['Capriss'])

df = extract_at_node_scale(mtg)

node_scale = df2waffle(df, index='rank', date='2015/03/02', variable='branching_type')
assert node_scale.shape == (20, 9)

df = extract_at_module_scale(mtg)
module_scale=df2waffle(df, index='order', date='2015/03/02', variable='crown_status', aggfunc='median')
assert module_scale.shape == (3, 9)

0 comments on commit 7150ae4

Please sign in to comment.