-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
110 lines (85 loc) · 3.65 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
def pixel_wise_accuracy(pred: torch.Tensor, target: torch.Tensor) -> float:
"""
Calculate pixe-wise accuracy.
Args:
pred (Tensor): output (logits) of model, (batch_size, num_classes, width,height).
target (Tensor): ground truth, trimap.
Returns:
accuracy (float): pixel accuracy.
"""
pred_tri = torch.argmax(pred, axis=1)
target = target - 1
accuracy = torch.sum(pred_tri == target) / pred_tri.numel()
return accuracy.item()
def iou_score(pred: torch.Tensor, target: torch.Tensor, smooth: float | None = 1e-6):
"""
Calculate the Intersection over Union (IoU) score.
Args:
pred (torch.Tensor): The output of the model. Shape: (batch_size, num_classes, width, height).
target (torch.Tensor): The ground truth labels for each pixel. Shape: (batch_size, width, height).
smooth (float): A small value to prevent division by zero. Default to `1e-6`.
Returns:
float: The average IoU score across all classes and all batches.
"""
_, num_classes, _, _ = pred.shape
target = target - 1 # Convert values from 1-3 to 0-2, assuming target values start from 1
assert target.min() >= 0 and target.max(
) < num_classes, 'Target contains invalid class indices'
# Convert logits to probabilities
pred = torch.nn.functional.softmax(pred, dim=1)
# Convert targets to one-hot encoding
# Shape change: (batch_size, width, height) -> (batch_size, num_classes, width, height)
target_one_hot = torch.nn.functional.one_hot(
target, num_classes).permute(0, 3, 1, 2)
target_one_hot = target_one_hot.type_as(pred)
# Calculate intersection and union for each batch and class
# Sum over width and height
intersection = torch.sum(pred * target_one_hot, dim=(2, 3))
union = torch.sum(pred + target_one_hot, dim=(2, 3)) - \
intersection # Ensure to subtract intersection once
# Calculate IoU and avoid division by zero
iou = (intersection + smooth) / (union + smooth)
# Average over all classes and batches
return iou.mean().item()
def evaluate_model_performance(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
mask: torch.Tensor,
model_description: str
):
"""
Evaluate the model on given dataloader to compute accuracy and IoU score.
Args:
model (nn.Moudle): The PyTorch model to evaluate.
dataloader (DateLoader): The DataLoader containing the test dataset.
device (torch.device): The device on which the computations are performed.
mask (Tensor): The mask tensor applied to inputs if necessary.
model_description (str): Description of the model phase for output clarity.
Returns:
None: Prints the accuracy and IoU directly.
"""
model.eval() # Set the model to evaluation mode
acc = 0
iou_total = 0
for x, y in dataloader:
inputs, targets = x.to(device), y.to(device)
preds = model(inputs, mask)
batch_size = preds.shape[0]
preds = preds.reshape(batch_size, 3, 224, 224)
# Calculate accuracy
acc += pixel_wise_accuracy(preds, targets)
# Calculate IoU
iou = iou_score(preds, targets)
iou_total += iou
# Calculate the average scores over all batches
accuracy = acc / len(dataloader)
average_iou = iou_total / len(dataloader)
# Print the results
print(
f'Accuracy of {model_description} on the fine-tuning test dataset: {accuracy:.2f}')
print(
f'IoU score of {model_description} on the fine-tuning test dataset: {average_iou:.2f}')