Skip to content

Commit

Permalink
Fix a bug that training is stuck while detection model is trained on …
Browse files Browse the repository at this point in the history
…distrubited environment (#3904)

* specify data type

* add docstring to notify need to update FMeasure for distributed training

* update label dtype from int32 to long
  • Loading branch information
eunwoosh authored Aug 29, 2024
1 parent 0a395b2 commit bc5b7d0
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/otx/core/data/dataset/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None:
bboxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=img_shape,
dtype=torch.float32,
),
labels=torch.as_tensor([ann.label for ann in bbox_anns]),
labels=torch.as_tensor([ann.label for ann in bbox_anns], dtype=torch.long),
)

return self._apply_transforms(entity)
Expand Down
1 change: 1 addition & 0 deletions src/otx/core/data/dataset/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None:
bboxes,
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=img_shape,
dtype=torch.float32,
),
masks=tv_tensors.Mask(masks, dtype=torch.uint8),
labels=torch.as_tensor(labels),
Expand Down
2 changes: 2 additions & 0 deletions src/otx/core/metrics/fmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,8 @@ class FMeasure(Metric):
IoU > threshold are reduced to one. This threshold can be determined automatically by setting `vary_nms_threshold`
to True.
# TODO(someone): need to update for distriubted training. refer https://lightning.ai/docs/torchmetrics/stable/pages/implement.html
Args:
label_info (int): Dataclass including label information.
vary_nms_threshold (bool): if True the maximal F-measure is determined by optimizing for different NMS threshold
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _convert_pred_entity_to_compute_metric(
"preds": [
{
"boxes": bboxes.data,
"scores": scores,
"scores": scores.type(torch.float32),
"labels": labels,
}
for bboxes, scores, labels in zip(
Expand Down

0 comments on commit bc5b7d0

Please sign in to comment.