diff --git a/spanner/batch.go b/spanner/batch.go index 69399d0fecce..0259d71a669e 100644 --- a/spanner/batch.go +++ b/spanner/batch.go @@ -309,7 +309,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R var ( sh *sessionHandle err error - rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) + rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) ) if sh, _, err = t.acquire(ctx); err != nil { return &RowIterator{err: err} @@ -322,7 +322,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R sh.updateLastUseTime() // Read or query partition. if p.rreq != nil { - rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) { + rpc = func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { client, err := client.StreamingRead(ctx, &sppb.ReadRequest{ Session: p.rreq.Session, Transaction: p.rreq.Transaction, @@ -335,7 +335,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R ResumeToken: resumeToken, DataBoostEnabled: p.rreq.DataBoostEnabled, DirectedReadOptions: p.rreq.DirectedReadOptions, - }) + }, opts...) if err != nil { return client, err } @@ -351,7 +351,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R return client, err } } else { - rpc = func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) { + rpc = func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { client, err := client.ExecuteStreamingSql(ctx, &sppb.ExecuteSqlRequest{ Session: p.qreq.Session, Transaction: p.qreq.Transaction, @@ -364,7 +364,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R ResumeToken: resumeToken, DataBoostEnabled: p.qreq.DataBoostEnabled, DirectedReadOptions: p.qreq.DirectedReadOptions, - }) + }, opts...) if err != nil { return client, err } @@ -387,7 +387,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R t.sp.sc.metricsTracerFactory, rpc, t.setTimestamp, - t.release) + t.release, client.(*grpcSpannerClient)) } // MarshalBinary implements BinaryMarshaler. diff --git a/spanner/client.go b/spanner/client.go index f95740bed97d..ef53bafdb42c 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -433,6 +433,14 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf } else { // Create gtransport ConnPool as usual if MultiEndpoint is not used. // gRPC options. + + // Add a unaryClientInterceptor and streamClientInterceptor. + reqIDInjector := new(requestIDHeaderInjector) + opts = append(opts, + option.WithGRPCDialOption(grpc.WithChainStreamInterceptor(reqIDInjector.interceptStream)), + option.WithGRPCDialOption(grpc.WithChainUnaryInterceptor(reqIDInjector.interceptUnary)), + ) + allOpts := allClientOpts(config.NumChannels, config.Compression, opts...) pool, err = gtransport.DialPool(ctx, allOpts...) if err != nil { diff --git a/spanner/client_test.go b/spanner/client_test.go index e353cc69b999..d7eb035f7267 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -4187,13 +4187,17 @@ func TestReadWriteTransaction_ContextTimeoutDuringCommit(t *testing.T) { if se.GRPCStatus().Code() != w.GRPCStatus().Code() { t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus()) } - if se.Error() != w.Error() { - t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error()) + if !testEqual(se, w) { + t.Fatalf("Error message mismatch:\nGot: %s\nWant: %s", se.Error(), w.Error()) } var outcome *TransactionOutcomeUnknownError if !errors.As(err, &outcome) { t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error") } + + if w.RequestID != "" { + t.Fatal("Missing .RequestID") + } } func TestFailedCommit_NoRollback(t *testing.T) { diff --git a/spanner/cmp_test.go b/spanner/cmp_test.go index 374fb827ef9c..b39c0a8f2f8f 100644 --- a/spanner/cmp_test.go +++ b/spanner/cmp_test.go @@ -64,6 +64,9 @@ func testEqual(a, b interface{}) bool { if strings.Contains(path.GoString(), "{*spanner.Error}.err") { return true } + if strings.Contains(path.GoString(), "{*spanner.Error}.RequestID") { + return true + } return false }, cmp.Ignore())) } diff --git a/spanner/errors.go b/spanner/errors.go index edb52d26a47f..d99c6d3775af 100644 --- a/spanner/errors.go +++ b/spanner/errors.go @@ -58,6 +58,10 @@ type Error struct { // additionalInformation optionally contains any additional information // about the error. additionalInformation string + + // RequestID is the associated ID that was sent to Google Cloud Spanner's + // backend, as the value in the "x-goog-spanner-request-id" gRPC header. + RequestID string } // TransactionOutcomeUnknownError is wrapped in a Spanner error when the error @@ -85,10 +89,17 @@ func (e *Error) Error() string { return "spanner: OK" } code := ErrCode(e) + + var s string if e.additionalInformation == "" { - return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc) + s = fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc) + } else { + s = fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation) } - return fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation) + if e.RequestID != "" { + s = fmt.Sprintf("%s, requestID = %q", s, e.RequestID) + } + return s } // Unwrap returns the wrapped error (if any). @@ -123,6 +134,10 @@ func (e *Error) decorate(info string) { // APIError error having given error code as its status. func spannerErrorf(code codes.Code, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) + return spannerError(code, msg) +} + +func spannerError(code codes.Code, msg string) error { wrapped, _ := apierror.FromError(status.Error(code, msg)) return &Error{ Code: code, @@ -172,9 +187,9 @@ func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error { desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg) wrapped = &TransactionOutcomeUnknownError{err: wrapped} } - return &Error{status.FromContextError(err).Code(), toAPIError(wrapped), desc, ""} + return &Error{status.FromContextError(err).Code(), toAPIError(wrapped), desc, "", ""} case status.Code(err) == codes.Unknown: - return &Error{codes.Unknown, toAPIError(err), err.Error(), ""} + return &Error{codes.Unknown, toAPIError(err), err.Error(), "", ""} default: statusErr := status.Convert(err) code, desc := statusErr.Code(), statusErr.Message() @@ -183,7 +198,7 @@ func toSpannerErrorWithCommitInfo(err error, errorDuringCommit bool) error { desc = fmt.Sprintf("%s, %s", desc, transactionOutcomeUnknownMsg) wrapped = &TransactionOutcomeUnknownError{err: wrapped} } - return &Error{code, toAPIError(wrapped), desc, ""} + return &Error{code, toAPIError(wrapped), desc, "", ""} } } diff --git a/spanner/grpc_client.go b/spanner/grpc_client.go index 9b7f1bca4ca6..207db266f079 100644 --- a/spanner/grpc_client.go +++ b/spanner/grpc_client.go @@ -19,6 +19,7 @@ package spanner import ( "context" "strings" + "sync/atomic" vkit "cloud.google.com/go/spanner/apiv1" "cloud.google.com/go/spanner/apiv1/spannerpb" @@ -67,6 +68,15 @@ type spannerClient interface { type grpcSpannerClient struct { raw *vkit.Client metricsTracerFactory *builtinMetricsTracerFactory + + // These fields are used to uniquely track x-goog-spanner-request-id where: + // raw(*vkit.Client) is the channel, and channelID is derived from the ordinal + // count of unique *vkit.Client as retrieved from the session pool. + channelID uint64 + // id is derived from the SpannerClient. + id int + // nthRequest is incremented for each new request (but not for retries of requests). + nthRequest *atomic.Uint32 } var ( @@ -76,13 +86,16 @@ var ( // newGRPCSpannerClient initializes a new spannerClient that uses the gRPC // Spanner API. -func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, opts ...option.ClientOption) (spannerClient, error) { +func newGRPCSpannerClient(ctx context.Context, sc *sessionClient, channelID uint64, opts ...option.ClientOption) (spannerClient, error) { raw, err := vkit.NewClient(ctx, opts...) if err != nil { return nil, err } g := &grpcSpannerClient{raw: raw, metricsTracerFactory: sc.metricsTracerFactory} + clientID := sc.nthClient + g.prepareRequestIDTrackers(clientID, channelID, sc.nthRequest) + clientInfo := []string{"gccl", internal.Version} if sc.userAgent != "" { agentWithVersion := strings.SplitN(sc.userAgent, "/", 2) @@ -118,7 +131,7 @@ func (g *grpcSpannerClient) CreateSession(ctx context.Context, req *spannerpb.Cr mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.CreateSession(ctx, req, opts...) + resp, err := g.raw.CreateSession(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -128,7 +141,7 @@ func (g *grpcSpannerClient) BatchCreateSessions(ctx context.Context, req *spanne mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.BatchCreateSessions(ctx, req, opts...) + resp, err := g.raw.BatchCreateSessions(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -138,21 +151,21 @@ func (g *grpcSpannerClient) GetSession(ctx context.Context, req *spannerpb.GetSe mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.GetSession(ctx, req, opts...) + resp, err := g.raw.GetSession(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) ListSessions(ctx context.Context, req *spannerpb.ListSessionsRequest, opts ...gax.CallOption) *vkit.SessionIterator { - return g.raw.ListSessions(ctx, req, opts...) + return g.raw.ListSessions(ctx, req, g.optsWithNextRequestID(opts)...) } func (g *grpcSpannerClient) DeleteSession(ctx context.Context, req *spannerpb.DeleteSessionRequest, opts ...gax.CallOption) error { mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - err := g.raw.DeleteSession(ctx, req, opts...) + err := g.raw.DeleteSession(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return err @@ -162,13 +175,15 @@ func (g *grpcSpannerClient) ExecuteSql(ctx context.Context, req *spannerpb.Execu mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.ExecuteSql(ctx, req, opts...) + resp, err := g.raw.ExecuteSql(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) ExecuteStreamingSql(ctx context.Context, req *spannerpb.ExecuteSqlRequest, opts ...gax.CallOption) (spannerpb.Spanner_ExecuteStreamingSqlClient, error) { + // Note: This method does not add g.optsWithNextRequestID to inject x-goog-spanner-request-id + // as it is already manually added when creating Stream iterators for ExecuteStreamingSql. return g.raw.ExecuteStreamingSql(peer.NewContext(ctx, &peer.Peer{}), req, opts...) } @@ -176,7 +191,7 @@ func (g *grpcSpannerClient) ExecuteBatchDml(ctx context.Context, req *spannerpb. mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.ExecuteBatchDml(ctx, req, opts...) + resp, err := g.raw.ExecuteBatchDml(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -186,13 +201,15 @@ func (g *grpcSpannerClient) Read(ctx context.Context, req *spannerpb.ReadRequest mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.Read(ctx, req, opts...) + resp, err := g.raw.Read(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) StreamingRead(ctx context.Context, req *spannerpb.ReadRequest, opts ...gax.CallOption) (spannerpb.Spanner_StreamingReadClient, error) { + // Note: This method does not add g.optsWithNextRequestID, as it is already + // manually added when creating Stream iterators for StreamingRead. return g.raw.StreamingRead(peer.NewContext(ctx, &peer.Peer{}), req, opts...) } @@ -200,7 +217,7 @@ func (g *grpcSpannerClient) BeginTransaction(ctx context.Context, req *spannerpb mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.BeginTransaction(ctx, req, opts...) + resp, err := g.raw.BeginTransaction(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -210,7 +227,7 @@ func (g *grpcSpannerClient) Commit(ctx context.Context, req *spannerpb.CommitReq mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.Commit(ctx, req, opts...) + resp, err := g.raw.Commit(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -220,7 +237,7 @@ func (g *grpcSpannerClient) Rollback(ctx context.Context, req *spannerpb.Rollbac mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - err := g.raw.Rollback(ctx, req, opts...) + err := g.raw.Rollback(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return err @@ -230,7 +247,7 @@ func (g *grpcSpannerClient) PartitionQuery(ctx context.Context, req *spannerpb.P mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.PartitionQuery(ctx, req, opts...) + resp, err := g.raw.PartitionQuery(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err @@ -240,12 +257,12 @@ func (g *grpcSpannerClient) PartitionRead(ctx context.Context, req *spannerpb.Pa mt := g.newBuiltinMetricsTracer(ctx) defer recordOperationCompletion(mt) ctx = context.WithValue(ctx, metricsTracerKey, mt) - resp, err := g.raw.PartitionRead(ctx, req, opts...) + resp, err := g.raw.PartitionRead(ctx, req, g.optsWithNextRequestID(opts)...) statusCode, _ := status.FromError(err) mt.currOp.setStatus(statusCode.Code().String()) return resp, err } func (g *grpcSpannerClient) BatchWrite(ctx context.Context, req *spannerpb.BatchWriteRequest, opts ...gax.CallOption) (spannerpb.Spanner_BatchWriteClient, error) { - return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, opts...) + return g.raw.BatchWrite(peer.NewContext(ctx, &peer.Peer{}), req, g.optsWithNextRequestID(opts)...) } diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index ae73b82230a1..08be3b21742c 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -90,6 +90,7 @@ const ( MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" MethodStreamingRead string = "EXECUTE_STREAMING_READ" MethodBatchWrite string = "BATCH_WRITE" + MethodPartitionQuery string = "PARTITION_QUERY" ) // StatementResult represents a mocked result on the test server. The result is @@ -1107,6 +1108,9 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba } func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.PartitionQueryRequest) (*spannerpb.PartitionResponse, error) { + if err := s.simulateExecutionTime(MethodPartitionQuery, req); err != nil { + return nil, err + } s.mu.Lock() if s.stopped { s.mu.Unlock() diff --git a/spanner/read.go b/spanner/read.go index eefd44b4843a..32c3f488050a 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -53,9 +53,10 @@ func stream( ctx context.Context, logger *log.Logger, meterTracerFactory *builtinMetricsTracerFactory, - rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error), + rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error), setTimestamp func(time.Time), release func(error), + gsc *grpcSpannerClient, ) *RowIterator { return streamWithReplaceSessionFunc( ctx, @@ -69,6 +70,7 @@ func stream( }, setTimestamp, release, + gsc, ) } @@ -79,18 +81,19 @@ func streamWithReplaceSessionFunc( ctx context.Context, logger *log.Logger, meterTracerFactory *builtinMetricsTracerFactory, - rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error), + rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error), replaceSession func(ctx context.Context) error, setTransactionID func(transactionID), updateTxState func(err error) error, setTimestamp func(time.Time), release func(error), + gsc *grpcSpannerClient, ) *RowIterator { ctx, cancel := context.WithCancel(ctx) ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.RowIterator") return &RowIterator{ meterTracerFactory: meterTracerFactory, - streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession), + streamd: newResumableStreamDecoder(ctx, logger, rpc, replaceSession, gsc), rowd: &partialResultSetDecoder{}, setTransactionID: setTransactionID, updateTxState: updateTxState, @@ -395,7 +398,7 @@ type resumableStreamDecoder struct { // a previous stream from the point encoded in restartToken. // rpc is always a wrapper of a Cloud Spanner query which is // resumable. - rpc func(ctx context.Context, restartToken []byte) (streamingReceiver, error) + rpc func(ctx context.Context, restartToken []byte, opts ...gax.CallOption) (streamingReceiver, error) // replaceSessionFunc is a function that can be used to replace the session // that is being used to execute the read operation. This function should @@ -437,12 +440,21 @@ type resumableStreamDecoder struct { // backoff is used for the retry settings backoff gax.Backoff + + gsc *grpcSpannerClient + + // reqIDInjector is generated once per stream, unless the stream + // gets broken and in that case a fresh one is generated. + reqIDInjector *requestIDWrap + // retryAttempt is is incremented whenever a retry happens, and it is + // reset whenever a new reqIDInjector is created afresh. + retryAttempt uint32 } // newResumableStreamDecoder creates a new resumeableStreamDecoder instance. // Parameter rpc should be a function that creates a new stream beginning at the // restartToken if non-nil. -func newResumableStreamDecoder(ctx context.Context, logger *log.Logger, rpc func(ct context.Context, restartToken []byte) (streamingReceiver, error), replaceSession func(ctx context.Context) error) *resumableStreamDecoder { +func newResumableStreamDecoder(ctx context.Context, logger *log.Logger, rpc func(ct context.Context, restartToken []byte, opts ...gax.CallOption) (streamingReceiver, error), replaceSession func(ctx context.Context) error, gsc *grpcSpannerClient) *resumableStreamDecoder { return &resumableStreamDecoder{ ctx: ctx, logger: logger, @@ -450,9 +462,18 @@ func newResumableStreamDecoder(ctx context.Context, logger *log.Logger, rpc func replaceSessionFunc: replaceSession, maxBytesBetweenResumeTokens: atomic.LoadInt32(&maxBytesBetweenResumeTokens), backoff: DefaultRetryBackoff, + gsc: gsc, } } +func (d *resumableStreamDecoder) reqIDInjectorOrNew() *requestIDWrap { + if d.reqIDInjector == nil { + d.reqIDInjector = d.gsc.generateRequestIDHeaderInjector() + d.retryAttempt = 0 + } + return d.reqIDInjector +} + // changeState fulfills state transition for resumableStateDecoder. func (d *resumableStreamDecoder) changeState(target resumableStreamDecoderState) { if d.state == queueingRetryable && d.state != target { @@ -533,11 +554,16 @@ var ( func (d *resumableStreamDecoder) next(mt *builtinMetricsTracer) bool { retryer := onCodes(d.backoff, codes.Unavailable, codes.ResourceExhausted, codes.Internal) + + // Setup and track x-goog-request-id in the manual retries for ExecuteStreamingSql. + riw := d.reqIDInjectorOrNew() + for { switch d.state { case unConnected: + d.retryAttempt++ // If no gRPC stream is available, try to initiate one. - d.stream, d.err = d.rpc(context.WithValue(d.ctx, metricsTracerKey, mt), d.resumeToken) + d.stream, d.err = d.rpc(context.WithValue(d.ctx, metricsTracerKey, mt), d.resumeToken, riw.withNextRetryAttempt(d.retryAttempt)) if d.err == nil { d.changeState(queueingRetryable) continue @@ -615,6 +641,7 @@ func (d *resumableStreamDecoder) next(mt *builtinMetricsTracer) bool { return false case finished: // If query has finished, check if there are still buffered messages. + d.reqIDInjector = nil if d.q.empty() { // No buffered PartialResultSet. return false diff --git a/spanner/read_test.go b/spanner/read_test.go index 9f6f3602595b..47a9e396b072 100644 --- a/spanner/read_test.go +++ b/spanner/read_test.go @@ -640,7 +640,7 @@ func TestRsdNonblockingStates(t *testing.T) { name string resumeTokens [][]byte prsErrors []PartialResultSetExecutionTime - rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) + rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) sql string // Expected values want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller @@ -713,7 +713,7 @@ func TestRsdNonblockingStates(t *testing.T) { queueingRetryable, // got foo-02 aborted, // got error }, - wantErr: status.Errorf(codes.Unknown, "I quit"), + wantErr: ToSpannerError(status.Errorf(codes.Unknown, "I quit")), }, { // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable @@ -778,7 +778,7 @@ func TestRsdNonblockingStates(t *testing.T) { s = append(s, aborted) // Error happens return s }(), - wantErr: status.Errorf(codes.Unknown, "Just Abort It"), + wantErr: ToSpannerError(status.Errorf(codes.Unknown, "Just Abort It")), }, } for _, test := range tests { @@ -796,12 +796,12 @@ func TestRsdNonblockingStates(t *testing.T) { } if test.rpc == nil { - test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + test.rpc = func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: test.sql, ResumeToken: resumeToken, - }) + }, opts...) } } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -812,6 +812,7 @@ func TestRsdNonblockingStates(t *testing.T) { nil, test.rpc, nil, + mc.(*grpcSpannerClient), ) st := []resumableStreamDecoderState{} var lastErr error @@ -879,7 +880,7 @@ func TestRsdNonblockingStates(t *testing.T) { } // Verify error message. if !testEqual(lastErr, test.wantErr) { - t.Fatalf("got error %v, want %v", lastErr, test.wantErr) + t.Fatalf("Error mismatch\n\tGot: %v\n\tWant: %v", lastErr, test.wantErr) } return } @@ -905,7 +906,7 @@ func TestRsdBlockingStates(t *testing.T) { for _, test := range []struct { name string resumeTokens [][]byte - rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error) + rpc func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) sql string // Expected values want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller @@ -917,7 +918,7 @@ func TestRsdBlockingStates(t *testing.T) { { // unConnected -> unConnected name: "unConnected -> unConnected", - rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + rpc: func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable") }, sql: "SELECT * from t_whatever", @@ -1094,12 +1095,12 @@ func TestRsdBlockingStates(t *testing.T) { // Avoid using test.sql directly in closure because for loop changes // test. sql := test.sql - test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + test.rpc = func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: sql, ResumeToken: resumeToken, - }) + }, opts...) } } ctx, cancel := context.WithCancel(context.Background()) @@ -1110,6 +1111,7 @@ func TestRsdBlockingStates(t *testing.T) { nil, test.rpc, nil, + mc.(*grpcSpannerClient), ) // Override backoff to make the test run faster. r.backoff = gax.Backoff{ @@ -1274,16 +1276,17 @@ func TestQueueBytes(t *testing.T) { decoder := newResumableStreamDecoder( ctx, nil, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: "SELECT t.key key, t.value value FROM t_mock t", ResumeToken: resumeToken, - }) + }, opts...) sr.rpcReceiver = r return sr, err }, nil, + mc.(*grpcSpannerClient), ) sizeOfPRS := proto.Size(&sppb.PartialResultSet{ @@ -1372,17 +1375,17 @@ func TestResumeToken(t *testing.T) { streaming := func() *RowIterator { return stream(context.Background(), nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: query, ResumeToken: resumeToken, - }) + }, opts...) sr.rpcReceiver = r return sr, err }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) } // Establish a stream to mock cloud spanner server. @@ -1517,17 +1520,17 @@ func TestGrpcReconnect(t *testing.T) { r := -1 // Establish a stream to mock cloud spanner server. iter := stream(context.Background(), nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { r++ return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: SelectSingerIDAlbumIDAlbumTitleFromAlbums, ResumeToken: resumeToken, - }) + }, opts...) }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) defer iter.Stop() for { _, err := iter.Next() @@ -1570,15 +1573,15 @@ func TestCancelTimeout(t *testing.T) { go func() { // Establish a stream to mock cloud spanner server. iter := stream(ctx, nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: SelectSingerIDAlbumIDAlbumTitleFromAlbums, ResumeToken: resumeToken, - }) + }, opts...) }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) defer iter.Stop() for { _, err = iter.Next() @@ -1607,15 +1610,15 @@ func TestCancelTimeout(t *testing.T) { go func() { // Establish a stream to mock cloud spanner server. iter := stream(ctx, nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: SelectSingerIDAlbumIDAlbumTitleFromAlbums, ResumeToken: resumeToken, - }) + }, opts...) }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) defer iter.Stop() for { _, err = iter.Next() @@ -1687,15 +1690,15 @@ func TestRowIteratorDo(t *testing.T) { nRows := 0 iter := stream(context.Background(), nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: SelectSingerIDAlbumIDAlbumTitleFromAlbums, ResumeToken: resumeToken, - }) + }, opts...) }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) err = iter.Do(func(r *Row) error { nRows++; return nil }) if err != nil { t.Errorf("Using Do: %v", err) @@ -1722,15 +1725,15 @@ func TestRowIteratorDoWithError(t *testing.T) { } iter := stream(context.Background(), nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: SelectSingerIDAlbumIDAlbumTitleFromAlbums, ResumeToken: resumeToken, - }) + }, opts...) }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) injected := errors.New("Failed iterator") err = iter.Do(func(r *Row) error { return injected }) if err != injected { @@ -1756,15 +1759,15 @@ func TestIteratorStopEarly(t *testing.T) { } iter := stream(ctx, nil, c.metricsTracerFactory, - func(ct context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ct context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{ Session: session.Name, Sql: SelectSingerIDAlbumIDAlbumTitleFromAlbums, ResumeToken: resumeToken, - }) + }, opts...) }, nil, - func(error) {}) + func(error) {}, mc.(*grpcSpannerClient)) _, err = iter.Next() if err != nil { t.Fatalf("before Stop: %v", err) diff --git a/spanner/request_id_header.go b/spanner/request_id_header.go new file mode 100644 index 000000000000..9fcd9377f432 --- /dev/null +++ b/spanner/request_id_header.go @@ -0,0 +1,263 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "math" + "math/big" + "sync/atomic" + "time" + + "github.com/googleapis/gax-go/v2" + "google.golang.org/api/iterator" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// randIDForProcess is a strongly randomly generated value derived +// from a uint64, and in the range [0, maxUint64]. +var randIDForProcess string + +func init() { + bigMaxInt64, _ := new(big.Int).SetString(fmt.Sprintf("%d", uint64(math.MaxUint64)), 10) + if g, w := bigMaxInt64.Uint64(), uint64(math.MaxUint64); g != w { + panic(fmt.Sprintf("mismatch in randIDForProcess.maxUint64:\n\tGot: %d\n\tWant: %d", g, w)) + } + r64, err := rand.Int(rand.Reader, bigMaxInt64) + if err != nil { + panic(err) + } + randIDForProcess = r64.String() +} + +// Please bump this version whenever this implementation +// executes on the plans of a new specification. +const xSpannerRequestIDVersion uint8 = 1 + +const xSpannerRequestIDHeader = "x-goog-spanner-request-id" + +// optsWithNextRequestID bundles priors with a new header "x-goog-spanner-request-id" +func (g *grpcSpannerClient) optsWithNextRequestID(priors []gax.CallOption) []gax.CallOption { + return append(priors, &retryerWithRequestID{g}) +} + +func (g *grpcSpannerClient) prepareRequestIDTrackers(clientID int, channelID uint64, nthRequest *atomic.Uint32) { + g.id = clientID // The ID derived from the SpannerClient. + g.channelID = channelID + g.nthRequest = nthRequest +} + +// retryerWithRequestID is a gax.CallOption that injects "x-goog-spanner-request-id" +// into every RPC, and it appropriately increments the RPC's ordinal number per retry. +type retryerWithRequestID struct { + gsc *grpcSpannerClient +} + +var _ gax.CallOption = (*retryerWithRequestID)(nil) + +func (g *grpcSpannerClient) appendRequestIDToGRPCOptions(priors []grpc.CallOption, nthRequest, attempt uint32) []grpc.CallOption { + // Each value should be added in Decimal, unpadded. + requestID := fmt.Sprintf("%d.%s.%d.%d.%d.%d", xSpannerRequestIDVersion, randIDForProcess, g.id, g.channelID, nthRequest, attempt) + md := metadata.MD{xSpannerRequestIDHeader: []string{requestID}} + return append(priors, grpc.Header(&md)) +} + +type requestID string + +// augmentErrorWithRequestID introspects error converting it to an *.Error and +// attaching the subject requestID, unless it is one of the following: +// * nil +// * context.Canceled +// * io.EOF +// * iterator.Done +// of which in this case, the original error will be attached as is, since those +// are sentinel errors used to break sensitive conditions like ending iterations. +func (r requestID) augmentErrorWithRequestID(err error) error { + if err == nil { + return nil + } + + switch err { + case iterator.Done, io.EOF, context.Canceled: + return err + + default: + potentialCommit := errors.Is(err, context.DeadlineExceeded) + if code := status.Code(err); code == codes.DeadlineExceeded { + potentialCommit = true + } + sErr := toSpannerErrorWithCommitInfo(err, potentialCommit) + if sErr == nil { + return err + } + + spErr := sErr.(*Error) + spErr.RequestID = string(r) + return spErr + } +} + +func gRPCCallOptionsToRequestID(opts []grpc.CallOption) (reqID requestID, found bool) { + for _, opt := range opts { + hdrOpt, ok := opt.(grpc.HeaderCallOption) + if !ok { + continue + } + + metadata := hdrOpt.HeaderAddr + reqIDs := metadata.Get(xSpannerRequestIDHeader) + if len(reqIDs) != 0 && len(reqIDs[0]) != 0 { + reqID = requestID(reqIDs[0]) + found = true + break + } + } + return +} + +func (wr *requestIDHeaderInjector) interceptUnary(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + // It is imperative to search for the requestID before the call + // because gRPC's internals will consume the headers. + reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + err := invoker(ctx, method, req, reply, cc, opts...) + if !foundRequestID { + return err + } + return reqID.augmentErrorWithRequestID(err) +} + +type requestIDErrWrappingClientStream struct { + grpc.ClientStream + reqID requestID +} + +func (rew *requestIDErrWrappingClientStream) processFromOutgoingContext(err error) error { + if err == nil { + return nil + } + return rew.reqID.augmentErrorWithRequestID(err) +} + +func (rew *requestIDErrWrappingClientStream) SendMsg(msg any) error { + err := rew.ClientStream.SendMsg(msg) + return rew.processFromOutgoingContext(err) +} + +func (rew *requestIDErrWrappingClientStream) RecvMsg(msg any) error { + err := rew.ClientStream.RecvMsg(msg) + return rew.processFromOutgoingContext(err) +} + +var _ grpc.ClientStream = (*requestIDErrWrappingClientStream)(nil) + +type requestIDHeaderInjector int + +func (wr *requestIDHeaderInjector) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // It is imperative to search for the requestID before the call + // because gRPC's internals will consume the headers. + reqID, foundRequestID := gRPCCallOptionsToRequestID(opts) + cs, err := streamer(ctx, desc, cc, method, opts...) + if !foundRequestID { + return cs, err + } + wcs := &requestIDErrWrappingClientStream{cs, reqID} + if err == nil { + return wcs, nil + } + + return wcs, reqID.augmentErrorWithRequestID(err) +} + +func (wr *retryerWithRequestID) Resolve(cs *gax.CallSettings) { + nthRequest := wr.gsc.nextNthRequest() + attempt := uint32(1) + // Inject the first request-id header. + // Note: after every gax.Invoke call, all the gRPC option headers are cleared out + // and nullified, but yet cs.GRPC still contains a reference to the inserted *metadata.MD + // just that it got cleared out and nullified. However, for retries we need to retain control + // of the entry to re-insert the updated request-id on every call, hence why we are creating + // and retaining a pointer reference to the metadata and shall be re-inserting the header value + // on every retry. + md := new(metadata.MD) + wr.generateAndInsertRequestID(md, nthRequest, attempt) + // Insert our grpc.CallOption that'll be updated by reference on every retry attempt. + cs.GRPC = append(cs.GRPC, grpc.Header(md)) + + if cs.Retry == nil { + // If there was no retry manager, our journey has ended. + return + } + + originalRetryer := cs.Retry() + newRetryer := func() gax.Retryer { + return (wrapRetryFn)(func(err error) (pause time.Duration, shouldRetry bool) { + attempt++ + wr.generateAndInsertRequestID(md, nthRequest, attempt) + return originalRetryer.Retry(err) + }) + } + cs.Retry = newRetryer +} + +func (wr *retryerWithRequestID) generateAndInsertRequestID(md *metadata.MD, nthRequest, attempt uint32) { + wr.gsc.generateAndInsertRequestID(md, nthRequest, attempt) +} + +func (gsc *grpcSpannerClient) generateAndInsertRequestID(md *metadata.MD, nthRequest, attempt uint32) { + // Google Engineering has requested that each value be added in Decimal unpadded. + // Should we have a standardized endianness: Little Endian or Big Endian? + reqID := fmt.Sprintf("%d.%s.%d.%d.%d.%d", xSpannerRequestIDVersion, randIDForProcess, gsc.id, gsc.channelID, nthRequest, attempt) + if *md == nil { + *md = metadata.MD{} + } + md.Set(xSpannerRequestIDHeader, reqID) +} + +type wrapRetryFn func(err error) (time.Duration, bool) + +var _ gax.Retryer = (wrapRetryFn)(nil) + +func (fn wrapRetryFn) Retry(err error) (time.Duration, bool) { + return fn(err) +} + +func (g *grpcSpannerClient) nextNthRequest() uint32 { + return g.nthRequest.Add(1) +} + +type requestIDWrap struct { + md *metadata.MD + nthRequest uint32 + gsc *grpcSpannerClient +} + +func (gsc *grpcSpannerClient) generateRequestIDHeaderInjector() *requestIDWrap { + // Setup and track x-goog-request-id. + md := new(metadata.MD) + return &requestIDWrap{md: md, nthRequest: gsc.nextNthRequest(), gsc: gsc} +} + +func (riw *requestIDWrap) withNextRetryAttempt(attempt uint32) gax.CallOption { + riw.gsc.generateAndInsertRequestID(riw.md, riw.nthRequest, attempt) + // If no gRPC stream is available, try to initiate one. + return gax.WithGRPCOptions(grpc.Header(riw.md)) +} diff --git a/spanner/request_id_header_test.go b/spanner/request_id_header_test.go new file mode 100644 index 000000000000..7c50430c3e4d --- /dev/null +++ b/spanner/request_id_header_test.go @@ -0,0 +1,1867 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spanner + +import ( + "context" + "encoding/json" + "fmt" + "math" + "regexp" + "slices" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/go-cmp/cmp" + "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/structpb" + + "cloud.google.com/go/spanner/internal/testutil" +) + +var regRequestID = regexp.MustCompile(`^(?P\d+).(?P[a-z0-9]+)\.(?P\d+)\.(?P\d+)\.(?P\d+)\.(?P\d+)$`) + +type requestIDSegments struct { + Version uint8 `json:"vers"` + ProcessID string `json:"proc_id"` + ClientID uint32 `json:"c_id"` + RequestNo uint32 `json:"req_id"` + ChannelID uint32 `json:"ch_id"` + RPCNo uint32 `json:"rpc_id"` +} + +func (ris *requestIDSegments) String() string { + return fmt.Sprintf("%d.%s.%d.%d.%d.%d", ris.Version, ris.ProcessID, ris.ClientID, ris.ChannelID, ris.RequestNo, ris.RPCNo) +} + +func checkForMissingSpannerRequestIDHeader(opts []grpc.CallOption) (*requestIDSegments, error) { + requestID := "" + for _, opt := range opts { + if hdrOpt, ok := opt.(grpc.HeaderCallOption); ok { + hdrs := hdrOpt.HeaderAddr.Get(xSpannerRequestIDHeader) + gotRequestID := len(hdrs) != 0 && len(hdrs[0]) != 0 + if gotRequestID { + requestID = hdrs[0] + break + } + } + } + + if requestID == "" { + return nil, status.Errorf(codes.InvalidArgument, "missing %q header", xSpannerRequestIDHeader) + } + if !regRequestID.MatchString(requestID) { + return nil, status.Errorf(codes.InvalidArgument, "requestID does not conform to pattern=%q", regRequestID.String()) + } + + // Now extract the respective fields and validate that they match our rubric. + template := `{"vers":$version,"proc_id":"$randProcessId","c_id":$clientId,"req_id":$reqId,"ch_id":$channelId,"rpc_id":$rpcId}` + asJSONBytes := []byte{} + for _, submatch := range regRequestID.FindAllStringSubmatchIndex(requestID, -1) { + asJSONBytes = regRequestID.ExpandString(asJSONBytes, template, requestID, submatch) + } + recv := new(requestIDSegments) + if err := json.Unmarshal(asJSONBytes, recv); err != nil { + return nil, status.Error(codes.InvalidArgument, "could not correctly parse requestID segements: "+string(asJSONBytes)) + } + if g, w := recv.ProcessID, randIDForProcess; g != w { + return nil, status.Errorf(codes.InvalidArgument, "invalid processId, got=%q want=%q", g, w) + } + return recv, validateRequestIDSegments(recv) +} + +func validateRequestIDSegments(recv *requestIDSegments) error { + if recv == nil || recv.ProcessID == "" { + return status.Errorf(codes.InvalidArgument, "unset processId") + } + if len(recv.ProcessID) == 0 || len(recv.ProcessID) > 20 { + return status.Errorf(codes.InvalidArgument, "processId must be in the range (0, maxUint64), got %d", len(recv.ProcessID)) + } + if g := recv.ClientID; g < 1 { + return status.Errorf(codes.InvalidArgument, "clientID must be >= 1, got=%d", g) + } + if g := recv.RequestNo; g < 1 { + return status.Errorf(codes.InvalidArgument, "requestNumber must be >= 1, got=%d", g) + } + if g := recv.ChannelID; g < 1 { + return status.Errorf(codes.InvalidArgument, "channelID must be >= 1, got=%d", g) + } + if g := recv.RPCNo; g < 1 { + return status.Errorf(codes.InvalidArgument, "rpcID must be >= 1, got=%d", g) + } + return nil +} + +func TestRequestIDHeader_sentOnEveryClientCall(t *testing.T) { + interceptorTracker := newInterceptorTracker() + + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + sqlSELECT1 := "SELECT 1" + resultSet := &sppb.ResultSet{ + Rows: []*structpb.ListValue{ + {Values: []*structpb.Value{ + {Kind: &structpb.Value_NumberValue{NumberValue: 1}}, + }}, + }, + Metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "Int", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + }, + }, + }, + } + result := &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + } + server.TestSpanner.PutStatementResult(sqlSELECT1, result) + + txn := sc.ReadOnlyTransaction() + defer txn.Close() + + ctx := context.Background() + stmt := NewStatement(sqlSELECT1) + rowIter := txn.Query(ctx, stmt) + defer rowIter.Stop() + for { + rows, err := rowIter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + _ = rows + } + + if interceptorTracker.unaryCallCount() < 1 { + t.Error("unaryClientCall was not invoked") + } + if interceptorTracker.streamCallCount() < 1 { + t.Error("streamClientCall was not invoked") + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +type interceptorTracker struct { + nUnaryClientCalls *atomic.Uint64 + nStreamClientCalls *atomic.Uint64 + + mu sync.Mutex // mu protects the fields down below. + unaryClientRequestIDSegments []*requestIDSegments + streamClientRequestIDSegments []*requestIDSegments +} + +func (it *interceptorTracker) unaryCallCount() uint64 { + return it.nUnaryClientCalls.Load() +} + +func (it *interceptorTracker) streamCallCount() uint64 { + return it.nStreamClientCalls.Load() +} + +func (it *interceptorTracker) validateRequestIDsMonotonicity() error { + if err := ensureMonotonicityOfRequestIDs(it.unaryClientRequestIDSegments); err != nil { + return fmt.Errorf("unaryClientRequestIDs: %w", err) + } + if err := ensureMonotonicityOfRequestIDs(it.streamClientRequestIDSegments); err != nil { + return fmt.Errorf("streamClientRequestIDs: %w", err) + } + return nil +} + +type interceptSummary struct { + ProcIDs []string `json:"proc_ids"` + MaxChannelID uint32 `json:"max_ch_id"` + MinChannelID uint32 `json:"min_ch_id"` + MaxClientID uint32 `json:"max_c_id"` + MinClientID uint32 `json:"min_c_id"` + MaxRPCID uint32 `json:"max_rpc_id"` + MinRPCID uint32 `json:"min_rpc_id"` +} + +func (it *interceptorTracker) summarize() (unarySummary, streamSummary *interceptSummary) { + return computeSummary(it.unaryClientRequestIDSegments), computeSummary(it.streamClientRequestIDSegments) +} + +func computeSummary(segments []*requestIDSegments) *interceptSummary { + summary := new(interceptSummary) + summary.MinRPCID = math.MaxUint32 + summary.MaxRPCID = 0 + summary.MinClientID = math.MaxUint32 + summary.MaxClientID = 0 + summary.MinChannelID = math.MaxUint32 + summary.MaxChannelID = 0 + for _, segment := range segments { + if len(summary.ProcIDs) == 0 || summary.ProcIDs[len(summary.ProcIDs)-1] != segment.ProcessID { + summary.ProcIDs = append(summary.ProcIDs, segment.ProcessID) + } + if segment.ClientID < summary.MinClientID { + summary.MinClientID = segment.ClientID + } + if segment.ClientID > summary.MaxClientID { + summary.MaxClientID = segment.ClientID + } + if segment.RPCNo < summary.MinRPCID { + summary.MinRPCID = segment.RPCNo + } + if segment.RPCNo > summary.MaxRPCID { + summary.MaxRPCID = segment.RPCNo + } + if segment.ChannelID < summary.MinChannelID { + summary.MinChannelID = segment.ChannelID + } + if segment.ChannelID > summary.MaxChannelID { + summary.MaxChannelID = segment.ChannelID + } + if segment.RPCNo > summary.MaxRPCID { + summary.MaxRPCID = segment.RPCNo + } + if segment.RPCNo < summary.MinRPCID { + summary.MinRPCID = segment.RPCNo + } + } + return summary +} + +func ensureMonotonicityOfRequestIDs(requestIDs []*requestIDSegments) error { + for _, segment := range requestIDs { + if err := validateRequestIDSegments(segment); err != nil { + return err + } + } + + // 2. Compare the current against previous requestID which requires at least 2 elements. + for i := 1; i < len(requestIDs); i++ { + rCurr, rPrev := requestIDs[i], requestIDs[i-1] + if rPrev.ProcessID != rCurr.ProcessID { + return fmt.Errorf("processID mismatch: #[%d].ProcessID=%q, #[%d].ProcessID=%q", i, rCurr.ProcessID, i-1, rPrev.ProcessID) + } + if rPrev.ClientID == rCurr.ClientID { + if rPrev.ChannelID == rCurr.ChannelID { + if rPrev.RequestNo == rCurr.RequestNo { + if rPrev.RPCNo >= rCurr.RPCNo { + return fmt.Errorf("sameChannelID, sameRequestNo yet #[%d].RPCNo=%d >= #[%d].RPCNo=%d\n\n\t%s\n\t%s", i-1, rPrev.RPCNo, i, rCurr.RPCNo, rPrev, rCurr) + } + } + } + + // In the case of retries, we shall might have the same request + // number, but rpc id must be monotonically increasing. + if false && rPrev.RequestNo == rCurr.RequestNo { + if rPrev.RPCNo >= rCurr.RPCNo { + return fmt.Errorf("sameClientID but rpcNo mismatch: #[%d].RPCNo=%d >= #[%d].RPCNo=%d", i-1, rPrev.RPCNo, i, rCurr.RPCNo) + } + } + } else if rPrev.ClientID > rCurr.ClientID { + // For requests that execute in parallel such as with PartitionQuery, + // we could have requests from previous clients executing slower than + // the newest client, hence this is not an error. + } + } + + // All checks passed so good to go. + return nil +} + +func TestRequestIDHeader_ensureMonotonicityOfRequestIDs(t *testing.T) { + procID := randIDForProcess + tests := []struct { + name string + in []*requestIDSegments + wantErr string + }{ + {name: "no values", wantErr: ""}, + {name: "1 value", in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, RequestNo: 1, ChannelID: 3, RPCNo: 1}, + }, wantErr: ""}, + {name: "Different processIDs", in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, RequestNo: 1, RPCNo: 1, ChannelID: 1}, + {ProcessID: strings.Repeat("a", len(procID)), ClientID: 1, RequestNo: 1, RPCNo: 2, ChannelID: 1}, + }, wantErr: "processID mismatch"}, + { + name: "Different clientID, prev has higher value", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 2, RequestNo: 1, RPCNo: 1, ChannelID: 1}, + {ProcessID: procID, ClientID: 1, RequestNo: 1, RPCNo: 1, ChannelID: 1}, + }, + wantErr: "", // Requests can occur in parallel. + }, + { + name: "Different clientID, prev has lower value", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, RPCNo: 1, ChannelID: 1, RequestNo: 1}, + {ProcessID: procID, ClientID: 2, RPCNo: 1, ChannelID: 1, RequestNo: 1}, + }, + wantErr: "", + }, + { + name: "Same channelID, prev has same RequestNo", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, ChannelID: 1, RPCNo: 1, RequestNo: 8}, + {ProcessID: procID, ClientID: 1, ChannelID: 1, RPCNo: 1, RequestNo: 8}, + }, + wantErr: "sameChannelID, sameRequestNo yet #[0].RPCNo=1 >= #[1].RPCNo=1", + }, + { + name: "Same clientID, different ChannelID prev has same RequestNo", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, ChannelID: 1, RequestNo: 1, RPCNo: 1}, + {ProcessID: procID, ClientID: 1, ChannelID: 2, RequestNo: 1, RPCNo: 1}, + }, + wantErr: "", + }, + { + name: "Same clientID, same ChannelID, same RequestNo, different RPCNo", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, ChannelID: 4, RequestNo: 3, RPCNo: 1}, + {ProcessID: procID, ClientID: 1, ChannelID: 4, RequestNo: 3, RPCNo: 4}, + }, + wantErr: "", + }, + { + name: "Same clientID, prev has higher RPCNo", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, ChannelID: 1, RequestNo: 1, RPCNo: 2}, + {ProcessID: procID, ClientID: 1, ChannelID: 1, RequestNo: 1, RPCNo: 1}, + }, + wantErr: "sameRequestNo yet #[0].RPCNo=2 >= #[1].RPCNo=1", + }, + { + name: "Same clientID, same channelID, prev has lower RPCNo", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 1, RequestNo: 2, ChannelID: 1, RPCNo: 1}, + {ProcessID: procID, ClientID: 1, RequestNo: 2, ChannelID: 1, RPCNo: 2}, + }, + wantErr: "", + }, + { + name: "Same clientID, prev has higher clientID", + in: []*requestIDSegments{ + {ProcessID: procID, ClientID: 2, RequestNo: 1, RPCNo: 1, ChannelID: 1}, + {ProcessID: procID, ClientID: 1, RequestNo: 1, RPCNo: 1, ChannelID: 1}, + }, + wantErr: "", // Requests can execute in parallel. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 1. Each segment but be valid! + err := ensureMonotonicityOfRequestIDs(tt.in) + if tt.wantErr != "" { + if err == nil { + t.Fatal("Expected a non-nil error") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Error mismatch\n\t%q\ncould not be found in\n\t%q", tt.wantErr, err) + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + }) + } +} + +func (it *interceptorTracker) unaryClientInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + it.nUnaryClientCalls.Add(1) + reqID, err := checkForMissingSpannerRequestIDHeader(opts) + if err != nil { + return err + } + + it.mu.Lock() + it.unaryClientRequestIDSegments = append(it.unaryClientRequestIDSegments, reqID) + it.mu.Unlock() + + // fmt.Printf("unary.method=%q\n", method) + // fmt.Printf("method=%q\nReq: %#v\nRes: %#v\n", method, req, reply) + // Otherwise proceed with the call. + return invoker(ctx, method, req, reply, cc, opts...) +} + +func (it *interceptorTracker) streamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + it.nStreamClientCalls.Add(1) + reqID, err := checkForMissingSpannerRequestIDHeader(opts) + if err != nil { + return nil, err + } + + it.mu.Lock() + it.streamClientRequestIDSegments = append(it.streamClientRequestIDSegments, reqID) + it.mu.Unlock() + + // fmt.Printf("stream.method=%q\n", method) + // Otherwise proceed with the call. + return streamer(ctx, desc, cc, method, opts...) +} + +func newInterceptorTracker() *interceptorTracker { + return &interceptorTracker{ + nUnaryClientCalls: new(atomic.Uint64), + nStreamClientCalls: new(atomic.Uint64), + } +} + +func TestRequestIDHeader_onRetriesWithFailedTransactionCommit(t *testing.T) { + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // First commit will fail, and the retry will begin a new transaction. + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + + ctx := context.Background() + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId"}, []any{int64(1)}), + } + + if _, err := sc.Apply(ctx, ms); err != nil { + t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", err) + } + + if _, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, // First commit fails. + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, // Second commit succeeds. + }); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(5); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g := interceptorTracker.streamCallCount(); g > 0 { + t.Errorf("streamClientCall was unexpectedly invoked %d times", g) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +// Tests that SessionNotFound errors are retried. +func TestRequestIDHeader_retriesOnSessionNotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + serverErr := newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s") + server.TestSpanner.PutExecutionTime(testutil.MethodBeginTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{serverErr, serverErr, serverErr}, + }) + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{serverErr}, + }) + + txn := sc.ReadOnlyTransaction() + defer txn.Close() + + var wantErr error + if _, _, got := txn.acquire(ctx); !testEqual(wantErr, got) { + t.Fatalf("Expect acquire to succeed, got %v, want %v.", got, wantErr) + } + + // The server error should lead to a retry of the BeginTransaction call and + // a valid session handle to be returned that will be used by the following + // requests. Note that calling txn.Query(...) does not actually send the + // query to the (mock) server. That is done at the first call to + // RowIterator.Next. The following statement only verifies that the + // transaction is in a valid state and received a valid session handle. + if got := txn.Query(ctx, NewStatement("SELECT 1")); !testEqual(wantErr, got.err) { + t.Fatalf("Expect Query to succeed, got %v, want %v.", got.err, wantErr) + } + + if got := txn.Read(ctx, "Users", KeySets(Key{"alice"}, Key{"bob"}), []string{"name", "email"}); !testEqual(wantErr, got.err) { + t.Fatalf("Expect Read to succeed, got %v, want %v.", got.err, wantErr) + } + + wantErr = ToSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")) + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []any{int64(1), "Foo", int64(50)}), + Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []any{int64(2), "Bar", int64(1)}), + } + _, got := sc.Apply(ctx, ms, ApplyAtLeastOnce()) + if !testEqual(wantErr, got) { + t.Fatalf("Expect Apply to fail\nGot: %v\nWant: %v\n", got, wantErr) + } + gotSErr := got.(*Error) + if gotSErr.RequestID == "" { + t.Fatal("Expected a non-blank requestID") + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(8); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g := interceptorTracker.streamCallCount(); g > 0 { + t.Errorf("streamClientCall was unexpectedly invoked %d times", g) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_BatchDMLWithMultipleDML(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + ctx := context.Background() + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + updateBarSetFoo := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + return err + } + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}, {SQL: updateBarSetFoo}}); err != nil { + return err + } + if _, err = tx.Update(ctx, Statement{SQL: updateBarSetFoo}); err != nil { + return err + } + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateBarSetFoo}}) + return err + }) + if err != nil { + t.Fatal(err) + } + + gotReqs, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.CommitRequest{}, + }) + if err != nil { + t.Fatal(err) + } + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + t.Errorf("got %d, want %d", got, want) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(6); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g := interceptorTracker.streamCallCount(); g > 0 { + t.Errorf("streamClientCall was unexpectedly invoked %d times", g) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_clientBatchWrite(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []any{"foo1", 1}}, + }}, + } + iter := sc.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g, w := interceptorTracker.streamCallCount(), uint64(1); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ClientBatchWriteWithSessionNotFound(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + server.TestSpanner.PutExecutionTime( + testutil.MethodBatchWrite, + testutil.SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []any{"foo1", 1}}, + }}, + } + iter := sc.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a retry for BatchWrite after the first SessionNotFound error, hence expecting 2 calls. + if g, w := interceptorTracker.streamCallCount(), uint64(2); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ClientBatchWriteWithError(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + injectedErr := status.Error(codes.InvalidArgument, "Invalid argument") + server.TestSpanner.PutExecutionTime( + testutil.MethodBatchWrite, + testutil.SimulatedExecutionTime{Errors: []error{injectedErr}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []any{"foo1", 1}}, + }}, + } + iter := sc.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + err := iter.Do(doFunc) + if err == nil { + t.Fatal("Expected an error") + } + + gotSErr := err.(*Error) + if gotSErr.RequestID == "" { + t.Fatal("Expected a non-blank requestID") + } + + if responseCount != 0 { + t.Fatalf("Do unexpectedly called %d times", responseCount) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(1); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_PartitionQueryWithoutError(t *testing.T) { + testRequestIDHeaderPartitionQuery(t, false) +} + +func TestRequestIDHeader_PartitionQueryWithError(t *testing.T) { + testRequestIDHeaderPartitionQuery(t, true) +} + +func testRequestIDHeaderPartitionQuery(t *testing.T, mustErrorOnPartitionQuery bool) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // The request will initially fail, and be retried. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + if mustErrorOnPartitionQuery { + server.TestSpanner.PutExecutionTime(testutil.MethodPartitionQuery, + testutil.SimulatedExecutionTime{ + Errors: []error{newAbortedErrorWithMinimalRetryDelay()}, + }) + } + + sqlFromSingers := "SELECT * FROM Singers" + resultSet := &sppb.ResultSet{ + Rows: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 1}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Bruce"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "Wayne"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 2}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Robin"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "SideKick"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 3}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Gordon"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "Commissioner"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 4}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Joker"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "None"}}, + }, + }), + structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "SingerId": {Kind: &structpb.Value_NumberValue{NumberValue: 5}}, + "FirstName": {Kind: &structpb.Value_StringValue{StringValue: "Riddler"}}, + "LastName": {Kind: &structpb.Value_StringValue{StringValue: "None"}}, + }, + }), + }}, + }, + Metadata: &sppb.ResultSetMetadata{ + RowType: &sppb.StructType{ + Fields: []*sppb.StructType_Field{ + {Name: "SingerId", Type: &sppb.Type{Code: sppb.TypeCode_INT64}}, + {Name: "FirstName", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + {Name: "LastName", Type: &sppb.Type{Code: sppb.TypeCode_STRING}}, + }, + }, + }, + } + result := &testutil.StatementResult{ + Type: testutil.StatementResultResultSet, + ResultSet: resultSet, + } + server.TestSpanner.PutStatementResult(sqlFromSingers, result) + + ctx := context.Background() + txn, err := sc.BatchReadOnlyTransaction(ctx, StrongRead()) + + if err != nil { + t.Fatal(err) + } + defer txn.Close() + + // Singer represents the elements in a row from the Singers table. + type Singer struct { + SingerID int64 + FirstName string + LastName string + SingerInfo []byte + } + stmt := Statement{SQL: "SELECT * FROM Singers;"} + partitions, err := txn.PartitionQuery(ctx, stmt, PartitionOptions{}) + + if mustErrorOnPartitionQuery { + // The methods invoked should be: ['/BatchCreateSessions', '/CreateSession', '/BeginTransaction', '/PartitionQuery'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + return + } + + if err != nil { + t.Fatal(err) + } + + wg := new(sync.WaitGroup) + for i, p := range partitions { + wg.Add(1) + go func(i int, p *Partition) { + defer wg.Done() + iter := txn.Execute(ctx, p) + defer iter.Stop() + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + var s Singer + if err := row.ToStruct(&s); err != nil { + _ = err + } + _ = s + } + }(i, p) + } + wg.Wait() + + // The methods invoked should be: ['/BatchCreateSessions', '/CreateSession', '/BeginTransaction', '/PartitionQuery'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ReadWriteTransactionUpdate(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + ctx := context.Background() + updateSQL := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + if _, err = tx.Update(ctx, Statement{SQL: updateSQL}); err != nil { + return err + } + if _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateSQL}, {SQL: updateSQL}}); err != nil { + return err + } + if _, err = tx.Update(ctx, Statement{SQL: updateSQL}); err != nil { + return err + } + _, err = tx.BatchUpdate(ctx, []Statement{{SQL: updateSQL}}) + return err + }) + if err != nil { + t.Fatal(err) + } + + gotReqs, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteBatchDmlRequest{}, + &sppb.CommitRequest{}, + }) + if err != nil { + t.Fatal(err) + } + muxCreateBuffer := 0 + if isMultiplexEnabled { + muxCreateBuffer = 1 + } + if got, want := gotReqs[1+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(1); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[2+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(2); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[3+muxCreateBuffer].(*sppb.ExecuteSqlRequest).Seqno, int64(3); got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := gotReqs[4+muxCreateBuffer].(*sppb.ExecuteBatchDmlRequest).Seqno, int64(4); got != want { + t.Errorf("got %d, want %d", got, want) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(6); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_ReadWriteTransactionBatchUpdateWithOptions(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + _, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + ctx := context.Background() + selectSQL := testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums + updateSQL := testutil.UpdateBarSetFoo + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + iter := tx.QueryWithOptions(ctx, NewStatement(selectSQL), QueryOptions{}) + iter.Next() + iter.Stop() + + qo := QueryOptions{} + iter = tx.ReadWithOptions(ctx, "FOO", AllKeys(), []string{"BAR"}, &ReadOptions{Priority: qo.Priority}) + iter.Next() + iter.Stop() + + tx.UpdateWithOptions(ctx, NewStatement(updateSQL), qo) + tx.BatchUpdateWithOptions(ctx, []Statement{ + NewStatement(updateSQL), + }, qo) + return nil + }) + if err != nil { + t.Fatal(err) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(4); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // The methods invoked should be: ['/ExecuteStreamingSql', '/StreamingRead'] + if g, w := interceptorTracker.streamCallCount(), uint64(2); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDHeader_multipleParallelCallsWithConventionalCustomerCalls(t *testing.T) { + t.Parallel() + + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + // We created exactly 1 client. + _, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + beginningClientID := uint32(sc.sc.nthClient) + + ctx := context.Background() + selectSQL := testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums + updateSQL := testutil.UpdateBarSetFoo + + // We are going to invoke 10 calls in parallel. + n := uint64(80) + wg := new(sync.WaitGroup) + semaCh := make(chan bool) + semaWg := new(sync.WaitGroup) + semaWg.Add(int(n)) + for i := uint64(0); i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + semaWg.Done() + <-semaCh + _, err := sc.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { + iter := tx.QueryWithOptions(ctx, NewStatement(selectSQL), QueryOptions{}) + iter.Next() + iter.Stop() + + qo := QueryOptions{} + iter = tx.ReadWithOptions(ctx, "FOO", AllKeys(), []string{"BAR"}, &ReadOptions{Priority: qo.Priority}) + iter.Next() + iter.Stop() + + tx.UpdateWithOptions(ctx, NewStatement(updateSQL), qo) + tx.BatchUpdateWithOptions(ctx, []Statement{ + NewStatement(updateSQL), + }, qo) + return nil + }) + if err != nil { + panic(err) + } + }() + } + + go func() { + semaWg.Wait() + close(semaCh) + }() + + wg.Wait() + + maxChannelID := uint32(sc.sc.connPool.Num()) + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + + gotUnarySummary, gotStreamSummary := interceptorTracker.summarize() + wantUnarySummary := &interceptSummary{ + ProcIDs: []string{randIDForProcess}, + MaxClientID: beginningClientID, + MinClientID: beginningClientID, + MaxRPCID: 1, + MinRPCID: 1, + MaxChannelID: maxChannelID, + MinChannelID: 1, + } + if diff := cmp.Diff(gotUnarySummary, wantUnarySummary); diff != "" { + t.Errorf("UnarySummary mismatch: got - want +\n%s", diff) + } + wantStreamSummary := &interceptSummary{ + ProcIDs: []string{randIDForProcess}, + MaxClientID: beginningClientID, + MinClientID: beginningClientID, + MaxRPCID: 1, + MinRPCID: 1, + MaxChannelID: maxChannelID, + MinChannelID: 1, + } + if diff := cmp.Diff(gotStreamSummary, wantStreamSummary); diff != "" { + t.Errorf("StreamSummary mismatch: got - want +\n%s", diff) + } + + // The methods invoked should be: ['/BatchCreateSessions', '/ExecuteSql', '/ExecuteBatchDml', '/Commit'] + if g, w := interceptorTracker.unaryCallCount(), uint64(245); g != w && false { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // The methods invoked should be: ['/ExecuteStreamingSql', '/StreamingRead'] + if g, w := interceptorTracker.streamCallCount(), uint64(2)*n; g != w && false { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } +} + +func newUnavailableErrorWithMinimalRetryDelay() error { + st := status.New(codes.Unavailable, "Please try again") + retry := &errdetails.RetryInfo{ + RetryDelay: durationpb.New(time.Nanosecond), + } + st, _ = st.WithDetails(retry) + return st.Err() +} + +func newInvalidArgumentError() error { + st := status.New(codes.InvalidArgument, "Invalid argument") + return st.Err() +} + +func TestRequestIDHeader_RetryOnAbortAndValidate(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // First commit will fail, and the retry will begin a new transaction. + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{ + Errors: []error{ + newUnavailableErrorWithMinimalRetryDelay(), + newUnavailableErrorWithMinimalRetryDelay(), + newUnavailableErrorWithMinimalRetryDelay(), + }, + }) + + ms := []*Mutation{ + Insert("Accounts", []string{"AccountId"}, []interface{}{int64(1)}), + } + + if _, e := sc.Apply(ctx, ms); e != nil { + t.Fatalf("ReadWriteTransaction retry on abort, got %v, want nil.", e) + } + + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.CommitRequest{}, + &sppb.CommitRequest{}, + &sppb.CommitRequest{}, + &sppb.CommitRequest{}, + }); err != nil { + t.Fatal(err) + } + + // The method CommitTransaction is retried 3 times due to the 3 retry errors, so we expect 4 invocations of it + // plus BatchCreateSession + BeginTransaction, hence a total of 6 calls. + // We expect 1 BatchCreateSessionsRequests + 6 * (BeginTransactionRequest + CommitRequest) = 13 + if g, w := interceptorTracker.unaryCallCount(), uint64(6); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + // We had a straight-up failure after the first BatchWrite call so only 1 call. + if g, w := interceptorTracker.streamCallCount(), uint64(0); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + + clientID := uint32(sc.sc.nthClient) + procID := randIDForProcess + version := xSpannerRequestIDVersion + wantUnarySegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 1, RPCNo: 1}, // BatchCreateSession + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 1}, // BeginTransaction + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 3, RPCNo: 1}, // Commit: failed on 1st attempt + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 3, RPCNo: 2}, // Commit: failed on 2nd attempt + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 3, RPCNo: 3}, // Commit: failed on 3rd attempt + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 3, RPCNo: 4}, // Commit: success on 4th attempt + } + + if diff := cmp.Diff(interceptorTracker.unaryClientRequestIDSegments, wantUnarySegments); diff != "" { + t.Fatalf("RequestID segments mismatch: got - want +\n%s", diff) + } +} + +func TestRequestIDHeader_BatchCreateSessions_Unavailable(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // BatchCreateSessions returns UNAVAILABLE and should be retried. + server.TestSpanner.PutExecutionTime(testutil.MethodBatchCreateSession, + testutil.SimulatedExecutionTime{ + Errors: []error{ + newUnavailableErrorWithMinimalRetryDelay(), + }, + }) + iter := sc.Single().Query(ctx, Statement{SQL: testutil.SelectFooFromBar}) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + } + + if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + }); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(2); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + if g, w := interceptorTracker.streamCallCount(), uint64(1); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + + clientID := uint32(sc.sc.nthClient) + procID := randIDForProcess + version := xSpannerRequestIDVersion + wantUnarySegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 1, RPCNo: 1}, // BatchCreateSession (initial attempt) + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 1, RPCNo: 2}, // BatchCreateSession (retry) + } + wantStreamingSegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 1}, // ExecuteStreamingSql + } + + if diff := cmp.Diff(interceptorTracker.unaryClientRequestIDSegments, wantUnarySegments); diff != "" { + t.Fatalf("RequestID unary segments mismatch: got - want +\n%s", diff) + } + if diff := cmp.Diff(interceptorTracker.streamClientRequestIDSegments, wantStreamingSegments); diff != "" { + t.Fatalf("RequestID streaming segments mismatch: got - want +\n%s", diff) + } +} + +func TestRequestIDHeader_SingleUseReadOnly_ExecuteStreamingSql_Unavailable(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // ExecuteStreamingSql returns UNAVAILABLE and should be retried. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{ + Errors: []error{ + newUnavailableErrorWithMinimalRetryDelay(), + newUnavailableErrorWithMinimalRetryDelay(), + }, + }) + iter := sc.Single().Query(ctx, Statement{SQL: testutil.SelectFooFromBar}) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + } + + if _, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteSqlRequest{}, + }); err != nil { + t.Fatal(err) + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + if g, w := interceptorTracker.streamCallCount(), uint64(3); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + + clientID := uint32(sc.sc.nthClient) + procID := randIDForProcess + version := xSpannerRequestIDVersion + wantUnarySegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 1, RPCNo: 1}, // BatchCreateSession + } + wantStreamingSegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 1}, // ExecuteStreamingSql (initial attempt) + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 2}, // ExecuteStreamingSql (retry) + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 3}, // ExecuteStreamingSql (retry) + } + + if diff := cmp.Diff(interceptorTracker.unaryClientRequestIDSegments, wantUnarySegments); diff != "" { + t.Fatalf("RequestID unary segments mismatch: got - want +\n%s", diff) + } + if diff := cmp.Diff(interceptorTracker.streamClientRequestIDSegments, wantStreamingSegments); diff != "" { + t.Fatalf("RequestID streaming segments mismatch: got - want +\n%s", diff) + } +} + +func TestRequestIDHeader_SingleUseReadOnly_ExecuteStreamingSql_InvalidArgument(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // Simulate that ExecuteStreamingSql is slow. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{Errors: []error{newInvalidArgumentError()}}) + + iter := sc.Single().Query(ctx, Statement{SQL: testutil.SelectFooFromBar}) + defer iter.Stop() + _, err := iter.Next() + if err == nil { + t.Fatal("missing invalid argument error") + } + if g, w := ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + spannerError, ok := err.(*Error) + if !ok { + t.Fatal("not a Spanner error") + } + if spannerError.RequestID == "" { + t.Fatal("missing RequestID on error") + } +} + +func TestRequestIDHeader_SingleUseReadOnly_ExecuteStreamingSql_ContextDeadlineExceeded(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // Simulate that ExecuteStreamingSql is slow. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, + testutil.SimulatedExecutionTime{MinimumExecutionTime: time.Second}) + ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + iter := sc.Single().Query(ctx, Statement{SQL: testutil.SelectFooFromBar}) + defer iter.Stop() + _, err := iter.Next() + if err == nil { + t.Fatal("missing deadline exceeded error") + } + if g, w := ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + spannerError, ok := err.(*Error) + if !ok { + t.Fatal("not a Spanner error") + } + if spannerError.RequestID == "" { + t.Fatal("missing RequestID on error") + } +} + +func TestRequestIDHeader_Commit_ContextDeadlineExceeded(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + + // Simulate that Commit is slow. + server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, + testutil.SimulatedExecutionTime{MinimumExecutionTime: time.Second}) + ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + defer cancel() + _, err := sc.Apply(ctx, []*Mutation{}) + if err == nil { + t.Fatal("missing deadline exceeded error") + } + if g, w := ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + spannerError, ok := err.(*Error) + if !ok { + t.Fatal("not a Spanner error") + } + if spannerError.RequestID == "" { + t.Fatal("missing RequestID on error") + } +} + +func TestRequestIDHeader_VerifyChannelNumber(t *testing.T) { + t.Parallel() + + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 100, + MaxOpened: 400, + incStep: 25, + }, + NumChannels: 4, + } + + _, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + // Wait for the session pool to be initialized. + sp := sc.idleSessions + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + if uint64(sp.idleList.Len()) != clientConfig.MinOpened { + return fmt.Errorf("num open sessions mismatch\nWant: %d\nGot: %d", sp.MinOpened, sp.numOpened) + } + return nil + }) + // Verify that we've seen request IDs for each channel number. + for channel := uint32(1); channel <= uint32(clientConfig.NumChannels); channel++ { + if !slices.ContainsFunc(interceptorTracker.unaryClientRequestIDSegments, func(segments *requestIDSegments) bool { + return segments.ChannelID == channel + }) { + t.Fatalf("missing channel %d in unary requests", channel) + } + } + + // Execute MinOpened + 1 queries without closing the iterators. + // This will check out MinOpened + 1 sessions, which also triggers + // one more BatchCreateSessions call. + iterators := make([]*RowIterator, 0, clientConfig.MinOpened+1) + for i := 0; i < int(clientConfig.MinOpened)+1; i++ { + iter := sc.Single().Query(ctx, Statement{SQL: testutil.SelectFooFromBar}) + iterators = append(iterators, iter) + _, err := iter.Next() + if err != nil { + t.Fatal(err) + } + } + // Verify that we've seen request IDs for each channel number. + for channel := uint32(1); channel <= uint32(clientConfig.NumChannels); channel++ { + if !slices.ContainsFunc(interceptorTracker.streamClientRequestIDSegments, func(segments *requestIDSegments) bool { + return segments.ChannelID == channel + }) { + t.Fatalf("missing channel %d in unary requests", channel) + } + } + // Verify that we've only seen channel numbers in the range [1, config.NumChannels]. + for _, segmentsSlice := range [][]*requestIDSegments{interceptorTracker.streamClientRequestIDSegments, interceptorTracker.unaryClientRequestIDSegments} { + if slices.ContainsFunc(segmentsSlice, func(segments *requestIDSegments) bool { + return segments.ChannelID < 1 || segments.ChannelID > uint32(clientConfig.NumChannels) + }) { + t.Fatalf("invalid channel in requests: %v", segmentsSlice) + } + } + + if g, w := interceptorTracker.unaryCallCount(), uint64(5); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + + if g, w := interceptorTracker.streamCallCount(), uint64(101); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } +} + +func TestRequestIDInError(t *testing.T) { + cases := []struct { + name string + err *Error + want string + }{ + {"nil error", nil, "spanner: OK"}, + {"only requestID", &Error{RequestID: "req-id"}, `spanner: code = "OK", desc = "", requestID = "req-id"`}, + { + "with an error", + &Error{RequestID: "req-id", Code: codes.Internal, Desc: "An error"}, + `spanner: code = "Internal", desc = "An error", requestID = "req-id"`, + }, + { + "with additional details", + &Error{additionalInformation: "additional", RequestID: "req-id"}, + `spanner: code = "OK", desc = "", additional information = additional, requestID = "req-id"`, + }, + } + + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got := tt.err.Error() + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("Error string mismatch: got - want +\n%s", diff) + } + }) + } +} + +func TestRequestIDHeader_SingleUseReadOnly_ExecuteStreamingSql_UnavailableDuringStream(t *testing.T) { + t.Parallel() + ctx := context.Background() + interceptorTracker := newInterceptorTracker() + clientOpts := []option.ClientOption{ + option.WithGRPCDialOption(grpc.WithUnaryInterceptor(interceptorTracker.unaryClientInterceptor)), + option.WithGRPCDialOption(grpc.WithStreamInterceptor(interceptorTracker.streamClientInterceptor)), + } + clientConfig := ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 2, + MaxOpened: 10, + WriteSessions: 0.2, + incStep: 2, + }, + } + server, sc, tearDown := setupMockedTestServerWithConfigAndClientOptions(t, clientConfig, clientOpts) + t.Cleanup(tearDown) + defer sc.Close() + // A stream of PartialResultSets can break halfway and be retried from that point. + server.TestSpanner.AddPartialResultSetError( + testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums, + testutil.PartialResultSetExecutionTime{ + ResumeToken: testutil.EncodeResumeToken(2), + Err: status.Errorf(codes.Internal, "stream terminated by RST_STREAM"), + }, + ) + server.TestSpanner.AddPartialResultSetError( + testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums, + testutil.PartialResultSetExecutionTime{ + ResumeToken: testutil.EncodeResumeToken(3), + Err: status.Errorf(codes.Unavailable, "server is unavailable"), + }, + ) + iter := sc.Single().Query(ctx, Statement{SQL: testutil.SelectSingerIDAlbumIDAlbumTitleFromAlbums}) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatal(err) + } + } + if _, err := shouldHaveReceived(server.TestSpanner, []any{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.ExecuteSqlRequest{}, + }); err != nil { + t.Fatal(err) + } + if g, w := interceptorTracker.unaryCallCount(), uint64(1); g != w { + t.Errorf("unaryClientCall is incorrect; got=%d want=%d", g, w) + } + if g, w := interceptorTracker.streamCallCount(), uint64(3); g != w { + t.Errorf("streamClientCall is incorrect; got=%d want=%d", g, w) + } + if err := interceptorTracker.validateRequestIDsMonotonicity(); err != nil { + t.Fatal(err) + } + clientID := uint32(sc.sc.nthClient) + procID := randIDForProcess + version := xSpannerRequestIDVersion + wantUnarySegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 1, RPCNo: 1}, // BatchCreateSession + } + wantStreamingSegments := []*requestIDSegments{ + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 1}, // ExecuteStreamingSql (initial attempt) + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 2}, // ExecuteStreamingSql (retry) + {Version: version, ProcessID: procID, ClientID: clientID, ChannelID: 1, RequestNo: 2, RPCNo: 3}, // ExecuteStreamingSql (retry) + } + if diff := cmp.Diff(interceptorTracker.unaryClientRequestIDSegments, wantUnarySegments); diff != "" { + t.Fatalf("RequestID unary segments mismatch: got - want +\n%s", diff) + } + if diff := cmp.Diff(interceptorTracker.streamClientRequestIDSegments, wantStreamingSegments); diff != "" { + t.Fatalf("RequestID streaming segments mismatch: got - want +\n%s", diff) + } +} diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go index 7468f21bc722..42f16f6cfa54 100644 --- a/spanner/sessionclient.go +++ b/spanner/sessionclient.go @@ -22,6 +22,7 @@ import ( "log" "reflect" "sync" + "sync/atomic" "time" "cloud.google.com/go/internal/trace" @@ -48,7 +49,7 @@ func newClientIDGenerator() *clientIDGenerator { return &clientIDGenerator{ids: make(map[string]int)} } -func (cg *clientIDGenerator) nextID(database string) string { +func (cg *clientIDGenerator) nextClientIDAndOrdinal(database string) (clientID string, nthClient int) { cg.mu.Lock() defer cg.mu.Unlock() var id int @@ -58,7 +59,12 @@ func (cg *clientIDGenerator) nextID(database string) string { id = 1 } cg.ids[database] = id - return fmt.Sprintf("client-%d", id) + return fmt.Sprintf("client-%d", id), id +} + +func (cg *clientIDGenerator) nextID(database string) string { + clientStrID, _ := cg.nextClientIDAndOrdinal(database) + return clientStrID } // sessionConsumer is passed to the batchCreateSessions method and will receive @@ -101,15 +107,22 @@ type sessionClient struct { callOptions *vkit.CallOptions otConfig *openTelemetryConfig metricsTracerFactory *builtinMetricsTracerFactory + channelIDMap map[*grpc.ClientConn]uint64 + + // These fields are for request-id propagation. + nthClient int + // nthRequest shall always be incremented on every fresh request. + nthRequest *atomic.Uint32 } // newSessionClient creates a session client to use for a database. func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, sessionLabels map[string]string, databaseRole string, disableRouteToLeader bool, md metadata.MD, batchTimeout time.Duration, logger *log.Logger, callOptions *vkit.CallOptions) *sessionClient { + clientID, nthClient := cidGen.nextClientIDAndOrdinal(database) return &sessionClient{ connPool: connPool, database: database, userAgent: userAgent, - id: cidGen.nextID(database), + id: clientID, sessionLabels: sessionLabels, databaseRole: databaseRole, disableRouteToLeader: disableRouteToLeader, @@ -117,6 +130,9 @@ func newSessionClient(connPool gtransport.ConnPool, database, userAgent string, batchTimeout: batchTimeout, logger: logger, callOptions: callOptions, + + nthClient: nthClient, + nthRequest: new(atomic.Uint32), } } @@ -396,14 +412,30 @@ func (sc *sessionClient) sessionWithID(id string) (*session, error) { // optimal usage of server side caches. func (sc *sessionClient) nextClient() (spannerClient, error) { var clientOpt option.ClientOption + var channelID uint64 if _, ok := sc.connPool.(*gmeWrapper); ok { // Pass GCPMultiEndpoint as a pool. clientOpt = gtransport.WithConnPool(sc.connPool) } else { // Pick a grpc.ClientConn from a regular pool. - clientOpt = option.WithGRPCConn(sc.connPool.Conn()) + conn := sc.connPool.Conn() + + // Retrieve the channelID for each spannerClient. + // It is assumed that this method is invoked + // under a lock already. + var ok bool + channelID, ok = sc.channelIDMap[conn] + if !ok { + if sc.channelIDMap == nil { + sc.channelIDMap = make(map[*grpc.ClientConn]uint64) + } + channelID = uint64(len(sc.channelIDMap)) + 1 + sc.channelIDMap[conn] = channelID + } + + clientOpt = option.WithGRPCConn(conn) } - client, err := newGRPCSpannerClient(context.Background(), sc, clientOpt) + client, err := newGRPCSpannerClient(context.Background(), sc, channelID, clientOpt) if err != nil { return nil, err } diff --git a/spanner/transaction.go b/spanner/transaction.go index a33d03628685..f70b29050bb9 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -312,7 +312,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.session.logger, t.sp.sc.metricsTracerFactory, - func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { if t.sh != nil { t.sh.updateLastUseTime() } @@ -331,7 +331,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key DirectedReadOptions: directedReadOptions, OrderBy: orderBy, LockHint: lockHint, - }) + }, opts...) if err != nil { if _, ok := t.getTransactionSelector().GetSelector().(*sppb.TransactionSelector_Begin); ok { t.setTransactionID(nil) @@ -357,6 +357,7 @@ func (t *txReadOnly) ReadWithOptions(ctx context.Context, table string, keys Key }, t.setTimestamp, t.release, + client.(*grpcSpannerClient), ) } @@ -612,13 +613,13 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.session.logger, t.sp.sc.metricsTracerFactory, - func(ctx context.Context, resumeToken []byte) (streamingReceiver, error) { + func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { req.ResumeToken = resumeToken req.Session = t.sh.getID() req.Transaction = t.getTransactionSelector() t.sh.updateLastUseTime() - client, err := client.ExecuteStreamingSql(ctx, req) + client, err := client.ExecuteStreamingSql(ctx, req, opts...) if err != nil { if _, ok := req.Transaction.GetSelector().(*sppb.TransactionSelector_Begin); ok { t.setTransactionID(nil) @@ -643,7 +644,8 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que return t.updateTxState(err) }, t.setTimestamp, - t.release) + t.release, + client.(*grpcSpannerClient)) } func (t *txReadOnly) prepareExecuteSQL(ctx context.Context, stmt Statement, options QueryOptions) (*sppb.ExecuteSqlRequest, *sessionHandle, error) { @@ -1370,7 +1372,7 @@ func (t *ReadWriteTransaction) batchUpdateWithOptions(ctx context.Context, stmts return counts, errInlineBeginTransactionFailed() } if resp.Status != nil && resp.Status.Code != 0 { - return counts, t.txReadOnly.updateTxState(spannerErrorf(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)) + return counts, t.txReadOnly.updateTxState(spannerError(codes.Code(uint32(resp.Status.Code)), resp.Status.Message)) } return counts, nil }