Skip to content

Commit

Permalink
ARROW-3707: [C++] Fix test regression with zstd 1.3.7
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Nov 7, 2018
1 parent c303cc9 commit 8a2488d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
56 changes: 31 additions & 25 deletions cpp/src/arrow/util/compression_zstd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <zstd.h>

#include "arrow/status.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"

using std::size_t;
Expand All @@ -34,6 +35,12 @@ namespace util {
// XXX level = 1 probably doesn't compress very much
constexpr int kZSTDDefaultCompressionLevel = 1;

static Status ZSTDError(size_t ret, const char* prefix_msg) {
std::stringstream ss;
ss << prefix_msg << ZSTD_getErrorName(ret);
return Status::IOError(ss.str());
}

// ----------------------------------------------------------------------
// ZSTD decompressor implementation

Expand All @@ -47,7 +54,7 @@ class ZSTDDecompressor : public Decompressor {
finished_ = false;
size_t ret = ZSTD_initDStream(stream_);
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "zstd init failed: ");
return ZSTDError(ret, "ZSTD init failed: ");
} else {
return Status::OK();
}
Expand All @@ -69,7 +76,7 @@ class ZSTDDecompressor : public Decompressor {
size_t ret;
ret = ZSTD_decompressStream(stream_, &out_buf, &in_buf);
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "zstd decompress failed: ");
return ZSTDError(ret, "ZSTD decompress failed: ");
}
*bytes_read = static_cast<int64_t>(in_buf.pos);
*bytes_written = static_cast<int64_t>(out_buf.pos);
Expand All @@ -81,12 +88,6 @@ class ZSTDDecompressor : public Decompressor {
bool IsFinished() override { return finished_; }

protected:
Status ZSTDError(size_t ret, const char* prefix_msg) {
std::stringstream ss;
ss << prefix_msg << ZSTD_getErrorName(ret);
return Status::IOError(ss.str());
}

ZSTD_DStream* stream_;
bool finished_;
};
Expand All @@ -103,7 +104,7 @@ class ZSTDCompressor : public Compressor {
Status Init() {
size_t ret = ZSTD_initCStream(stream_, kZSTDDefaultCompressionLevel);
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "zstd init failed: ");
return ZSTDError(ret, "ZSTD init failed: ");
} else {
return Status::OK();
}
Expand All @@ -119,12 +120,6 @@ class ZSTDCompressor : public Compressor {
bool* should_retry) override;

protected:
Status ZSTDError(size_t ret, const char* prefix_msg) {
std::stringstream ss;
ss << prefix_msg << ZSTD_getErrorName(ret);
return Status::IOError(ss.str());
}

ZSTD_CStream* stream_;
};

Expand All @@ -144,7 +139,7 @@ Status ZSTDCompressor::Compress(int64_t input_len, const uint8_t* input,
size_t ret;
ret = ZSTD_compressStream(stream_, &out_buf, &in_buf);
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "zstd compress failed: ");
return ZSTDError(ret, "ZSTD compress failed: ");
}
*bytes_read = static_cast<int64_t>(in_buf.pos);
*bytes_written = static_cast<int64_t>(out_buf.pos);
Expand All @@ -162,7 +157,7 @@ Status ZSTDCompressor::Flush(int64_t output_len, uint8_t* output, int64_t* bytes
size_t ret;
ret = ZSTD_flushStream(stream_, &out_buf);
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "zstd flush failed: ");
return ZSTDError(ret, "ZSTD flush failed: ");
}
*bytes_written = static_cast<int64_t>(out_buf.pos);
*should_retry = ret > 0;
Expand All @@ -180,7 +175,7 @@ Status ZSTDCompressor::End(int64_t output_len, uint8_t* output, int64_t* bytes_w
size_t ret;
ret = ZSTD_endStream(stream_, &out_buf);
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "zstd end failed: ");
return ZSTDError(ret, "ZSTD end failed: ");
}
*bytes_written = static_cast<int64_t>(out_buf.pos);
*should_retry = ret > 0;
Expand All @@ -206,10 +201,20 @@ Status ZSTDCodec::MakeDecompressor(std::shared_ptr<Decompressor>* out) {

Status ZSTDCodec::Decompress(int64_t input_len, const uint8_t* input, int64_t output_len,
uint8_t* output_buffer) {
int64_t decompressed_size =
ZSTD_decompress(output_buffer, static_cast<size_t>(output_len), input,
static_cast<size_t>(input_len));
if (decompressed_size != output_len) {
if (output_buffer == nullptr) {
// We may pass a NULL 0-byte output buffer but some zstd versions demand
// a valid pointer: https://github.com/facebook/zstd/issues/1385
static uint8_t empty_buffer[1];
DCHECK_EQ(output_len, 0);
output_buffer = empty_buffer;
}

size_t ret = ZSTD_decompress(output_buffer, static_cast<size_t>(output_len), input,
static_cast<size_t>(input_len));
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "ZSTD decompression failed: ");
}
if (static_cast<int64_t>(ret) != output_len) {
return Status::IOError("Corrupt ZSTD compressed data.");
}
return Status::OK();
Expand All @@ -223,12 +228,13 @@ int64_t ZSTDCodec::MaxCompressedLen(int64_t input_len,
Status ZSTDCodec::Compress(int64_t input_len, const uint8_t* input,
int64_t output_buffer_len, uint8_t* output_buffer,
int64_t* output_length) {
*output_length =
size_t ret =
ZSTD_compress(output_buffer, static_cast<size_t>(output_buffer_len), input,
static_cast<size_t>(input_len), kZSTDDefaultCompressionLevel);
if (ZSTD_isError(*output_length)) {
return Status::IOError("ZSTD compression failure.");
if (ZSTD_isError(ret)) {
return ZSTDError(ret, "ZSTD compression failed: ");
}
*output_length = static_cast<int64_t>(ret);
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/thirdparty/versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ RAPIDJSON_VERSION=v1.1.0
SNAPPY_VERSION=1.1.3
THRIFT_VERSION=0.11.0
ZLIB_VERSION=1.2.8
ZSTD_VERSION=v1.2.0
ZSTD_VERSION=v1.3.7
RE2_VERSION=2018-10-01

0 comments on commit 8a2488d

Please sign in to comment.