forked from anazalea/pySankey
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsankey.py
152 lines (126 loc) · 5.75 KB
/
sankey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 24 18:43:02 2016
@author: anneya
.-.
.--.( ).--.
<-. .-.-.(.-> )_ .--.
`-`( )-' `) )
(o o ) `)`-'
( ) ,)
( () ) )
`---"\ , , ,/`
`--' `--' `--'
| | | |
| | | |
' | ' |
Produces Sankey Diagrams with matplotlib
Copyright (C) 2016 Anneya Golob
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict
plt.rc('text',usetex=False)
plt.rc('font',family='serif')
import seaborn as sns;
sns.set_style("white",{'font.family':[u'serif']})
def sankey(before,after,ax,colorDict={},aspect=4,rightColor=False,beforeLabels=[],afterLabels=[]):
'''
Make Sankey Diagram showing flow from before-->after
Inputs:
left = NumPy array of object labels on the left of the digram
right = NumPy array of corresponding labels on the right of the digram
len(right) == len(left)
colorDict = Dictionary of colors to use for each label {'label':'color'}
aspect = vertical extent of the digram in units of horizontal extent
rightColor = If true, each strip in the diagram will be be colored
according to its left label
Ouput:
None
'''
df = pd.DataFrame({'before':before,'after':after},index=range(len(before)))
# Identify all labels that appear 'before' or 'after'
allLabels = pd.Series(np.r_[df.before.unique(),df.after.unique()]).unique()
if beforeLabels!=[]:
killLabels = {}
for l in df.before.unique():
if l not in beforeLabels:
killLabels[l]='other'
print(killLabels)
df['before'].replace(killLabels,inplace=True)
if afterLabels!=[]:
killLabels = {}
for l in df.after.unique():
if l not in afterLabels:
killLabels[l]='other'
df['after'].replace(killLabels,inplace=True)
allLabels = pd.Series(np.r_[df.before.unique(),df.after.unique()]).unique()
# If no colorDict given, make one
if colorDict == {}:
pal = "hls"
cls = sns.color_palette(pal, len(allLabels))
for i,l in enumerate(allLabels):
colorDict[l] = cls[i]
# Determine widths of individual strips
ns = defaultdict()
for l in allLabels:
myD = {}
for l2 in allLabels:
myD[l2] = len(df[(df.before==l) & (df.after==l2)])
ns[l] = myD
print(ns)
# Determine positions of left and right label patches and total widths
widths = defaultdict()
for i,l in enumerate(allLabels):
myD = {}
myD['left'] = len(df[df.before==l])
myD['right'] = len(df[df.after==l])
myD['total'] = max(myD['left'],myD['right'])
if i==0:
myD['bottom'] = 0
myD['top'] = myD['total']
myD['leftBottom'] = 0.5*(myD['top']+myD['bottom']) - 0.5*myD['left']
myD['rightBottom'] = 0.5*(myD['top']+myD['bottom']) - 0.5*myD['right']
else:
myD['bottom'] = widths[allLabels[i-1]]['top'] + 0.02*len(df)
myD['top'] = myD['bottom'] + myD['total']
myD['leftBottom'] = 0.5*(myD['top']+myD['bottom']) - 0.5*myD['left']
myD['rightBottom'] = 0.5*(myD['top']+myD['bottom']) - 0.5*myD['right']
topEdge = myD['top']
widths[l] = myD
# Total vertical extent of diagram
xMax = topEdge/aspect
# Draw vertical bars on left and right of each label's section & print label
for l in allLabels:
ax.fill_between([-0.02*xMax,0],2*[widths[l]['leftBottom']],\
2*[widths[l]['leftBottom']+widths[l]['left']],color=colorDict[l],alpha=0.99)
ax.fill_between([xMax,1.02*xMax],2*[widths[l]['rightBottom']],\
2*[widths[l]['rightBottom']+widths[l]['right']],color=colorDict[l],alpha=0.99)
ax.text(-0.05*xMax,widths[l]['leftBottom']+0.5*widths[l]['left'],l,{'ha': 'right', 'va': 'center'})
ax.text(1.05*xMax,widths[l]['rightBottom']+0.5*widths[l]['right'],l,{'ha': 'left', 'va': 'center'})
# Plot strips
for l in allLabels:
for l2 in allLabels:
lc = l
if rightColor:
lc = l2
# Create array of y values for each strip, half at left value, half at right, convolve
ys = np.array(50*[widths[l]['leftBottom']+0.5*ns[l][l2]]+50*[widths[l2]['rightBottom']+0.5*ns[l][l2]])
ys = np.convolve(ys,0.05*np.ones(20),mode='valid')
ys = np.convolve(ys,0.05*np.ones(20),mode='valid')
# Update bottom edges at each label so next strip starts at the right place
widths[l]['leftBottom'] = widths[l]['leftBottom']+ns[l][l2]
widths[l2]['rightBottom'] = widths[l2]['rightBottom']+ns[l][l2]
ax.fill_between(np.linspace(0,xMax,len(ys)),ys-0.5*ns[l][l2],ys+0.5*ns[l][l2],alpha=0.65,color=colorDict[lc])
ax.axis('off')