Skip to content

Commit

Permalink
update ceil_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Jan 23, 2025
1 parent 1ec9cad commit c84a858
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
26 changes: 17 additions & 9 deletions onnxruntime/core/providers/cpu/nn/pool_attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,30 +150,30 @@ struct PoolAttributes {
case AutoPadType::VALID:
*pad_head = 0;
*pad_tail = 0;
*out_size = ComputeOutputSize(in_size, stride, kernel, 0, dilation);
*out_size = ComputeOutputSize(in_size, stride, kernel, 0, 0, dilation);
break;
case AutoPadType::SAME_LOWER: {
int64_t legacy_target_size = (in_size + stride - 1) / stride;
int64_t pad_needed = (legacy_target_size - 1) * stride + kernel - in_size;
*pad_head = (pad_needed + 1) / 2;
*pad_tail = pad_needed - *pad_head;
*out_size = ComputeOutputSize(in_size, stride, kernel, pad_needed, dilation);
*out_size = ComputeOutputSize(in_size, stride, kernel, *pad_head, *pad_tail, dilation);
break;
}
case AutoPadType::SAME_UPPER: {
int64_t legacy_target_size = (in_size + stride - 1) / stride;
int64_t pad_needed = (legacy_target_size - 1) * stride + kernel - in_size;
*pad_head = pad_needed / 2;
*pad_tail = pad_needed - *pad_head;
*out_size = ComputeOutputSize(in_size, stride, kernel, pad_needed, dilation);
*out_size = ComputeOutputSize(in_size, stride, kernel, *pad_head, *pad_tail, dilation);
break;
}
default: {
ORT_THROW("Unsupported AutoPad Type.");
}
}
} else {
*out_size = ComputeOutputSize(in_size, stride, kernel, *pad_head + *pad_tail, dilation);
*out_size = ComputeOutputSize(in_size, stride, kernel, *pad_head, *pad_tail, dilation);
}
}
#if defined(_MSC_VER) && !defined(__clang__)
Expand All @@ -184,13 +184,21 @@ struct PoolAttributes {
int64_t ComputeOutputSize(int64_t in_size,
int64_t stride,
int64_t kernel,
int64_t pad_needed,
int64_t pad_head,
int64_t pad_tail,
int64_t dilation) const {
if (ceil_mode == 0) {
return static_cast<int64_t>(static_cast<float>(in_size + pad_needed - dilation * (kernel - 1) - 1) / stride + 1);
int64_t numerator = in_size + pad_head + pad_tail - dilation * (kernel - 1) - 1;
int64_t out_size = numerator / stride + 1;

if (ceil_mode == 1) {
out_size = static_cast<int64_t>(std::ceil(static_cast<float>(numerator) / stride)) + 1;
// Ensure that the last pooling starts inside the image (at least 1 pixel)
// Reference: https://github.com/onnx/onnx/pull/5741
if ((out_size - 1) * stride >= in_size + pad_head) {
--out_size;
}
}
return static_cast<int64_t>(
std::ceil(static_cast<float>(in_size + pad_needed - dilation * (kernel - 1) - 1) / stride + 1));
return out_size;
}
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/onnx/TestCase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,7 @@ std::unique_ptr<std::set<BrokenTest>> GetBrokenTests(const std::string& provider
"output=Y:expected 1 (3f800000), got 4 (40800000), diff: 3, tol=0.002 idx=24. 13 of 49 differ. CPU test passed."});
broken_tests->insert({"convtranspose_group_2", "Segmentation fault (core dumped). CPU test passed."});
broken_tests->insert({"convtranspose_group_2_image_3", "Segmentation fault (core dumped). CPU test passed."});
broken_tests->insert({"averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True", "TODO: enable this in the next ONNX release."});
}

#ifdef DISABLE_CONTRIB_OPS
Expand Down

0 comments on commit c84a858

Please sign in to comment.