diff --git a/large_objects.go b/large_objects.go index c238ab9c2..a3028b638 100644 --- a/large_objects.go +++ b/large_objects.go @@ -6,6 +6,11 @@ import ( "io" ) +// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of +// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data +// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB. +var maxLargeObjectMessageLength = 1024*1024*1024 - 1024 + // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it // was created. // @@ -68,32 +73,64 @@ type LargeObject struct { // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. func (o *LargeObject) Write(p []byte) (int, error) { - var n int - err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) - if err != nil { - return n, err - } - - if n < 0 { - return 0, errors.New("failed to write to large object") + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + var n int + err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n) + if err != nil { + return nTotal, err + } + + if n < 0 { + return nTotal, errors.New("failed to write to large object") + } + + nTotal += n + + if n < expected { + return nTotal, errors.New("short write to large object") + } else if n > expected { + return nTotal, errors.New("invalid write to large object") + } } - return n, nil + return nTotal, nil } // Read reads up to len(p) bytes into p returning the number of bytes read. func (o *LargeObject) Read(p []byte) (int, error) { - var res []byte - err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) - copy(p, res) - if err != nil { - return len(res), err + nTotal := 0 + for { + expected := len(p) - nTotal + if expected == 0 { + break + } else if expected > maxLargeObjectMessageLength { + expected = maxLargeObjectMessageLength + } + + var res []byte + err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res) + copy(p[nTotal:], res) + nTotal += len(res) + if err != nil { + return nTotal, err + } + + if len(res) < expected { + return nTotal, io.EOF + } else if len(res) > expected { + return nTotal, errors.New("invalid read of large object") + } } - if len(res) < len(p) { - err = io.EOF - } - return len(res), err + return nTotal, nil } // Seek moves the current location pointer to the new location specified by offset. diff --git a/large_objects_private_test.go b/large_objects_private_test.go new file mode 100644 index 000000000..36eca8f06 --- /dev/null +++ b/large_objects_private_test.go @@ -0,0 +1,20 @@ +package pgx + +import ( + "testing" +) + +// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable +// to the given length for the duration of the test. +// +// Tests using this helper should not use t.Parallel(). +func SetMaxLargeObjectMessageLength(t *testing.T, length int) { + t.Helper() + + original := maxLargeObjectMessageLength + t.Cleanup(func() { + maxLargeObjectMessageLength = original + }) + + maxLargeObjectMessageLength = length +} diff --git a/large_objects_test.go b/large_objects_test.go index 25611bf67..de2eed0d8 100644 --- a/large_objects_test.go +++ b/large_objects_test.go @@ -13,7 +13,8 @@ import ( ) func TestLargeObjects(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -34,7 +35,8 @@ func TestLargeObjects(t *testing.T) { } func TestLargeObjectsSimpleProtocol(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -160,7 +162,8 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { } func TestLargeObjectsMultipleTransactions(t *testing.T) { - t.Parallel() + // We use a very short limit to test chunking logic. + pgx.SetMaxLargeObjectMessageLength(t, 2) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel()