-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_functions.py
349 lines (281 loc) · 11.1 KB
/
train_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from typing import Dict, List, Tuple
def accuracy(output_logits, true_labels) -> float:
"""
This function gets the raw logits from model's output and
comapres them with the true labels.
Args:
output_logits: which are torch.Tensors outputed from the model
true_labels: the ground truth labels
Retruns:
The accuracy
"""
if output_logits.shape[1] == 1: # incase we have a bin-array classification
# Convert logits to predicted labels (0 or 1)
predicted_indices = (output_logits > 0).float()
# Ensure true_labels are also floats for comparison
true_indices = true_labels.float()
else: # multi-class classification
# Convert logits to predicted class indices
predicted_indices = torch.argmax(output_logits, dim=1)
# Compute the number of correct predictions
correct_predictions = (predicted_indices == true_labels).sum().item()
# Compute accuracy
accuracy = correct_predictions / len(output_logits)
return accuracy
# Function to train and evaluate the model
def train_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device):
"""
This function performs a training step for a single epoch.
Turns a pytoch model into training mode and then runs it
through all of the required training steps.
Args:
model: a Pytorch model
dataloader: A DataLoader instance to train the model on.
criterion: A pytorch loss function to minimize.
optimizer: A pytorch optimizer to help minimize the loss function.
device: a target device to compute on. e.g.("cpu" or "cuda")
Returns:
The Training loss and accuracy
"""
# put the model into traing mode
model.train()
# Setup train loss and train accuracy values
train_loss, train_acc = 0, 0
# loop throught the batches of the DataLoader and train
for batch, (X,y) in enumerate(dataloader):
# send the data to target device
X, y = X.to(device) , y.to(device)
# forward pass through the model
y_pred = model(X)
# calculate and accumulate the loss
loss = criterion(y_pred,y)
train_loss += loss.item()
# set optimizer zero grad
optimizer.zero_grad()
# loss backwards
loss.backward()
# optimizer step
optimizer.step()
# Calculate and accumulate accuracy metric across all batches
train_acc += accuracy(y_pred , y)
# Adjust metrics to get average loss and accuracy per batch
train_loss = train_loss / len(dataloader)
train_acc = train_acc / len(dataloader)
return train_loss, train_acc
def batch_train(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
accumulation_steps: int = 1):
"""
This function performs a training step for a single epoch with gradient accumulation.
Turns a pytorch model into training mode and then runs it
through all of the required training steps.
Args:
model: a Pytorch model
dataloader: A DataLoader instance to train the model on.
criterion: A pytorch loss function to minimize.
optimizer: A pytorch optimizer to help minimize the loss function.
device: a target device to compute on. e.g.("cpu" or "cuda")
accumulation_steps: Number of steps to accumulate gradients before performing an optimizer step.
Returns:
The Training loss and accuracy
"""
# put the model into training mode
model.train()
# Setup train loss and train accuracy values
train_loss, train_acc = 0, 0
# loop through the batches of the DataLoader and train
for batch, (X, y) in enumerate(dataloader):
# send the data to target device
X, y = X.to(device), y.to(device)
# forward pass through the model
y_pred = model(X)
# calculate and accumulate the loss
loss = criterion(y_pred, y)
train_loss += loss.item()
# loss backwards
loss.backward()
# Perform optimizer step and zero grad every `accumulation_steps`
if (batch + 1) % accumulation_steps == 0:
# optimizer step
optimizer.step()
# set optimizer zero grad
optimizer.zero_grad()
# Calculate and accumulate accuracy metric across all batches
train_acc += accuracy(y_pred, y)
# If we have leftover gradients, perform a step
if (batch + 1) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()
# Adjust metrics to get average loss and accuracy per batch
train_loss = train_loss / len(dataloader)
train_acc = train_acc / len(dataloader)
return train_loss, train_acc
def test_step(model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
criterion: torch.nn.Module,
device: torch.device):
"""
Tests a PyTorch model for a single epoch.
Turns a target PyTorch model to "eval" mode and then performs
a forward pass on a testing dataset.
Args:
model: A PyTorch model to be tested.
dataloader: A DataLoader instance for the model to be tested on.
criterion: A PyTorch loss function to calculate loss on the test data.
device: A target device to compute on (e.g. "cuda" or "cpu").
Returns:
The test loss and test accuracy
"""
# Put model in eval mode
model.eval()
# Setup test loss and test accuracy values
test_loss, test_acc = 0, 0
# Turn on inference context manager
with torch.inference_mode():
# Loop through DataLoader batches
for batch, (X, y) in enumerate(dataloader):
# Send data to target device
X, y = X.to(device), y.to(device)
# Forward pass
test_pred_logits = model(X)
# Calculate and accumulate loss
loss = criterion(test_pred_logits, y)
test_loss += loss.item()
# Calculate and accumulate accuracy
test_acc += accuracy(test_pred_logits , y)
# Adjust metrics to get average loss and accuracy per batch
test_loss = test_loss / len(dataloader)
test_acc = test_acc / len(dataloader)
return test_loss, test_acc
def train(model: torch.nn.Module,
train_dataloader: torch.utils.data.DataLoader,
test_dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
epochs: int,
device: torch.device,
batched_acc: bool = False,
accumulation_steps: int = 1) -> Dict[str, List]:
"""Trains and tests a PyTorch model.
Passes a target PyTorch models through batch_train() or train_step() and test_step()
functions for a number of epochs, training and testing the model
in the same epoch loop.
Calculates, prints and stores evaluation metrics throughout.
Args:
model: A PyTorch model to be trained and tested.
train_dataloader: A DataLoader instance for the model to be trained on.
test_datalaoder: A DataLoader instance for the model to be tested on.
optimizer: A PyTorch optimizer to help minimize the loss function.
criterion: A PyTorch loss function to calculate loss on both datasets.
epochs: An integer indicating how many epochs to train for.
device: A target device to compute on (e.g. "cuda" or "cpu").
batched_acc: A boolean to indicate if gradient accumulation should be used.
accumulation_steps: Number of steps to accumulate gradients before performing an optimizer step.
Returns:
A dictionary of training and testing loss.
Each metric has a value in a list for each epoch.
In the form: {train_loss: [...],
train_acc: [...],
test_loss: [...],
test_acc: [...]}
"""
# Create empty results dictionary
results = {
"train_loss": [],
"train_acc": [],
"test_loss": [],
"test_acc": []
}
# Loop through training and testing steps for a number of epochs
for epoch in tqdm(range(epochs)):
if batched_acc:
train_loss, train_acc = batch_train(model=model,
dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
device=device,
accumulation_steps=accumulation_steps)
else:
train_loss, train_acc = train_step(model=model,
dataloader=train_dataloader,
criterion=criterion,
optimizer=optimizer,
device=device)
test_loss, test_acc = test_step(model=model,
dataloader=test_dataloader,
criterion=criterion,
device=device)
print(
f"Epoch: {epoch+1} | "
f"train_loss: {train_loss:.4f} | "
f"train_acc: {train_acc:.4f} | "
f"test_loss: {test_loss:.4f} | "
f"test_acc: {test_acc:.4f}"
)
# Update results dictionary
results["train_loss"].append(train_loss)
results["train_acc"].append(train_acc)
results["test_loss"].append(test_loss)
results["test_acc"].append(test_acc)
# Return the filled results at the end of the epochs
return results
def plot_loss(loss_results:dict, save_path:str= None):
"""Takes a Dict of results produced by the train function
and plots the loss
Args:
loss_results: its a dict with two key values
train_loss and test_loss
Returns:
plots the models train and test loss with respect to the epochs
"""
# Get the number of epochs or data points
epochs = len(loss_results['train_loss'])
# Plotting the train loss
plt.plot(range(1, epochs + 1), loss_results['train_loss'], label='Train Loss')
# Plotting the test loss
plt.plot(range(1, epochs + 1), loss_results['test_loss'], label='Test Loss')
# Adding labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Test Loss Over Epochs')
# Adding a legend
plt.legend()
# Show the plot
plt.show()
if save_path:
plt.savefig(save_path, format='jpeg')
def plot_accuracy(loss_results:dict, save_path:str= None):
"""Takes a Dict of results produced by the train function
and plots the accuracy
Args:
loss_results: its a dict with two key values
train_loss and test_loss
Returns:
plots the models train and test accuracy with respect to the epochs
"""
# Get the number of epochs or data points
epochs = len(loss_results['train_loss'])
# Plotting the train loss
plt.plot(range(1, epochs + 1), loss_results['train_acc'], label='Train Accuracy')
# Plotting the test loss
plt.plot(range(1, epochs + 1), loss_results['test_acc'], label='Test Accuracy')
# Adding labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Train and Test Loss Over Epochs')
# Adding a legend
plt.legend()
# Show the plot
plt.show()
if save_path:
plt.savefig(save_path, format='jpeg')