Skip to content

Commit

Permalink
Fix num_trials calculation on dataset length less than num_class (#4014)
Browse files Browse the repository at this point in the history
Fix balanced sampler
  • Loading branch information
harimkang authored Oct 11, 2024
1 parent 7040faf commit 81829a3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/otx/algo/samplers/balanced_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0}
self.num_cls = len(self.img_indices.keys())
self.data_length = len(self.dataset)
self.num_trials = int(self.data_length / self.num_cls)
self.num_trials = max(int(self.data_length / self.num_cls), 1)

if efficient_mode:
# Reduce the # of sampling (sampling data for a single epoch)
Expand Down

0 comments on commit 81829a3

Please sign in to comment.