From 36d5b8c0cee70585c3c97a9b2e2f480c79fe787b Mon Sep 17 00:00:00 2001
From: Sen Huang <senhuang96@fb.com>
Date: Mon, 8 Mar 2021 12:23:57 -0800
Subject: [PATCH] Add seekable roundtrip fuzzer

---
 build/cmake/lib/CMakeLists.txt                |  1 +
 contrib/seekable_format/examples/Makefile     |  2 +-
 .../examples/parallel_compression.c           |  1 -
 contrib/seekable_format/tests/Makefile        |  2 +-
 contrib/seekable_format/zstdseek_compress.c   |  7 +-
 contrib/seekable_format/zstdseek_decompress.c |  5 +-
 tests/fuzz/.gitignore                         |  1 +
 tests/fuzz/Makefile                           | 12 ++-
 tests/fuzz/fuzz.py                            |  1 +
 tests/fuzz/seekable_roundtrip.c               | 88 +++++++++++++++++++
 10 files changed, 108 insertions(+), 12 deletions(-)
 create mode 100644 tests/fuzz/seekable_roundtrip.c

diff --git a/build/cmake/lib/CMakeLists.txt b/build/cmake/lib/CMakeLists.txt
index d58c652a1af..5d89b6e1f0a 100644
--- a/build/cmake/lib/CMakeLists.txt
+++ b/build/cmake/lib/CMakeLists.txt
@@ -91,6 +91,7 @@ endif ()
 if (ZSTD_BUILD_STATIC)
     add_library(libzstd_static STATIC ${Sources} ${Headers})
     list(APPEND library_targets libzstd_static)
+    target_include_directories(libzstd_static PUBLIC ../../../lib)
     if (ZSTD_MULTITHREAD_SUPPORT)
         set_property(TARGET libzstd_static APPEND PROPERTY COMPILE_DEFINITIONS "ZSTD_MULTITHREAD")
         if (UNIX)
diff --git a/contrib/seekable_format/examples/Makefile b/contrib/seekable_format/examples/Makefile
index 543780f75d3..9df6b75fb84 100644
--- a/contrib/seekable_format/examples/Makefile
+++ b/contrib/seekable_format/examples/Makefile
@@ -13,7 +13,7 @@ ZSTDLIB_PATH = ../../../lib
 ZSTDLIB_NAME = libzstd.a
 ZSTDLIB = $(ZSTDLIB_PATH)/$(ZSTDLIB_NAME)
 
-CPPFLAGS += -I../ -I../../../lib -I../../../lib/common
+CPPFLAGS += -DXXH_NAMESPACE=ZSTD_ -I../ -I../../../lib -I../../../lib/common
 
 CFLAGS ?= -O3
 CFLAGS += -g
diff --git a/contrib/seekable_format/examples/parallel_compression.c b/contrib/seekable_format/examples/parallel_compression.c
index 69644d2b3c8..4118b0ad762 100644
--- a/contrib/seekable_format/examples/parallel_compression.c
+++ b/contrib/seekable_format/examples/parallel_compression.c
@@ -21,7 +21,6 @@
 #  define SLEEP(x) usleep(x * 1000)
 #endif
 
-#define XXH_NAMESPACE ZSTD_
 #include "xxhash.h"
 
 #include "pool.h"      // use zstd thread pool for demo
diff --git a/contrib/seekable_format/tests/Makefile b/contrib/seekable_format/tests/Makefile
index 15eadb40e2d..d51deb3ea82 100644
--- a/contrib/seekable_format/tests/Makefile
+++ b/contrib/seekable_format/tests/Makefile
@@ -13,7 +13,7 @@ ZSTDLIB_PATH = ../../../lib
 ZSTDLIB_NAME = libzstd.a
 ZSTDLIB = $(ZSTDLIB_PATH)/$(ZSTDLIB_NAME)
 
-CPPFLAGS += -I../ -I$(ZSTDLIB_PATH) -I$(ZSTDLIB_PATH)/common
+CPPFLAGS += -DXXH_NAMESPACE=ZSTD_ -I../ -I$(ZSTDLIB_PATH) -I$(ZSTDLIB_PATH)/common
 
 CFLAGS ?= -O3
 CFLAGS += -g -Wall -Wextra -Wcast-qual -Wcast-align -Wconversion \
