From 0ead11a53105d3e6db90eac98f8dd78a629f70c0 Mon Sep 17 00:00:00 2001
From: Evan Jones <evan.jones@datadoghq.com>
Date: Fri, 5 Mar 2021 16:55:31 -0500
Subject: [PATCH 1/2] Fix "Go pointer to Go pointer" panics

The Cgo runtime verifies that Go never passes pointers to other Go
pointers, which is required for correct garbage collection.
Unfortunately, its checks are not perfect, and there are occasional
false positives. Our code triggers these false positives if the
slice passed to compression functions is in the same memory
allocation as Go pointers. This happened when trying to use zstd with
another package's Writer type, which has an internal buffer.

The tests added in this PR all fail with the following panic. The
fix is to ensure the expression unsafe.Pointer(&src[0]) is inside the
Cgo call, and not before. This is documented in the following issue:

https://github.com/golang/go/issues/14210#issuecomment-346402945

The remaining uses of the "var srcPtr *byte" pattern are safe: they
all pass the address of a byte slice that is allocated internally by
this library, so I believe they cannot cause this error.

Fixes the following panic:

panic: runtime error: cgo argument has Go pointer to Go pointer

goroutine 30 [running]:
panic(...)
  /usr/local/go/src/runtime/panic.go:969 +0x1b9
github.com/DataDog/zstd.(*ctx).CompressLevel.func1(...)
  /home/circleci/project/zstd_ctx.go:75 +0xd9
github.com/DataDog/zstd.(*ctx).CompressLevel(...)
  /home/circleci/project/zstd_ctx.go:75 +0xce
github.com/DataDog/zstd.TestCtxCompressLevelNoGoPointers.func1(...)
  /home/circleci/project/zstd_ctx_test.go:71 +0x77
github.com/DataDog/zstd.testCompressNoGoPointers(...)
  /home/circleci/project/zstd_test.go:130 +0xad
github.com/DataDog/zstd.TestCtxCompressLevelNoGoPointers(...)
  /home/circleci/project/zstd_ctx_test.go:69 +0x37
testing.tRunner(...)
  /usr/local/go/src/testing/testing.go:1123 +0xef
---
 zstd.go             | 28 ++++++++++++++++++----------
 zstd_ctx.go         | 31 ++++++++++++++++++++-----------
 zstd_ctx_test.go    |  7 +++++++
 zstd_stream.go      |  8 ++------
 zstd_stream_test.go | 16 ++++++++++++++++
 zstd_test.go        | 37 +++++++++++++++++++++++++++++++++++++
 6 files changed, 100 insertions(+), 27 deletions(-)

diff --git a/zstd.go b/zstd.go
index 164a923..634ed65 100644
--- a/zstd.go
+++ b/zstd.go
@@ -58,18 +58,26 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) {
 		dst = make([]byte, bound)
 	}
 
