Skip to content

Commit

Permalink
fix loss
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Apr 8, 2024
1 parent 727cf2b commit 22f8f81
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions src/otx/algo/detection/losses/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@

from torch import nn


if TYPE_CHECKING:
from torch import Tensor


def accuracy(pred: Tensor, target: Tensor, topk: int = 1, thresh: float | None = None):
def accuracy(
pred: Tensor,
target: Tensor,
topk: int | tuple[int] = 1,
thresh: float | None = None,
) -> list:
"""Calculate accuracy according to the prediction and target.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/otx/algo/detection/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .accuracy import accuracy
from .utils import weight_reduce_loss


if TYPE_CHECKING:
from torch import Tensor

Expand Down Expand Up @@ -79,7 +78,8 @@ def py_focal_loss_with_prob(
reduction: str = "mean",
avg_factor: int | None = None,
):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
Expand Down

0 comments on commit 22f8f81

Please sign in to comment.