-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
92 lines (73 loc) · 3.25 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from load_data import *
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from matplotlib.style import use
use("seaborn")
def plot_predict_test_data(ground_truth_data, input_data, predictions):
fig, ax = plt.subplots(figsize=(16, 9))
plt.plot(ground_truth_data, color="tab:red", label="True Value")
ax.plot(
range(len(input_data) + 1, len(input_data) + 1 + len(predictions)),
predictions,
color="tab:blue",
label="Predicted Testing Value"
)
plt.legend()
plt.show()
def plot_all_data(input_data, predictions, fig, ax, **label):
ax.plot(input_data, color="tab:red", label=label["ground_truth"])
ax.plot(predictions, color="tab:blue", label=label["predict_value"])
return fig, ax
def predict_func(input_data, model_name, scaler):
model = load_model(model_name)
predictions = model.predict(input_data)
predictions = scaler.inverse_transform(predictions)
return predictions
if __name__ == "__main__":
csv_file = "TY_climate_2017_2018.csv"
# column_name = "TT-Avg(℃)" # column_name: TT-Avg(℃), MT-Avg(g)
column_name = "MT-Avg(g)"
# load date
train_data, test_data = load_data(csv_file, column_name)
train_data, scaler_train = data_preprocessing(train_data)
test_data, scaler_test = data_preprocessing(test_data)
# get ground truth
df = pd.read_csv(csv_file)
df = df[column_name].values
df = df.reshape(-1, 1)
# create data
x_train, y_train = create_dataset(train_data)
x_test, y_test = create_dataset(test_data)
# reshape
x_train = x_train.reshape(x_train.shape[0], 1, 1)
y_train = y_train.reshape(y_train.shape[0], 1, 1)
x_test = x_test.reshape(x_test.shape[0], 1, 1)
y_test = y_test.reshape(y_test.shape[0], 1, 1)
# predict all data
all_data, scaler_all_data = data_preprocessing(df)
all_data_x, all_data_y = create_dataset(all_data)
all_data_x = all_data_x.reshape(all_data_x.shape[0], 1, 1)
all_data_y = all_data_y.reshape(all_data_y.shape[0], 1, 1)
# ---------------------------------------------------------------------------
# ------------------------predict test part----------------------------------
# model_name = "saved_models_tt_avg/LSTM_002.h5"
# predictions = predict_func(x_test, model_name, scaler_test)
# model_name = "saved_models_mt_avg/LSTM_002.h5"
# predictions = predict_func(all_data_x, model_name, scaler_all_data)
# ------------------------predict all data-----------------------------------
# model_name = "saved_models_tt_avg/LSTM_002.h5"
# predictions = predict_func(all_data_x, model_name, scaler_all_data)
model_name = "saved_models_mt_avg/LSTM_002.h5"
predictions = predict_func(all_data_x, model_name, scaler_all_data)
# ---------------------------------------------------------------------------
# plot predicted part and all ground truth
# plot_predict_test_data(df, y_train, predictions)
# plot all ground truth data and all predicted data
labels = {
"ground_truth": "True Testing Value",
"predict_value": "Predicted Testing Value"
}
fig, ax = plt.subplots(figsize=(16, 9))
fig, ax = plot_all_data(df[1:], predictions, fig, ax, **labels)
plt.legend()
plt.show()