diff --git a/contrib/seekable_format/zstdseek_compress.c b/contrib/seekable_format/zstdseek_compress.c
index d92917a6250..242bd2ac3a1 100644
--- a/contrib/seekable_format/zstdseek_compress.c
+++ b/contrib/seekable_format/zstdseek_compress.c
@@ -12,7 +12,6 @@
 #include <assert.h>
 
 #define XXH_STATIC_LINKING_ONLY
-#define XXH_NAMESPACE ZSTD_
 #include "xxhash.h"
 
 #define ZSTD_STATIC_LINKING_ONLY
@@ -83,7 +82,7 @@ static size_t ZSTD_seekable_frameLog_freeVec(ZSTD_frameLog* fl)
 
 ZSTD_frameLog* ZSTD_seekable_createFrameLog(int checksumFlag)
 {
-    ZSTD_frameLog* const fl = malloc(sizeof(ZSTD_frameLog));
+    ZSTD_frameLog* const fl = (ZSTD_frameLog*)malloc(sizeof(ZSTD_frameLog));
     if (fl == NULL) return NULL;
 
     if (ZSTD_isError(ZSTD_seekable_frameLog_allocVec(fl))) {
@@ -108,7 +107,7 @@ size_t ZSTD_seekable_freeFrameLog(ZSTD_frameLog* fl)
 
 ZSTD_seekable_CStream* ZSTD_seekable_createCStream(void)
 {
-    ZSTD_seekable_CStream* const zcs = malloc(sizeof(ZSTD_seekable_CStream));
+    ZSTD_seekable_CStream* const zcs = (ZSTD_seekable_CStream*)malloc(sizeof(ZSTD_seekable_CStream));
     if (zcs == NULL) return NULL;
 
     memset(zcs, 0, sizeof(*zcs));
@@ -177,7 +176,7 @@ size_t ZSTD_seekable_logFrame(ZSTD_frameLog* fl,
     if (fl->size == fl->capacity) {
         /* exponential size increase for constant amortized runtime */
         size_t const newCapacity = fl->capacity * 2;
-        framelogEntry_t* const newEntries = realloc(fl->entries,
+        framelogEntry_t* const newEntries = (framelogEntry_t*)realloc(fl->entries,
                 sizeof(framelogEntry_t) * newCapacity);
 
         if (newEntries == NULL) return ERROR(memory_allocation);
diff --git a/contrib/seekable_format/zstdseek_decompress.c b/contrib/seekable_format/zstdseek_decompress.c
index ecf816c172f..5eed024950b 100644
--- a/contrib/seekable_format/zstdseek_decompress.c
+++ b/contrib/seekable_format/zstdseek_decompress.c
@@ -60,7 +60,6 @@
 #include <assert.h>
 
 #define XXH_STATIC_LINKING_ONLY
-#define XXH_NAMESPACE ZSTD_
 #include "xxhash.h"
 
 #define ZSTD_STATIC_LINKING_ONLY
@@ -176,7 +175,7 @@ struct ZSTD_seekable_s {
 
 ZSTD_seekable* ZSTD_seekable_create(void)
 {
-    ZSTD_seekable* const zs = malloc(sizeof(ZSTD_seekable));
+    ZSTD_seekable* const zs = (ZSTD_seekable*)malloc(sizeof(ZSTD_seekable));
     if (zs == NULL) return NULL;
 
     /* also initializes stage to zsds_init */
@@ -202,7 +201,7 @@ size_t ZSTD_seekable_free(ZSTD_seekable* zs)
 
 ZSTD_seekTable* ZSTD_seekTable_create_fromSeekable(const ZSTD_seekable* zs)
 {
-    ZSTD_seekTable* const st = malloc(sizeof(ZSTD_seekTable));
+    ZSTD_seekTable* const st = (ZSTD_seekTable*)malloc(sizeof(ZSTD_seekTable));
     if (st==NULL) return NULL;
 
     st->checksumFlag = zs->seekTable.checksumFlag;
diff --git a/tests/fuzz/.gitignore b/tests/fuzz/.gitignore
index 8ef3a3efdf8..93d935a85e8 100644
--- a/tests/fuzz/.gitignore
+++ b/tests/fuzz/.gitignore
@@ -16,6 +16,7 @@ zstd_frame_info
 decompress_dstSize_tooSmall
 fse_read_ncount
 sequence_compression_api
+seekable_roundtrip
 fuzz-*.log
 rt_lib_*
 d_lib_*
diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile
index 24b9f346816..ccb574b79da 100644
--- a/tests/fuzz/Makefile
+++ b/tests/fuzz/Makefile
@@ -25,10 +25,11 @@ CORPORA_URL_PREFIX:=https://github.com/facebook/zstd/releases/download/fuzz-corp
 
 ZSTDDIR = ../../lib
 PRGDIR = ../../programs
+CONTRIBDIR = ../../contrib
 
 FUZZ_CPPFLAGS := -I$(ZSTDDIR) -I$(ZSTDDIR)/common -I$(ZSTDDIR)/compress \
 	-I$(ZSTDDIR)/dictBuilder -I$(ZSTDDIR)/deprecated -I$(ZSTDDIR)/legacy \
-	-I$(PRGDIR) -DZSTD_MULTITHREAD -DZSTD_LEGACY_SUPPORT=1 $(CPPFLAGS)
+	-I$(CONTRIBDIR)/seekable_format -I$(PRGDIR) -DZSTD_MULTITHREAD -DZSTD_LEGACY_SUPPORT=1 $(CPPFLAGS)
 FUZZ_EXTRA_FLAGS := -Wall -Wextra -Wcast-qual -Wcast-align -Wshadow \
 	-Wstrict-aliasing=1 -Wswitch-enum -Wdeclaration-after-statement \
 	-Wstrict-prototypes -Wundef \
@@ -46,6 +47,9 @@ FUZZ_ROUND_TRIP_FLAGS := -DFUZZING_ASSERT_VALID_SEQUENCE
 FUZZ_HEADERS := fuzz_helpers.h fuzz.h zstd_helpers.h fuzz_data_producer.h
 FUZZ_SRC := $(PRGDIR)/util.c ./fuzz_helpers.c ./zstd_helpers.c ./fuzz_data_producer.c
 
+SEEKABLE_HEADERS = $(CONTRIBDIR)/seekable_format/zstd_seekable.h
+SEEKABLE_OBJS = $(CONTRIBDIR)/seekable_format/zstdseek_compress.c $(CONTRIBDIR)/seekable_format/zstdseek_decompress.c
+
 ZSTDCOMMON_SRC := $(ZSTDDIR)/common/*.c
 ZSTDCOMP_SRC   := $(ZSTDDIR)/compress/*.c
 ZSTDDECOMP_SRC := $(ZSTDDIR)/decompress/*.c
@@ -98,7 +102,8 @@ FUZZ_TARGETS :=       \
 	dictionary_stream_round_trip \
 	decompress_dstSize_tooSmall \
 	fse_read_ncount \
-	sequence_compression_api
+	sequence_compression_api \
+	seekable_roundtrip
 
 all: libregression.a $(FUZZ_TARGETS)
 
@@ -192,6 +197,9 @@ fse_read_ncount: $(FUZZ_HEADERS) $(FUZZ_ROUND_TRIP_OBJ) rt_fuzz_fse_read_ncount.
 sequence_compression_api: $(FUZZ_HEADERS) $(FUZZ_ROUND_TRIP_OBJ) rt_fuzz_sequence_compression_api.o
 	$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_ROUND_TRIP_OBJ) rt_fuzz_sequence_compression_api.o $(LIB_FUZZING_ENGINE) -o $@
 
+seekable_roundtrip: $(FUZZ_HEADERS) $(SEEKABLE_HEADERS) $(FUZZ_ROUND_TRIP_OBJ) $(SEEKABLE_OBJS)  rt_fuzz_seekable_roundtrip.o
+	$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_ROUND_TRIP_OBJ) $(SEEKABLE_OBJS) rt_fuzz_seekable_roundtrip.o $(LIB_FUZZING_ENGINE) -o $@
+
 libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c d_fuzz_regression_driver.o
 	$(AR) $(FUZZ_ARFLAGS) $@ d_fuzz_regression_driver.o
 
diff --git a/tests/fuzz/fuzz.py b/tests/fuzz/fuzz.py
index a3431f365c8..d8dfa7782e5 100755
--- a/tests/fuzz/fuzz.py
+++ b/tests/fuzz/fuzz.py
@@ -62,6 +62,7 @@ def __init__(self, input_type, frame_type=FrameType.ZSTD):
     'decompress_dstSize_tooSmall': TargetInfo(InputType.RAW_DATA),
     'fse_read_ncount': TargetInfo(InputType.RAW_DATA),
     'sequence_compression_api': TargetInfo(InputType.RAW_DATA),
+    'seekable_roundtrip': TargetInfo(InputType.RAW_DATA),
 }
 TARGETS = list(TARGET_INFO.keys())
 ALL_TARGETS = TARGETS + ['all']
diff --git a/tests/fuzz/seekable_roundtrip.c b/tests/fuzz/seekable_roundtrip.c
new file mode 100644
index 00000000000..dcdcaae116a
--- /dev/null
+++ b/tests/fuzz/seekable_roundtrip.c
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ * You may select, at your option, one of the above-listed licenses.
+ */
+
+#include "zstd.h"
+#include "zstd_seekable.h"
+#include "fuzz_helpers.h"
+#include "fuzz_data_producer.h"
+
+static ZSTD_seekable *stream = NULL;
+static ZSTD_seekable_CStream *zscs = NULL;
+static const size_t kSeekableOverheadSize = ZSTD_seekTableFooterSize;
+
+int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
+{
+    /* Give a random portion of src data to the producer, to use for
+    parameter generation. The rest will be used for (de)compression */
+    FUZZ_dataProducer_t *producer = FUZZ_dataProducer_create(src, size);
+    size = FUZZ_dataProducer_reserveDataPrefix(producer);
+    size_t const compressedBufferSize = ZSTD_compressBound(size) + kSeekableOverheadSize;
+    uint8_t* compressedBuffer = (uint8_t*)malloc(compressedBufferSize);
+    uint8_t* decompressedBuffer = (uint8_t*)malloc(size);
+
+    int const cLevel = FUZZ_dataProducer_int32Range(producer, ZSTD_minCLevel(), ZSTD_maxCLevel());
+    unsigned const checksumFlag = FUZZ_dataProducer_int32Range(producer, 0, 1);
+    size_t const uncompressedSize = FUZZ_dataProducer_uint32Range(producer, 0, size);
+    size_t const offset = FUZZ_dataProducer_uint32Range(producer, 0, size - uncompressedSize);
+    size_t seekSize;
+
+    if (!zscs) {
+        zscs = ZSTD_seekable_createCStream();
+        FUZZ_ASSERT(zscs);
+    }
+    if (!stream) {
+        stream = ZSTD_seekable_create();
+        FUZZ_ASSERT(stream);
+    }
+
+    {   /* Perform a compression */
+        size_t const initStatus = ZSTD_seekable_initCStream(zscs, cLevel, checksumFlag, size);
+        size_t endStatus;
+        ZSTD_outBuffer out = { .dst=compressedBuffer, .pos=0, .size=compressedBufferSize };
+        ZSTD_inBuffer  in  = { .src=src, .pos=0, .size=size };
+        FUZZ_ASSERT(!ZSTD_isError(initStatus));
+
+        do {
+            size_t cSize = ZSTD_seekable_compressStream(zscs, &out, &in);
+            FUZZ_ASSERT(!ZSTD_isError(cSize));
+        } while (in.pos != in.size);
+
+        FUZZ_ASSERT(in.pos == in.size);
+        endStatus = ZSTD_seekable_endStream(zscs, &out);
+        FUZZ_ASSERT(!ZSTD_isError(endStatus));
+        seekSize = out.pos;
+    }
+
+    {   /* Decompress at an offset */
+        size_t const initStatus = ZSTD_seekable_initBuff(stream, compressedBuffer, seekSize);
+        size_t decompressedBytesTotal = 0;
+        size_t dSize;
+
+        FUZZ_ZASSERT(initStatus);
+        do {
+            dSize = ZSTD_seekable_decompress(stream, decompressedBuffer, uncompressedSize, offset);
+            FUZZ_ASSERT(!ZSTD_isError(dSize));
+            decompressedBytesTotal += dSize;
+        } while (decompressedBytesTotal < uncompressedSize && dSize > 0);
+        FUZZ_ASSERT(decompressedBytesTotal == uncompressedSize);
+    }
+
+    FUZZ_ASSERT_MSG(!FUZZ_memcmp(src+offset, decompressedBuffer, uncompressedSize), "Corruption!");
+
+    free(decompressedBuffer);
+    free(compressedBuffer);
+    FUZZ_dataProducer_free(producer);
+
+#ifndef STATEFUL_FUZZING
+    ZSTD_seekable_free(stream); stream = NULL;
+    ZSTD_seekable_freeCStream(zscs); zscs = NULL;
+#endif
+    return 0;
+}