Skip to content

Commit

Permalink
Merge branch 'main' into py38-compat
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored May 9, 2023
2 parents 509af04 + 2caa84f commit dc047dc
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 50 deletions.
38 changes: 2 additions & 36 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
repository: pytorch/vision
upload-artifact: docs
script: |
set -euo pipefail
Expand All @@ -40,42 +39,9 @@ jobs:
pip install --progress-bar=off -r requirements.txt
echo '::endgroup::'
echo '::group::Build HTML docs'
# The runner does not have sufficient memory to run with as many processes as their are
# The runner does not have sufficient memory to run with as many processes as there are
# cores (`-j auto`). Thus, we limit to a single process (`-j 1`) here.
sed -i -e 's/-j auto/-j 1/' Makefile
make html
echo '::endgroup::'
mv build/html "${RUNNER_ARTIFACT_DIR}"
upload-preview:
if: github.event_name == 'pull_request'
needs: [build]
runs-on: [self-hosted, linux.2xlarge]
steps:
- uses: actions/download-artifact@v3
with:
name: docs

- name: Upload docs preview
uses: seemethere/upload-artifact-s3@v5
with:
retention-days: 14
s3-bucket: doc-previews
if-no-files-found: error
path: html
s3-prefix: pytorch/vision/${{ github.event.pull_request.number }}

# The upload below duplicates the upload from above, but to a different path. This is needed since we are in the
# process of changing the path, but want to keep the disruption to a minimum.
# See https://github.com/pytorch/test-infra/issues/3894
# After a grace period, we can delete this again
- name: Upload docs preview
uses: seemethere/upload-artifact-s3@v5
with:
retention-days: 14
s3-bucket: doc-previews
if-no-files-found: error
path: html
s3-prefix: pytorch/pytorch/vision/${{ github.event.pull_request.number }}
mv build/html/* "${RUNNER_DOCS_DIR}"
12 changes: 2 additions & 10 deletions .github/workflows/test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ jobs:
with:
repository: pytorch/vision
runner: ${{ matrix.runner }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
timeout: 120
script: |
set -euxo pipefail
Expand All @@ -38,14 +40,4 @@ jobs:
export GPU_ARCH_TYPE=${{ matrix.gpu-arch-type }}
export GPU_ARCH_VERSION=${{ matrix.gpu-arch-version }}
# TODO: Port this to pytorch/test-infra/.github/workflows/windows_job.yml
export PATH="/c/Jenkins/Miniconda3/Scripts:${PATH}"
if [[ $GPU_ARCH_TYPE == 'cuda' ]]; then
# TODO: This should be handled by the generic Windows job the same as its done by the generic Linux job
export CUDA_HOME="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ matrix.gpu-arch-version }}"
export CUDA_PATH="${CUDA_HOME}"
export PATH="${CUDA_PATH}/bin:${PATH}"
fi
./.github/scripts/unittest.sh
35 changes: 35 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2137,3 +2137,38 @@ def test_no_warnings_v1_namespace():
from torchvision.datasets import ImageNet
"""
assert_run_python_script(textwrap.dedent(source))


class TestLambda:
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])

@inputs
def test_default(self, input):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

transform = transforms.Lambda(was_applied_fn)

transform(input)

assert was_applied

@inputs
def test_with_types(self, input):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

types = (torch.Tensor, np.ndarray)
transform = transforms.Lambda(was_applied_fn, *types)

transform(input)

assert was_applied is isinstance(input, types)
3 changes: 3 additions & 0 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
old_height, old_width = spatial_size
new_height, new_width = F._geometry._compute_resized_output_size(spatial_size, size=size, max_size=max_size)

if (old_height, old_width) == (new_height, new_width):
return bounding_box, (old_height, old_width)

affine_matrix = np.array(
[
[new_width / old_width, 0, 0],
Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ class Lambda(Transform):
lambd (function): Lambda/function to be used for transform.
"""

_transformed_types = (object,)

def __init__(self, lambd: Callable[[Any], Any], *types: Type):
super().__init__()
self.lambd = lambd
self.types = types or (object,)
self.types = types or self._transformed_types

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, self.types):
Expand Down
22 changes: 19 additions & 3 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def resize_image_tensor(
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)

if image.numel() > 0:
if (new_height, new_width) == (old_height, old_width):
return image
elif image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width)

dtype = image.dtype
Expand Down Expand Up @@ -210,9 +212,19 @@ def resize_image_pil(
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
old_height, old_width = image.height, image.width
new_height, new_width = _compute_resized_output_size(
(old_height, old_width),
size=size, # type: ignore[arg-type]
max_size=max_size,
)

interpolation = _check_interpolation(interpolation)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])

if (new_height, new_width) == (old_height, old_width):
return image

return image.resize((new_width, new_height), resample=pil_modes_mapping[interpolation])


def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
Expand All @@ -235,6 +247,10 @@ def resize_bounding_box(
) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = spatial_size
new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size)

if (new_height, new_width) == (old_height, old_width):
return bounding_box, spatial_size

w_ratio = new_width / old_width
h_ratio = new_height / old_height
ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device)
Expand Down

0 comments on commit dc047dc

Please sign in to comment.