Skip to content

Commit

Permalink
tidy code
Browse files Browse the repository at this point in the history
  • Loading branch information
crazyzlj committed Sep 4, 2017
1 parent a45e7b0 commit 0e07bab
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ doc/img/
.nfs*

# add nogit directory which contains personal test code
seims/nogit/*
/seims/nogit/*
30 changes: 30 additions & 0 deletions seims/nogit/export_scenario_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#! /usr/bin/env python
# -*- coding: utf-8 -*-
"""Export scenario to tiff.
@author : Liangjun Zhu
@changelog:
"""
from seims.pygeoc.pygeoc.utils.utils import get_config_parser
from seims.scenario_analysis.slpposunits.config import SASPUConfig
from seims.scenario_analysis.slpposunits.scenario import SPScenario


def export_scenario_demo():
"""Export scenario as raster data."""
cf = get_config_parser()
cfg = SASPUConfig(cf)
gene_array = [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0,
1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 2.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0,
0.0, 2.0, 0.0, 2.0, 0.0, 2.0, 0.0, 2.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0,
0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
out_raster = r'C:\z_data\ChangTing\seims_models\NSGA2_Output\0830_method3\rule_mth3_3rd\gen100_158736296.tif'
sce = SPScenario(cfg)
setattr(sce, 'gene_values', gene_array)
sce.export_scenario_to_gtiff(out_raster)


if __name__ == '__main__':
export_scenario_demo()
4 changes: 2 additions & 2 deletions seims/preprocess/database/model_param_ini.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ Sediment,USLE_K,the soil erodibility factor,none,ALL,-9999,1,RC,1,0
Sediment,USLE_LS_COEF,the LS coefficient,none,ALL,1,0,AC,10,0
Sediment,USLE_P,the erosion control practice factor,none,ALL,-9999,1,RC,1,0
Sediment,vcd,channel degredation code,None,SEDR,0,0,VC,1,0
Sediment,vcrit,critical veldeg Ctiy for sediment deposition,m/s,ChannelRouting,0.01,0,AC,1,0
Sediment,vcrit,critical veloctiy for sediment deposition,m/s,ChannelRouting,0.01,0,AC,1,0
Snow,c_rain,Rainfall impact factor,mm/mm/deg C/delta_t,SNO_DD,0.04,0,AC,1,0
Snow,c_snow,temperature impact factor,mm/deg C/delta_t,SNO_DD,3,0,AC,6,0
Snow,c_snow12,Melt factor on December 21,mm/deg C/day,SNO_SP,6.5,0,AC,9,6
Expand Down Expand Up @@ -188,7 +188,7 @@ WaterBalance,k_soil10,ratio between soil temperature at 10 cm and the mean,none,
WaterBalance,P_max,Maximum rainfall intensity when k_run = 1.0,mm,SUR_MR,30,0,AC,1000,10
WaterBalance,Poreindex,Pore size distribution index,none,ALL,-9999,1,RC,1.2,0.8
WaterBalance,Porosity,Soil porosity,m3/m3,ALL,-9999,1,RC,0.8,0.2
WaterBalance,pot_k,hydraulic conductivity of soil surface of pothole,mm/hr,IMP_SWAT,0.1,0,AC,10,0.01
WaterBalance,pot_k,hydraulic conductivity of soil surface of pothole,mm/hr,IMP_SWAT,0.1,0,AC,100,0.01
WaterBalance,Rootdepth,Root depth,m,ALL,-9999,1,RC,1.5,0.2
WaterBalance,Runoff_co,Potential runoff coefficient,none,SUR_MR,-9999,1,RC,1,0.5
WaterBalance,rv_co,Groundwater revap coefficient,none,GW_RSVR,0,0,AC,0.2,0
Expand Down
3 changes: 2 additions & 1 deletion seims/scenario_analysis/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import random
from datetime import timedelta
from subprocess import CalledProcessError
from pymongo.errors import NetworkTimeout

from bson.objectid import ObjectId
from pymongo.errors import NetworkTimeout

from seims.preprocess.db_mongodb import ConnectMongoDB
from seims.pygeoc.pygeoc.utils.utils import UtilClass, StringClass, get_config_parser
Expand Down
13 changes: 9 additions & 4 deletions seims/scenario_analysis/slpposunits/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def calculate_environment(self):
self.environment = self.worst_env
return

def export_scenario_to_gtiff(self):
def export_scenario_to_gtiff(self, outpath=None):
"""Export scenario to GTiff.
TODO: Read Raster from MongoDB should be extracted to pygeoc.
Expand Down Expand Up @@ -281,8 +281,8 @@ def export_scenario_to_gtiff(self):

for k, v in v_dict.iteritems():
slppos_data[slppos_data == k] = v

outpath = self.scenario_dir + os.sep + 'Scenario_%d.tif' % self.ID
if outpath is None:
outpath = self.scenario_dir + os.sep + 'Scenario_%d.tif' % self.ID
RasterUtilClass.write_gtiff_file(outpath, ysize, xsize, slppos_data, geotransform,
srs, nodata_value)
client.close()
Expand Down Expand Up @@ -347,7 +347,8 @@ def scenario_effectiveness(cf, individual):
return sce.economy, sce.environment, curid


if __name__ == '__main__':
def main():
"""TEST CODE"""
cf = get_config_parser()
cfg = SASPUConfig(cf)

Expand All @@ -372,3 +373,7 @@ def scenario_effectiveness(cf, individual):
# econ, env, sceid = scenario_effectiveness(cfg, init_gene_values)
# print ('Scenario %d: %s\n' % (sceid, ', '.join(str(v) for v in init_gene_values)))
# print ('Effectiveness:\n\teconomy: %f\n\tenvironment: %f\n' % (econ, env))


if __name__ == '__main__':
main()
140 changes: 112 additions & 28 deletions seims/scenario_analysis/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def plot_pareto_front(pop, ws, gen_id):
img_path = ws + os.sep + 'Pareto_Gen_%d_Pop_%d.png' % (gen_id, pop_size)
plt.savefig(img_path)
# plt.show()
# close current plot in case of 'figure.max_open_warning'
plt.cla()
plt.clf()
plt.close()


def read_pareto_points_from_txt(txt_file, sce_name, xname, yname, gens):
Expand All @@ -58,9 +62,9 @@ def read_pareto_points_from_txt(txt_file, sce_name, xname, yname, gens):
if len(values) != 1:
continue
gen = int(values[0])
if gen not in gens:
found = False
continue
# if gen not in gens:
# found = False
# continue
found = True
cur_gen = gen
pareto_popnum[cur_gen] = list()
Expand Down Expand Up @@ -146,11 +150,11 @@ def read_pareto_popsize_from_txt(txt_file, sce_name='scenario'):
return genids, acc_num


def plot_pareto_fronts_by_method(method_files, sce_name, xname, yname, gens, ws):
def plot_pareto_fronts_by_method(method_paths, sce_name, xname, yname, gens, ws):
"""
Plot Pareto fronts of different method at a same generation for comparision.
Args:
method_files(dict): key is method name (which also displayed in legend), value is file path.
method_paths(dict): key is method name (which also displayed in legend), value is file path.
sce_name(str): Scenario ID field name.
xname(list): the first is x field name in log file, and the second on is for plot,
the third and forth values are low and high limit (optional).
Expand All @@ -160,7 +164,8 @@ def plot_pareto_fronts_by_method(method_files, sce_name, xname, yname, gens, ws)
"""
pareto_data = dict()
acc_pop_size = dict()
for k, v in method_files.iteritems():
for k, v in method_paths.iteritems():
v = v + os.sep + 'runtime.log'
pareto_data[k], acc_pop_size[k] = read_pareto_points_from_txt(v, sce_name, xname,
yname, gens)
# print (pareto_data)
Expand All @@ -172,14 +177,16 @@ def plot_pareto_fronts_by_method(method_files, sce_name, xname, yname, gens, ws)
plt.rcParams['ytick.direction'] = 'out'
plt.rcParams['font.family'] = 'Times New Roman'
markers = ['.', '+', '*', 'x', 'd', 'h', 's', '<', '>']
linestyles = ['-', '--', '-.', ':']
# plot accumulate pop size
fig, ax = plt.subplots(figsize=(12, 8))
fig, ax = plt.subplots(figsize=(9, 8))
mark_idx = 0
for method, gen_popsize in acc_pop_size.iteritems():
xdata = gen_popsize[0]
ydata = gen_popsize[1]
plt.plot(xdata, ydata, marker=markers[mark_idx], color='black', markersize=20,
label=method, linewidth=1)
print ('Evaluated pop size: %s - %d' % (method, sum(ydata)))
plt.plot(xdata, ydata, linestyle=linestyles[mark_idx], color='black',
label=method, linewidth=2)
mark_idx += 1
plt.legend(fontsize=24, loc=2)
xaxis = plt.gca().xaxis
Expand All @@ -199,9 +206,52 @@ def plot_pareto_fronts_by_method(method_files, sce_name, xname, yname, gens, ws)
plt.cla()
plt.clf()
plt.close()

# plot Pareto points of all generations
mark_idx = 0
for method, gen_popsize in pareto_data.iteritems():
fig, ax = plt.subplots(figsize=(9, 8))
xdata = list()
ydata = list()
for gen, gendata in gen_popsize.iteritems():
xdata += gen_popsize[gen][xname[0]]
ydata += gen_popsize[gen][yname[0]]
plt.scatter(xdata, ydata, marker=markers[mark_idx], s=20,
color='black', label=method)
mark_idx += 1
xaxis = plt.gca().xaxis
yaxis = plt.gca().yaxis
for xlebal in xaxis.get_ticklabels():
xlebal.set_fontsize(20)
for ylebal in yaxis.get_ticklabels():
ylebal.set_fontsize(20)
plt.xlabel(xlabel_str, fontsize=20)
plt.ylabel(ylabel_str, fontsize=20)
# set xy axis limit
curxlim = ax.get_xlim()
if len(xname) >= 3:
if curxlim[0] < xname[2]:
ax.set_xlim(left=xname[2])
if len(xname) >= 4 and curxlim[1] > xname[3]:
ax.set_xlim(right=xname[3])
curylim = ax.get_ylim()
if len(yname) >= 3:
if curylim[0] < yname[2]:
ax.set_ylim(bottom=yname[2])
if len(yname) >= 4 and curylim[1] > yname[3]:
ax.set_ylim(top=yname[3])
plt.tight_layout()
fpath = ws + os.sep + method + '-Pareto.png'
plt.savefig(fpath, dpi=300)
print ('%s saved!' % fpath)
# close current plot in case of 'figure.max_open_warning'
plt.cla()
plt.clf()
plt.close()

# plot comparision of Pareto fronts
for gen in gens:
fig, ax = plt.subplots(figsize=(12, 8))
fig, ax = plt.subplots(figsize=(9, 8))
mark_idx = 0
gen_existed = True
for method, gen_popsize in pareto_data.iteritems():
Expand Down Expand Up @@ -232,13 +282,13 @@ def plot_pareto_fronts_by_method(method_files, sce_name, xname, yname, gens, ws)
if len(xname) >= 3:
if curxlim[0] < xname[2]:
ax.set_xlim(left=xname[2])
if len(xname) >= 4 and curxlim[1] > xname[3]:
if len(xname) >= 4: # and curxlim[1] > xname[3]:
ax.set_xlim(right=xname[3])
curylim = ax.get_ylim()
if len(yname) >= 3:
if curylim[0] < yname[2]:
ax.set_ylim(bottom=yname[2])
if len(yname) >= 4 and curylim[1] > yname[3]:
if len(yname) >= 4: # and curylim[1] > yname[3]:
ax.set_ylim(top=yname[3])

plt.legend(fontsize=24, loc=2)
Expand All @@ -252,19 +302,53 @@ def plot_pareto_fronts_by_method(method_files, sce_name, xname, yname, gens, ws)
plt.close()


def main():
"""Main Entrance."""
base_dir = r'C:\z_data\ChangTing\seims_models\NSGA2_OUTPUT\0829_constrait'
method_pareto = {'Rule based': base_dir + os.sep + 'rule_mth3/runtime.log',
'Random': base_dir + os.sep + 'rdm_cxhill/runtime.log'}
scenario_id = 'scenario'
# xaxis = ['economy', 'Economical benefit (1,000 USD$)']
xaxis = ['economy', 'Economical benefit (10,000 RMBY)']
yaxis = ['environment', 'Reduction rate of soil erosion']
draw_gens = range(1, 110)

plot_pareto_fronts_by_method(method_pareto, scenario_id, xaxis, yaxis, draw_gens, base_dir)


if __name__ == '__main__':
main()
def plot_hypervolume_by_method(method_paths, ws):
"""Plot hypervolume"""
hyperv = dict()
for k, v in method_paths.iteritems():
v = v + os.sep + 'hypervolume.txt'
x = list()
y = list()
f = open(v)
for line in f:
values = StringClass.extract_numeric_values_from_string(line)
if values is None:
continue
if len(values) != 2:
continue
x.append(int(values[0]))
y.append(values[1])
f.close()
if len(x) == len(y) > 0:
hyperv[k] = [x[:], y[:]]
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'
plt.rcParams['font.family'] = 'Times New Roman'
linestyles = ['-', '--', '-.', ':']
# plot accumulate pop size
fig, ax = plt.subplots(figsize=(10, 8))
mark_idx = 0
for method, gen_hyperv in hyperv.iteritems():
xdata = gen_hyperv[0]
ydata = gen_hyperv[1]
plt.plot(xdata, ydata, linestyle=linestyles[mark_idx], color='black',
label=method, linewidth=2)
mark_idx += 1
plt.legend(fontsize=24, loc=2)
xaxis = plt.gca().xaxis
yaxis = plt.gca().yaxis
for xlebal in xaxis.get_ticklabels():
xlebal.set_fontsize(20)
for ylebal in yaxis.get_ticklabels():
ylebal.set_fontsize(20)
plt.xlabel('Generation count', fontsize=20)
plt.ylabel('Hypervolume', fontsize=20)
ax.set_xlim(left=0, right=ax.get_xlim()[1] + 2)
plt.tight_layout()
fpath = ws + os.sep + 'hypervolume.png'
plt.savefig(fpath, dpi=300)
print ('%s saved!' % fpath)
# close current plot in case of 'figure.max_open_warning'
plt.cla()
plt.clf()
plt.close()

0 comments on commit 0e07bab

Please sign in to comment.