-	var srcPtr *byte // Do not point anywhere, if src is empty
-	if len(src) > 0 {
-		srcPtr = &src[0]
+	// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
+	// This means we need to special case empty input. See:
+	// https://github.com/golang/go/issues/14210#issuecomment-346402945
+	var cWritten C.size_t
+	if len(src) == 0 {
+		cWritten = C.ZSTD_compress(
+			unsafe.Pointer(&dst[0]),
+			C.size_t(len(dst)),
+			unsafe.Pointer(nil),
+			C.size_t(0),
+			C.int(level))
+	} else {
+		cWritten = C.ZSTD_compress(
+			unsafe.Pointer(&dst[0]),
+			C.size_t(len(dst)),
+			unsafe.Pointer(&src[0]),
+			C.size_t(len(src)),
+			C.int(level))
 	}
 
-	cWritten := C.ZSTD_compress(
-		unsafe.Pointer(&dst[0]),
-		C.size_t(len(dst)),
-		unsafe.Pointer(srcPtr),
-		C.size_t(len(src)),
-		C.int(level))
-
 	written := int(cWritten)
 	// Check if the return is an Error code
 	if err := getError(written); err != nil {
diff --git a/zstd_ctx.go b/zstd_ctx.go
index 12e9539..6b98943 100644
--- a/zstd_ctx.go
+++ b/zstd_ctx.go
@@ -63,19 +63,28 @@ func (c *ctx) CompressLevel(dst, src []byte, level int) ([]byte, error) {
 		dst = make([]byte, bound)
 	}
 
-	var srcPtr *byte // Do not point anywhere, if src is empty
-	if len(src) > 0 {
-		srcPtr = &src[0]
+	// We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics.
+	// This means we need to special case empty input. See:
+	// https://github.com/golang/go/issues/14210#issuecomment-346402945
+	var cWritten C.size_t
+	if len(src) == 0 {
+		cWritten = C.ZSTD_compressCCtx(
+			c.cctx,
+			unsafe.Pointer(&dst[0]),
+			C.size_t(len(dst)),
+			unsafe.Pointer(nil),
+			C.size_t(0),
+			C.int(level))
+	} else {
+		cWritten = C.ZSTD_compressCCtx(
+			c.cctx,
+			unsafe.Pointer(&dst[0]),
+			C.size_t(len(dst)),
+			unsafe.Pointer(&src[0]),
+			C.size_t(len(src)),
+			C.int(level))
 	}
 
-	cWritten := C.ZSTD_compressCCtx(
-		c.cctx,
-		unsafe.Pointer(&dst[0]),
-		C.size_t(len(dst)),
-		unsafe.Pointer(srcPtr),
-		C.size_t(len(src)),
-		C.int(level))
-
 	written := int(cWritten)
 	// Check if the return is an Error code
 	if err := getError(written); err != nil {
diff --git a/zstd_ctx_test.go b/zstd_ctx_test.go
index 831a21f..ac82091 100644
--- a/zstd_ctx_test.go
+++ b/zstd_ctx_test.go
@@ -65,6 +65,13 @@ func TestCtxCompressLevel(t *testing.T) {
 	}
 }
 
+func TestCtxCompressLevelNoGoPointers(t *testing.T) {
+	testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
+		cctx := NewCtx()
+		return cctx.CompressLevel(nil, input, BestSpeed)
+	})
+}
+
 func TestCtxEmptySliceCompress(t *testing.T) {
 	ctx := NewCtx()
 
diff --git a/zstd_stream.go b/zstd_stream.go
index 1ed0e98..ac0a23a 100644
--- a/zstd_stream.go
+++ b/zstd_stream.go
@@ -168,17 +168,13 @@ func (w *Writer) Write(p []byte) (int, error) {
 		srcData = w.srcBuffer
 	}
 
-	var srcPtr *byte // Do not point anywhere, if src is empty
-	if len(srcData) > 0 {
-		srcPtr = &srcData[0]
-	}
-
+	// &srcData[0] is safe: it is p or w.srcBuffer but only if len() > 0 checked above
 	C.ZSTD_compressStream2_wrapper(
 		w.resultBuffer,
 		w.ctx,
 		unsafe.Pointer(&w.dstBuffer[0]),
 		C.size_t(len(w.dstBuffer)),
-		unsafe.Pointer(srcPtr),
+		unsafe.Pointer(&srcData[0]),
 		C.size_t(len(srcData)),
 	)
 	ret := int(w.resultBuffer.return_code)
diff --git a/zstd_stream_test.go b/zstd_stream_test.go
index 06acece..79f412e 100644
--- a/zstd_stream_test.go
+++ b/zstd_stream_test.go
@@ -375,6 +375,22 @@ func TestStreamDecompressionChunks(t *testing.T) {
 	}
 }
 
+func TestStreamWriteNoGoPointers(t *testing.T) {
+	testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
+		buf := &bytes.Buffer{}
+		zw := NewWriter(buf)
+		_, err := zw.Write(input)
+		if err != nil {
+			return nil, err
+		}
+		err = zw.Close()
+		if err != nil {
+			return nil, err
+		}
+		return buf.Bytes(), nil
+	})
+}
+
 func BenchmarkStreamCompression(b *testing.B) {
 	if raw == nil {
 		b.Fatal(ErrNoPayloadEnv)
diff --git a/zstd_test.go b/zstd_test.go
index e4d90c8..e5bb2d2 100644
--- a/zstd_test.go
+++ b/zstd_test.go
@@ -109,6 +109,43 @@ func TestCompressLevel(t *testing.T) {
 	}
 }
 
+// structWithGoPointers contains a byte buffer and a pointer to Go objects (slice). This means
+// Cgo checks can fail when passing a pointer to buffer:
+// "panic: runtime error: cgo argument has Go pointer to Go pointer"
+// https://github.com/golang/go/issues/14210#issuecomment-346402945
+type structWithGoPointers struct {
+	buffer [1]byte
+	slice  []byte
+}
+
+// testCompressDecompressByte ensures that functions use the correct unsafe.Pointer assignment
+// to avoid "Go pointer to Go pointer" panics.
+func testCompressNoGoPointers(t *testing.T, compressFunc func(input []byte) ([]byte, error)) {
+	t.Helper()
+
+	s := structWithGoPointers{}
+	s.buffer[0] = 0x42
+	s.slice = s.buffer[:1]
+
+	compressed, err := compressFunc(s.slice)
+	if err != nil {
+		t.Fatal(err)
+	}
+	decompressed, err := Decompress(nil, compressed)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(decompressed, s.slice) {
+		t.Errorf("decompressed=%#v input=%#v", decompressed, s.slice)
+	}
+}
+
+func TestCompressLevelNoGoPointers(t *testing.T) {
+	testCompressNoGoPointers(t, func(input []byte) ([]byte, error) {
+		return CompressLevel(nil, input, BestSpeed)
+	})
+}
+
 func doCompressLevel(payload []byte, out []byte) error {
 	out, err := CompressLevel(out, payload, DefaultCompression)
 	if err != nil {

From ee095180119c4771a30d3f88ef6d5d03fa07cac3 Mon Sep 17 00:00:00 2001
From: Evan Jones <evan.jones@datadoghq.com>
Date: Mon, 8 Mar 2021 15:22:36 -0500
Subject: [PATCH 2/2] zstd_stream.go: add explicit len() check

---
 zstd_stream.go | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/zstd_stream.go b/zstd_stream.go
index ac0a23a..f9eb2de 100644
--- a/zstd_stream.go
+++ b/zstd_stream.go
@@ -168,7 +168,11 @@ func (w *Writer) Write(p []byte) (int, error) {
 		srcData = w.srcBuffer
 	}
 
-	// &srcData[0] is safe: it is p or w.srcBuffer but only if len() > 0 checked above
+	if len(srcData) == 0 {
+		// this is technically unnecessary: srcData is p or w.srcBuffer, and len() > 0 checked above
+		// but this ensures the code can change without dereferencing an srcData[0]
+		return 0, nil
+	}
 	C.ZSTD_compressStream2_wrapper(
 		w.resultBuffer,
 		w.ctx,