-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This reverts commit 720adf7.
- Loading branch information
Showing
2 changed files
with
131 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,135 +1,142 @@ | ||
import math | ||
from typing import Tuple | ||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from typing import Callable | ||
from typing import Iterable | ||
from typing import Union | ||
|
||
import tensorflow as tf | ||
|
||
from ..utils import normalize_data_format | ||
from ..utils import normalize_tuple | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="gcvit") | ||
class AdaptivePooling2D(tf.keras.layers.Layer): | ||
"""Parent class for 2D pooling layers with adaptive kernel size. | ||
Implementation is based on tensorflow-addons: | ||
https://github.com/tensorflow/addons/blob/v0.17.0/tensorflow_addons/layers/adaptive_pooling.py#LL157C1-L234C41 | ||
This class only exists for code reuse. It will never be an exposed API. | ||
Args: | ||
reduce_function: The reduction method to apply, e.g. `tf.reduce_max`. | ||
output_size: An integer or tuple/list of 2 integers specifying | ||
(pooled_rows, pooled_cols). The new size of output channels. | ||
data_format: A string, | ||
one of `channels_last` (default) or `channels_first`. | ||
The ordering of the dimensions in the inputs. | ||
`channels_last` corresponds to inputs with shape | ||
`(batch, height, width, channels)` while `channels_first` | ||
corresponds to inputs with shape | ||
`(batch, channels, height, width)`. | ||
""" | ||
|
||
class AdaptiveAveragePooling2D(tf.keras.layers.Layer): | ||
def __init__( | ||
self, | ||
output_size: Tuple[int, int], | ||
input_ordering: str = "NHWC", | ||
**kwargs | ||
reduce_function: Callable, | ||
output_size: Union[int, Iterable[int]], | ||
data_format=None, | ||
**kwargs, | ||
): | ||
self.data_format = normalize_data_format(data_format) | ||
self.reduce_function = reduce_function | ||
self.output_size = normalize_tuple(output_size, 2, "output_size") | ||
super().__init__(**kwargs) | ||
self.output_size = output_size | ||
self.input_ordering = input_ordering | ||
if input_ordering not in ("NCHW", "NHWC"): | ||
raise ValueError( | ||
"Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!" | ||
) | ||
self.h_axis = input_ordering.index("H") | ||
self.w_axis = input_ordering.index("W") | ||
|
||
def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool): | ||
# Figure out which axis we're pooling on | ||
if h_pooling: | ||
axis = self.h_axis | ||
output_dim = self.output_size[0] | ||
else: | ||
axis = self.w_axis | ||
output_dim = self.output_size[1] | ||
input_dim = inputs.shape[axis] | ||
|
||
# Figure out the potential pooling windows | ||
# This is the key idea - the torch op will always use only two | ||
# consecutive pooling window sizes, like 3 and 4. Therefore, | ||
# if we pool with both possible sizes, we simply need to gather | ||
# the 'correct' pool at each position to reimplement the torch op. | ||
small_window = math.ceil(input_dim / output_dim) | ||
big_window = small_window + 1 | ||
if h_pooling: | ||
output_dim = self.output_size[0] | ||
small_window_shape = (small_window, 1) | ||
big_window_shape = (big_window, 1) | ||
else: | ||
output_dim = self.output_size[1] | ||
small_window_shape = (1, small_window) | ||
big_window_shape = (1, big_window) | ||
|
||
# For integer resizes, we can take a very quick shortcut | ||
if input_dim % output_dim == 0: | ||
return tf.nn.avg_pool2d( | ||
inputs, | ||
ksize=small_window_shape, | ||
strides=small_window_shape, | ||
padding="VALID", | ||
data_format=self.input_ordering, | ||
) | ||
|
||
# For non-integer resizes, we pool with both possible window sizes | ||
# and concatenate them | ||
small_pool = tf.nn.avg_pool2d( | ||
inputs, | ||
ksize=small_window_shape, | ||
strides=1, | ||
padding="VALID", | ||
data_format=self.input_ordering, | ||
) | ||
big_pool = tf.nn.avg_pool2d( | ||
inputs, | ||
ksize=big_window_shape, | ||
strides=1, | ||
padding="VALID", | ||
data_format=self.input_ordering, | ||
) | ||
both_pool = tf.concat([small_pool, big_pool], axis=axis) | ||
|
||
# We compute vectors of the start and end positions | ||
# for each pooling window | ||
# Each (start, end) pair here corresponds to a single output position | ||
window_starts = tf.math.floor( | ||
(tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim | ||
) | ||
window_starts = tf.cast(window_starts, tf.int64) | ||
window_ends = tf.math.ceil( | ||
(tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim) | ||
/ output_dim | ||
) | ||
window_ends = tf.cast(window_ends, tf.int64) | ||
|
||
# pool_selector is a boolean array of shape (output_dim,) | ||
# where 1 indicates that output position | ||
# has a big receptive field and 0 indicates that that output | ||
# position has a small receptive field | ||
pool_selector = tf.cast( | ||
window_ends - window_starts - small_window, tf.bool | ||
) | ||
|
||
# Since we concatenated the small and big pools, we need to do a bit of | ||
# pointer arithmetic to get the indices of the big pools | ||
small_indices = window_starts | ||
big_indices = window_starts + small_pool.shape[axis] | ||
|
||
# Finally, we use the pool_selector to generate a list of indices, | ||
# one per output position | ||
gather_indices = tf.where(pool_selector, big_indices, small_indices) | ||
|
||
# Gathering from those indices yields the final, correct pooling | ||
return tf.gather(both_pool, gather_indices, axis=axis) | ||
|
||
def call(self, inputs: tf.Tensor): | ||
if self.input_ordering == "NHWC": | ||
input_shape = inputs.shape[1:3] | ||
def call(self, inputs, *args): | ||
h_bins = self.output_size[0] | ||
w_bins = self.output_size[1] | ||
if self.data_format == "channels_last": | ||
split_cols = tf.split(inputs, h_bins, axis=1) | ||
split_cols = tf.stack(split_cols, axis=1) | ||
split_rows = tf.split(split_cols, w_bins, axis=3) | ||
split_rows = tf.stack(split_rows, axis=3) | ||
out_vect = self.reduce_function(split_rows, axis=[2, 4]) | ||
else: | ||
input_shape = inputs.shape[2:] | ||
|
||
if ( | ||
input_shape[0] % self.output_size[0] == 0 | ||
and input_shape[1] % self.output_size[1] == 0 | ||
): | ||
# If we're resizing by an integer factor on both dimensions, | ||
# we can take a very quick shortcut. | ||
h_resize = int(input_shape[0] // self.output_size[0]) | ||
w_resize = int(input_shape[1] // self.output_size[1]) | ||
return tf.nn.avg_pool2d( | ||
inputs, | ||
ksize=(h_resize, w_resize), | ||
strides=(h_resize, w_resize), | ||
padding="VALID", | ||
data_format=self.input_ordering, | ||
split_cols = tf.split(inputs, h_bins, axis=2) | ||
split_cols = tf.stack(split_cols, axis=2) | ||
split_rows = tf.split(split_cols, w_bins, axis=4) | ||
split_rows = tf.stack(split_rows, axis=4) | ||
out_vect = self.reduce_function(split_rows, axis=[3, 5]) | ||
return out_vect | ||
|
||
def compute_output_shape(self, input_shape): | ||
input_shape = tf.TensorShape(input_shape).as_list() | ||
if self.data_format == "channels_last": | ||
shape = tf.TensorShape( | ||
[ | ||
input_shape[0], | ||
self.output_size[0], | ||
self.output_size[1], | ||
input_shape[3], | ||
] | ||
) | ||
else: | ||
# If we can't take the shortcut, we do a 1D pool on each axis | ||
h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True) | ||
return self.pseudo_1d_pool(h_pooled, h_pooling=False) | ||
shape = tf.TensorShape( | ||
[ | ||
input_shape[0], | ||
input_shape[1], | ||
self.output_size[0], | ||
self.output_size[1], | ||
] | ||
) | ||
|
||
return shape | ||
|
||
def get_config(self): | ||
config = { | ||
"output_size": self.output_size, | ||
"data_format": self.data_format, | ||
} | ||
base_config = super().get_config() | ||
return {**base_config, **config} | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="gcvit") | ||
class AdaptiveAveragePooling2D(AdaptivePooling2D): | ||
"""Average Pooling with adaptive kernel size. | ||
Class is borrowed from tensorflow-addons: | ||
https://github.com/tensorflow/addons/blob/v0.17.0/tensorflow_addons/layers/adaptive_pooling.py#L238 | ||
Args: | ||
output_size: Tuple of integers specifying (pooled_rows, pooled_cols). | ||
The new size of output channels. | ||
data_format: A string, | ||
one of `channels_last` (default) or `channels_first`. | ||
The ordering of the dimensions in the inputs. | ||
`channels_last` corresponds to inputs with shape | ||
`(batch, height, width, channels)` while `channels_first` | ||
corresponds to inputs with shape `(batch, channels, height, width)`. | ||
Input shape: | ||
- If `data_format='channels_last'`: | ||
4D tensor with shape `(batch_size, height, width, channels)`. | ||
- If `data_format='channels_first'`: | ||
4D tensor with shape `(batch_size, channels, height, width)`. | ||
Output shape: | ||
- If `data_format='channels_last'`: | ||
4D tensor with shape `(batch_size, pooled_rows, pooled_cols, channels)`. | ||
- If `data_format='channels_first'`: | ||
4D tensor with shape `(batch_size, channels, pooled_rows, pooled_cols)`. | ||
""" | ||
|
||
def __init__( | ||
self, output_size: Union[int, Iterable[int]], data_format=None, **kwargs | ||
): | ||
super().__init__(tf.reduce_mean, output_size, data_format, **kwargs) |