-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResultVisualizationHelper.py
28 lines (25 loc) · 1.02 KB
/
ResultVisualizationHelper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import csv
training_losses = []
validation_losses = []
with open('results/losses.csv', 'r') as file:
lines = csv.reader(file)
for row in lines:
if row[0] == 'training':
training_losses.append(abs(int(row[1])))
elif row[0] == 'validation':
validation_losses.append(abs(int(row[1])))
fig, ax = plt.subplots()
ax.plot(range(len(training_losses)), training_losses, color="red", marker="o")
ax.set_xlabel("Epoch", fontsize=14)
ax.set_ylim([0, 5000])
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_ylabel("Training Loss / batch", color="red", fontsize=14)
ax2 = ax.twinx()
ax2.plot(range(len(validation_losses)), validation_losses, color="blue", marker="o")
ax2.set_ylabel("Validation Loss / batch", color="blue", fontsize=14)
ax2.set_ylim([0, 10000000000])
ax2.ticklabel_format(useOffset=False, style='plain')
plt.show()
fig.savefig('results/lossesOverTime.jpg', format='jpeg', dpi=100, bbox_inches='tight')