Skip to content

Commit

Permalink
implement test code of torch automatic_bs file
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Apr 17, 2023
1 parent e09b70e commit 37ec5d7
Showing 1 changed file with 51 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest

from otx.algorithms.common.adapters.torch.utils import adapt_batch_size
from otx.algorithms.common.adapters.torch.utils import automatic_bs


def test_adapt_batch_size(mocker):
mocker_torch = mocker.patch.object(automatic_bs, "torch")
mocker_torch.cuda.mem_get_info.return_value = (1, 10000)

def mock_train_func(batch_size):
if batch_size > 100:
mocker_torch.cuda.max_memory_allocated.return_value = 10000
raise RuntimeError("CUDA out of memory.")
elif batch_size > 80:
mocker_torch.cuda.max_memory_allocated.return_value = 9000
else:
mocker_torch.cuda.max_memory_allocated.return_value = 1000

adapted_bs = adapt_batch_size(mock_train_func, 128, 1000)

assert adapted_bs == 80


def test_adapt_batch_size_gpu_memory_too_small(mocker):
mocker_torch = mocker.patch.object(automatic_bs, "torch")
mocker_torch.cuda.mem_get_info.return_value = (1, 10000)

def mock_train_func(batch_size):
if batch_size > 4:
mocker_torch.cuda.max_memory_allocated.return_value = 10000
raise RuntimeError("CUDA out of memory.")
elif batch_size >= 2:
mocker_torch.cuda.max_memory_allocated.return_value = 9000
else:
mocker_torch.cuda.max_memory_allocated.return_value = 1000

with pytest.raises(RuntimeError):
adapt_batch_size(mock_train_func, 128, 1000)


@pytest.mark.parametrize("default_bs", [-1, 0])
def test_adapt_batch_size_wrong_default_bs(mocker, default_bs):
with pytest.raises(ValueError):
adapt_batch_size(mocker.MagicMock(), default_bs, 1000)


@pytest.mark.parametrize("trainset_size", [-1, 0])
def test_adapt_batch_size_wrong_trainset_size(mocker, trainset_size):
with pytest.raises(ValueError):
adapt_batch_size(mocker.MagicMock(), 8, trainset_size)

0 comments on commit 37ec5d7

Please sign in to comment.