Skip to content

Commit

Permalink
Tensor Iterator loop unrolling (#17667)
Browse files Browse the repository at this point in the history
Summary:
Modified Tensor Iterator gpu reduction kernel.
Creating multiple accumulator during thread reduce, this removes data dependency
between unrolled loops, expose instruction level parallelism that benefits
latency bounded kernels (e.g. welford used by `torch.std`)

This approach increases register usage, such that we need to tune unrolling
factors to prevent register spilling.
Current implementation tune down the unrolling factor to 2 for welford (register
heavy kernel), while keeping it unchanged (4) for the rest of reduction kernels.
Pull Request resolved: pytorch/pytorch#17667

Differential Revision: D14368325

Pulled By: umanwizard

fbshipit-source-id: 9d64c0dccabdb1b7c3922a6557224af704a1974e
  • Loading branch information
jjsjann123 authored and facebook-github-bot committed Mar 14, 2019
1 parent 135d683 commit 6343113
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 50 deletions.
85 changes: 37 additions & 48 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,6 @@ __device__ void strided_iterate(func_t f, index_t begin, index_t end, index_t st
}
}

template <int vt, typename index_t, typename type_t, typename foo_t>
__device__ Array<type_t, vt> load_memory(const type_t* in, index_t begin, index_t end, index_t stride, foo_t foo) {
Array<type_t, vt> res;
strided_iterate<vt>([&](index_t i, index_t idx) {
res[i] = in[foo(idx)];
}, begin, end, stride);
return res;
}

template <int vt, typename index_t, typename type_t>
__device__ Array<type_t, vt> load_memory(const type_t* in, index_t begin, index_t end, index_t stride) {
return load_memory<vt, index_t>(in, begin, end, stride, [](index_t idx) { return idx; });
}

template <typename out_scalar_t, typename func_t>
struct func_wrapper_t {
using arg_t = typename binary_function_traits<func_t>::arg2_t;
Expand All @@ -271,15 +257,14 @@ func_wrapper_t<scalar_t, func_t> func_wrapper(const func_t& op) {
return func_wrapper_t<scalar_t, func_t> { op };
}

template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t>
template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t, int vt0=4>
struct ReduceOp {
using traits = binary_function_traits<decltype(&ops_t::reduce)>;
using arg_t = typename std::remove_const<typename std::remove_reference<typename traits::arg1_t>::type>::type;

using InputCalculator = OffsetCalculator<1, index_t>;
using OutputCalculator = OffsetCalculator<2, index_t>;

static constexpr int vt0 = 4;
static constexpr bool can_accumulate_in_output =
std::is_convertible<arg_t, out_scalar_t>::value
&& std::is_convertible<out_scalar_t, arg_t>::value;
Expand Down Expand Up @@ -365,40 +350,44 @@ struct ReduceOp {
}
}

C10_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, index_t offset) const {
index_t end = config.num_inputs;
index_t stride = input_calc.strides_[0][0] / sizeof(scalar_t);
if (input_calc.dims == 1) {
return load_memory<vt0, index_t>(data, offset, end, config.step_input, [&](index_t idx) {
return idx * stride;
});
} else {
return load_memory<vt0, index_t>(data, offset, end, config.step_input, [&](index_t idx) {
return input_calc.get(idx)[0] / sizeof(scalar_t);
});
}
}

C10_DEVICE arg_t thread_reduce_once(const scalar_t* data, index_t offset) const {
auto values = load_inputs(data, offset);

arg_t value = ident;
strided_iterate<vt0, index_t>([&](index_t i, index_t idx) {
value = ops.reduce(value, values[i]);
}, offset, config.num_inputs, config.step_input);

return value;
}

C10_DEVICE arg_t thread_reduce(const scalar_t* data) const {
arg_t value = ident;
index_t idx = config.input_idx();
// Multiple accumulators to remove dependency between unrolled loops.
arg_t value_list[vt0];
#pragma unroll
for (int i = 0; i < vt0; i++) {
value_list[i] = ident;
}
index_t end = config.num_inputs;
index_t stride = config.step_input;
index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);

