-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathdata_container.py
91 lines (70 loc) · 2.57 KB
/
data_container.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from typing import Callable, Type, Union
import numpy as np
import torch
def assert_tensor_type(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor):
raise AttributeError(
f'{args[0].__class__.__name__} has no attribute '
f'{func.__name__} for type {args[0].datatype}')
return func(*args, **kwargs)
return wrapper
class DataContainer:
"""A container for any type of objects.
Typically tensors will be stacked in the collate function and sliced along
some dimension in the scatter function. This behavior has some limitations.
1. All tensors have to be the same size.
2. Types are limited (numpy array or Tensor).
We design `DataContainer` and `MMDataParallel` to overcome these
limitations. The behavior can be either of the following.
- copy to GPU, pad all tensors to the same size and stack them
- copy to GPU without stacking
- leave the objects as is and pass it to the model
- pad_dims specifies the number of last few dimensions to do padding
"""
def __init__(self,
data: Union[torch.Tensor, np.ndarray],
stack: bool = False,
padding_value: int = 0,
cpu_only: bool = False,
pad_dims: int = 2):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims
def __repr__(self) -> str:
return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self) -> int:
return len(self._data)
@property
def data(self) -> Union[torch.Tensor, np.ndarray]:
return self._data
@property
def datatype(self) -> Union[Type, str]:
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self) -> bool:
return self._cpu_only
@property
def stack(self) -> bool:
return self._stack
@property
def padding_value(self) -> int:
return self._padding_value
@property
def pad_dims(self) -> int:
return self._pad_dims
@assert_tensor_type
def size(self, *args, **kwargs) -> torch.Size:
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self) -> int:
return self.data.dim()