From 308be34798920973b8cc800c400a374ac5a31da5 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Mon, 6 May 2024 11:33:24 +0530 Subject: [PATCH] Improvements --- velox/common/encode/Base32.cpp | 224 +++++++++++------------ velox/common/encode/tests/Base32Test.cpp | 50 +++-- velox/common/encode/tests/Base64Test.cpp | 28 +-- 3 files changed, 149 insertions(+), 153 deletions(-) diff --git a/velox/common/encode/Base32.cpp b/velox/common/encode/Base32.cpp index dcbfd46725c5c..39d7ccb8494f5 100644 --- a/velox/common/encode/Base32.cpp +++ b/velox/common/encode/Base32.cpp @@ -180,117 +180,117 @@ template } size_t Base32::calculateDecodedSize(const char* data, size_t& size) { - if (size == 0) { - return 0; - } - - // Check if the input data is padded - if (isPadded(data, size)) { - /// If padded, ensure that the string length is a multiple of the encoded - /// block size. - if (size % kEncodedBlockSize != 0) { - throw EncoderException( - "Base32::decode() - invalid input string: " - "string length is not a multiple of 8."); - } - - auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize; - auto padding = countPadding(data, size); - size -= padding; - - // Adjust the needed size for padding. - return needed - - ceil((padding * kBinaryBlockSize) / - static_cast(kEncodedBlockSize)); - } else { - // If not padded, calculate extra bytes, if any. - auto extra = size % kEncodedBlockSize; - auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize; - - // Adjust the needed size for extra bytes, if present. - if (extra) { - if ((extra == 6) || (extra == 3) || (extra == 1)) { - throw EncoderException( - "Base32::decode() - invalid input string: " - "string length cannot be 6, 3 or 1 more than a multiple of 8."); - } - needed += (extra * kBinaryBlockSize) / kEncodedBlockSize; - } - - return needed; - } - } - - size_t - Base32::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { - return decodeImpl(src, src_len, dst, dst_len, kBase32ReverseIndexTable); - } - - size_t Base32::decodeImpl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len, - const ReverseIndex& reverse_lookup) { - if (!src_len) { - return 0; - } - - auto needed = calculateDecodedSize(src, src_len); - if (dst_len < needed) { - throw EncoderException( - "Base32::decode() - invalid output string: " - "output string is too small."); - } - - // Handle full groups of 8 characters. - for (; src_len > 8; src_len -= 8, src += 8, dst += 5) { - /// Each character of the 8 bytes encode 5 bits of the original, grab each - /// with the appropriate shifts to rebuild the original and then split that - /// back into the original 8 bit bytes. - uint64_t last = - (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | - (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30) | - (baseReverseLookup(kBase, src[2], reverse_lookup) << 25) | - (baseReverseLookup(kBase, src[3], reverse_lookup) << 20) | - (baseReverseLookup(kBase, src[4], reverse_lookup) << 15) | - (baseReverseLookup(kBase, src[5], reverse_lookup) << 10) | - (baseReverseLookup(kBase, src[6], reverse_lookup) << 5) | - baseReverseLookup(kBase, src[7], reverse_lookup); - dst[0] = (last >> 32) & 0xff; - dst[1] = (last >> 24) & 0xff; - dst[2] = (last >> 16) & 0xff; - dst[3] = (last >> 8) & 0xff; - dst[4] = last & 0xff; - } - - /// Handle the last 2, 4, 5, 7 or 8 characters. This is similar to the above, - /// but the last characters may or may not exist. - DCHECK(src_len >= 2); - uint64_t last = - (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | - (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30); - dst[0] = (last >> 32) & 0xff; - if (src_len > 2) { - last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 25; - last |= baseReverseLookup(kBase, src[3], reverse_lookup) << 20; - dst[1] = (last >> 24) & 0xff; - if (src_len > 4) { - last |= baseReverseLookup(kBase, src[4], reverse_lookup) << 15; - dst[2] = (last >> 16) & 0xff; - if (src_len > 5) { - last |= baseReverseLookup(kBase, src[5], reverse_lookup) << 10; - last |= baseReverseLookup(kBase, src[6], reverse_lookup) << 5; - dst[3] = (last >> 8) & 0xff; - if (src_len > 7) { - last |= baseReverseLookup(kBase, src[7], reverse_lookup); - dst[4] = last & 0xff; - } - } - } - } - - return needed; - } + if (size == 0) { + return 0; + } + + // Check if the input data is padded + if (isPadded(data, size)) { + /// If padded, ensure that the string length is a multiple of the encoded + /// block size. + if (size % kEncodedBlockSize != 0) { + throw EncoderException( + "Base32::decode() - invalid input string: " + "string length is not a multiple of 8."); + } + + auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize; + auto padding = countPadding(data, size); + size -= padding; + + // Adjust the needed size for padding. + return needed - + ceil((padding * kBinaryBlockSize) / + static_cast(kEncodedBlockSize)); + } else { + // If not padded, calculate extra bytes, if any. + auto extra = size % kEncodedBlockSize; + auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize; + + // Adjust the needed size for extra bytes, if present. + if (extra) { + if ((extra == 6) || (extra == 3) || (extra == 1)) { + throw EncoderException( + "Base32::decode() - invalid input string: " + "string length cannot be 6, 3 or 1 more than a multiple of 8."); + } + needed += (extra * kBinaryBlockSize) / kEncodedBlockSize; + } + + return needed; + } +} + +size_t +Base32::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { + return decodeImpl(src, src_len, dst, dst_len, kBase32ReverseIndexTable); +} + +size_t Base32::decodeImpl( + const char* src, + size_t src_len, + char* dst, + size_t dst_len, + const ReverseIndex& reverse_lookup) { + if (!src_len) { + return 0; + } + + auto needed = calculateDecodedSize(src, src_len); + if (dst_len < needed) { + throw EncoderException( + "Base32::decode() - invalid output string: " + "output string is too small."); + } + + // Handle full groups of 8 characters. + for (; src_len > 8; src_len -= 8, src += 8, dst += 5) { + /// Each character of the 8 bytes encode 5 bits of the original, grab each + /// with the appropriate shifts to rebuild the original and then split that + /// back into the original 8 bit bytes. + uint64_t last = + (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | + (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30) | + (baseReverseLookup(kBase, src[2], reverse_lookup) << 25) | + (baseReverseLookup(kBase, src[3], reverse_lookup) << 20) | + (baseReverseLookup(kBase, src[4], reverse_lookup) << 15) | + (baseReverseLookup(kBase, src[5], reverse_lookup) << 10) | + (baseReverseLookup(kBase, src[6], reverse_lookup) << 5) | + baseReverseLookup(kBase, src[7], reverse_lookup); + dst[0] = (last >> 32) & 0xff; + dst[1] = (last >> 24) & 0xff; + dst[2] = (last >> 16) & 0xff; + dst[3] = (last >> 8) & 0xff; + dst[4] = last & 0xff; + } + + /// Handle the last 2, 4, 5, 7 or 8 characters. This is similar to the above, + /// but the last characters may or may not exist. + DCHECK(src_len >= 2); + uint64_t last = + (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | + (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30); + dst[0] = (last >> 32) & 0xff; + if (src_len > 2) { + last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 25; + last |= baseReverseLookup(kBase, src[3], reverse_lookup) << 20; + dst[1] = (last >> 24) & 0xff; + if (src_len > 4) { + last |= baseReverseLookup(kBase, src[4], reverse_lookup) << 15; + dst[2] = (last >> 16) & 0xff; + if (src_len > 5) { + last |= baseReverseLookup(kBase, src[5], reverse_lookup) << 10; + last |= baseReverseLookup(kBase, src[6], reverse_lookup) << 5; + dst[3] = (last >> 8) & 0xff; + if (src_len > 7) { + last |= baseReverseLookup(kBase, src[7], reverse_lookup); + dst[4] = last & 0xff; + } + } + } + } + + return needed; +} } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/Base32Test.cpp b/velox/common/encode/tests/Base32Test.cpp index bf1973c56a39a..ee711e11dd1f1 100644 --- a/velox/common/encode/tests/Base32Test.cpp +++ b/velox/common/encode/tests/Base32Test.cpp @@ -40,39 +40,35 @@ TEST_F(Base32Test, calculateEncodedSizeProperSize) { } TEST_F(Base32Test, calculateDecodedSizeProperSize) { - size_t encoded_size{0}; + struct TestCase { + std::string encoded; + size_t initial_size; + int expected_decoded; + size_t expected_size; + }; - encoded_size = 8; - EXPECT_EQ(1, Base32::calculateDecodedSize("ME======", encoded_size)); - EXPECT_EQ(2, encoded_size); + std::vector test_cases = { + {"ME======", 8, 1, 2}, + {"ME", 2, 1, 2}, + {"MFRA====", 8, 2, 4}, + {"MFRGG===", 8, 3, 5}, + {"NBSWY3DPEB3W64TMMQ======", 24, 11, 18}, + {"NBSWY3DPEB3W64TMMQ", 18, 11, 18}}; - encoded_size = 2; - EXPECT_EQ(1, Base32::calculateDecodedSize("ME", encoded_size)); - EXPECT_EQ(2, encoded_size); + for (const auto& test : test_cases) { + size_t encoded_size = test.initial_size; + EXPECT_EQ( + test.expected_decoded, + Base32::calculateDecodedSize(test.encoded.c_str(), encoded_size)); + EXPECT_EQ(test.expected_size, encoded_size); + } +} - encoded_size = 9; +TEST_F(Base32Test, errorWhenDecodedStringPartiallyPadded) { + size_t encoded_size = 9; EXPECT_THROW( Base32::calculateDecodedSize("MFRA====", encoded_size), facebook::velox::encoding::EncoderException); - - encoded_size = 8; - EXPECT_EQ(2, Base32::calculateDecodedSize("MFRA====", encoded_size)); - EXPECT_EQ(4, encoded_size); - - encoded_size = 8; - EXPECT_EQ(3, Base32::calculateDecodedSize("MFRGG===", encoded_size)); - EXPECT_EQ(5, encoded_size); - - encoded_size = 24; - EXPECT_EQ( - 11, - Base32::calculateDecodedSize("NBSWY3DPEB3W64TMMQ======", encoded_size)); - EXPECT_EQ(18, encoded_size); - - encoded_size = 18; - EXPECT_EQ( - 11, Base32::calculateDecodedSize("NBSWY3DPEB3W64TMMQ", encoded_size)); - EXPECT_EQ(18, encoded_size); } } // namespace facebook::velox::encoding \ No newline at end of file diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 998b4d92540bf..4192431539af1 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -45,22 +45,22 @@ TEST_F(Base64Test, fromBase64) { EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA"))); } -struct TestCase { - std::string inputBase64; - size_t initialEncodedSize; - size_t expectedDecodedSize; - size_t expectedEncodedSizeAfter; -}; +TEST_F(Base64Test, calculateDecodedSizeProperSize) { + struct TestCase { + std::string inputBase64; + size_t initialEncodedSize; + size_t expectedDecodedSize; + size_t expectedEncodedSizeAfter; + }; -std::vector testCases{ - {"SGVsbG8sIFdvcmxkIQ==", 20, 13, 18}, - {"SGVsbG8sIFdvcmxkIQ", 18, 13, 18}, - {"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 32, 23, 31}, - {"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 23, 31}, - {"MTIzNDU2Nzg5MA==", 16, 10, 14}, - {"MTIzNDU2Nzg5MA", 14, 10, 14}}; + std::vector testCases{ + {"SGVsbG8sIFdvcmxkIQ==", 20, 13, 18}, + {"SGVsbG8sIFdvcmxkIQ", 18, 13, 18}, + {"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 32, 23, 31}, + {"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 23, 31}, + {"MTIzNDU2Nzg5MA==", 16, 10, 14}, + {"MTIzNDU2Nzg5MA", 14, 10, 14}}; -TEST_F(Base64Test, calculateDecodedSizeProperSize) { for (const auto& testCase : testCases) { size_t encodedSize = testCase.initialEncodedSize; size_t decodedSize =