// Reducing layers of function calls so compiler could do proper loop unroll
// that exposes instruction level parallelism.
while (idx < config.num_inputs) {
arg_t next = thread_reduce_once(data, idx);
value = ops.combine(value, next);
// load input
Array<scalar_t, vt0> values;
if (input_calc.dims == 1) {
strided_iterate<vt0>([&](index_t i, index_t idx) {
values[i] = data[idx * element_stride];
}, idx, end, stride);
} else {
strided_iterate<vt0>([&](index_t i, index_t idx) {
values[i] = data[input_calc.get(idx)[0] / sizeof(scalar_t)];
}, idx, end, stride);
}
// compute
strided_iterate<vt0, index_t>([&](index_t i, index_t idx) {
value_list[i] = ops.reduce(value_list[i], values[i]);
}, idx, config.num_inputs, config.step_input);
// step offset
idx += config.step_input * vt0;
}
return value;
#pragma unroll
for (int i = 1; i < vt0; i++) {
value_list[0] = ops.combine(value_list[0], value_list[i]);
}
return value_list[0];
}

C10_DEVICE arg_t block_x_reduce(arg_t value, char* shared_memory) const {
Expand Down Expand Up @@ -598,7 +587,7 @@ struct AccumulationBuffer {
at::DataPtr buffer_;
};

template <typename scalar_t, typename out_scalar_t, typename ops_t, typename ident_t=double>
template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
AccumulationBuffer* acc_buf_ptr=nullptr) {
AT_ASSERT(iter.numel() > 0 && iter.ntensors() == 2);
Expand Down Expand Up @@ -633,7 +622,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id

if (!can_use_32bit_indexing) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_reduce_kernel<scalar_t, out_scalar_t>(sub_iter, ops, ident, acc_buf_ptr);
gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident, acc_buf_ptr);
}
return;
}
Expand Down Expand Up @@ -710,7 +699,7 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
AT_ASSERT(can_use_32bit_indexing);
auto output_calc = make_output_calculator<uint32_t>(iter);
auto input_calc = make_input_calculator<uint32_t>(iter);
auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t>(
auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t, vt0>(
ops,
config,
input_calc,
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/cuda/ReduceOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ void sum_kernel_impl(TensorIterator& iter) {

template <typename scalar_t>
void std_var_kernel_impl(TensorIterator& iter, bool unbiased, bool take_sqrt) {
gpu_reduce_kernel<scalar_t, scalar_t>(iter, WelfordOps<scalar_t, scalar_t, int32_t, float> { unbiased, take_sqrt }, WelfordData<scalar_t, int32_t, float> {});
// reducing unrolling factor to 2 for welford kernel
// This is necessary to lower register usage that leads to register spills.
gpu_reduce_kernel<scalar_t, scalar_t, 2>(iter, WelfordOps<scalar_t, scalar_t, int32_t, float> { unbiased, take_sqrt }, WelfordData<scalar_t, int32_t, float> {});
}

template <>
void std_var_kernel_impl<at::Half>(TensorIterator& iter, bool unbiased, bool take_sqrt) {
gpu_reduce_kernel<at::Half, at::Half>(iter, WelfordOps<at::Half, float, int32_t, float> { unbiased, take_sqrt }, WelfordData<float, int32_t, float> {});
// reducing unrolling factor to 2 for welford kernel
// This is necessary to lower register usage that leads to register spills.
gpu_reduce_kernel<at::Half, at::Half, 2>(iter, WelfordOps<at::Half, float, int32_t, float> { unbiased, take_sqrt }, WelfordData<float, int32_t, float> {});
}

template <typename scalar_t, typename acc_t=scalar_t>
Expand Down

0 comments on commit 6343113

Please sign in to comment.