diff --git a/connect_tls_117.go b/connect_tls_117.go new file mode 100644 index 00000000..3b73db74 --- /dev/null +++ b/connect_tls_117.go @@ -0,0 +1,44 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.17 +// +build go1.17 + +package alloydbconn + +import ( + "context" + "crypto/tls" + "net" + + "cloud.google.com/go/alloydbconn/errtype" + "cloud.google.com/go/alloydbconn/internal/alloydb" +) + +// connectTLS returns a new TLS client side connection +// using conn as the underlying transport. +// +// The returned connection has already completed its TLS handshake. +func connectTLS(ctx context.Context, conn net.Conn, c *tls.Config, i *alloydb.Instance) (net.Conn, error) { + tlsConn := tls.Client(conn, c) + // HandshakeContext was introduced in Go 1.17, hence + // this file is conditionally compiled on only Go versions >= 1.17. + if err := tlsConn.HandshakeContext(ctx); err != nil { + // refresh the instance info in case it caused the handshake failure + i.ForceRefresh() + _ = tlsConn.Close() // best effort close attempt + return nil, errtype.NewDialError("handshake failed", i.String(), err) + } + return tlsConn, nil +} diff --git a/connect_tls_other.go b/connect_tls_other.go new file mode 100644 index 00000000..a9625589 --- /dev/null +++ b/connect_tls_other.go @@ -0,0 +1,42 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !go1.17 +// +build !go1.17 + +package alloydbconn + +import ( + "context" + "crypto/tls" + "net" + + "cloud.google.com/go/alloydbconn/errtype" + "cloud.google.com/go/alloydbconn/internal/alloydb" +) + +// connectTLS returns a new TLS client side connection +// using conn as the underlying transport. +// +// The returned connection has already completed its TLS handshake. +func connectTLS(_ context.Context, conn net.Conn, c *tls.Config, i *alloydb.Instance) (net.Conn, error) { + tlsConn := tls.Client(conn, c) + if err := tlsConn.Handshake(); err != nil { + // refresh the instance info in case it caused the handshake failure + i.ForceRefresh() + _ = tlsConn.Close() // best effort close attempt + return nil, errtype.NewDialError("handshake failed", i.String(), err) + } + return tlsConn, nil +} diff --git a/dialer.go b/dialer.go index 95160068..007418dd 100644 --- a/dialer.go +++ b/dialer.go @@ -18,7 +18,6 @@ import ( "context" "crypto/rand" "crypto/rsa" - "crypto/tls" _ "embed" "fmt" "net" @@ -194,12 +193,9 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption) return nil, errtype.NewDialError("failed to set keep-alive period", i.String(), err) } } - tlsConn := tls.Client(conn, tlsCfg) - if err := tlsConn.Handshake(); err != nil { - // refresh the instance info in case it caused the handshake failure - i.ForceRefresh() - _ = tlsConn.Close() // best effort close attempt - return nil, errtype.NewDialError("handshake failed", i.String(), err) + tlsConn, err := connectTLS(ctx, conn, tlsCfg, i) + if err != nil { + return nil, err } latency := time.Since(startTime).Milliseconds() go func() {