Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add armed Listener with TLS handshake timeout #637

Merged
merged 4 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bind/flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ func HTTPLogConfig(fs *pflag.FlagSet, cfg []NamedParam[httplog.Mode]) {
}

func TLSServerConfig(fs *pflag.FlagSet, cfg *forwarder.TLSServerConfig, namePrefix string) {
fs.DurationVar(&cfg.HandshakeTimeout,
namePrefix+"tls-handshake-timeout", cfg.HandshakeTimeout,
"The maximum amount of time to wait for a TLS handshake before closing connection. Zero means no limit.")

fs.Var(anyflag.NewValueWithRedact[string](cfg.CertFile, &cfg.CertFile, func(val string) (string, error) { return val, nil }, RedactBase64),
namePrefix+"tls-cert-file", "<path or base64>"+
"TLS certificate to use if the server protocol is https or h2. "+
Expand Down
37 changes: 20 additions & 17 deletions http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/middleware"
"github.com/saucelabs/forwarder/pac"
"github.com/saucelabs/forwarder/ratelimit"
"github.com/saucelabs/forwarder/ruleset"
)

Expand Down Expand Up @@ -109,6 +108,9 @@ func DefaultHTTPProxyConfig() *HTTPProxyConfig {
Protocol: HTTPScheme,
Addr: ":3128",
ReadHeaderTimeout: 1 * time.Minute,
TLSServerConfig: TLSServerConfig{
HandshakeTimeout: 10 * time.Second,
},
},
Name: "forwarder",
ProxyLocalhost: DenyProxyLocalhost,
Expand Down Expand Up @@ -249,12 +251,14 @@ func (hp *HTTPProxy) configureProxy() error {
return hp.config.MITMDomains.Match(req.URL.Hostname())
}
}
hp.proxy.MITMTLSHandshakeTimeout = hp.config.TLSServerConfig.HandshakeTimeout
}

hp.proxy.AllowHTTP = true
hp.proxy.RequestIDHeader = hp.config.RequestIDHeader
hp.proxy.ConnectRequestModifier = hp.config.ConnectRequestModifier
hp.proxy.ConnectFunc = hp.config.ConnectFunc
hp.proxy.ConnectTimeout = 60 * time.Second
hp.proxy.WithoutWarning = true
hp.proxy.ErrorResponse = hp.errorResponse
hp.proxy.IdleTimeout = hp.config.IdleTimeout
Expand Down Expand Up @@ -599,27 +603,26 @@ func (hp *HTTPProxy) Run(ctx context.Context) error {
}

func (hp *HTTPProxy) listen() (net.Listener, error) {
listener, err := Listen("tcp", hp.config.Addr)
if err != nil {
return nil, fmt.Errorf("failed to open listener on address %s: %w", hp.config.Addr, err)
switch hp.config.Protocol {
case HTTPScheme, HTTPSScheme, HTTP2Scheme:
default:
return nil, fmt.Errorf("invalid protocol %q", hp.config.Protocol)
}

if rl, wl := int64(hp.config.ReadLimit), int64(hp.config.WriteLimit); rl > 0 || wl > 0 {
// Notice that the ReadLimit stands for the read limit *from* a proxy, and the WriteLimit
// stands for the write limit *to* a proxy, thus the ReadLimit is in fact
// a txBandwidth and the WriteLimit is a rxBandwidth.
listener = ratelimit.NewListener(listener, wl, rl)
l := Listener{
Address: hp.config.Addr,
Log: hp.log,
TLSConfig: hp.tlsConfig,
TLSHandshakeTimeout: hp.config.TLSServerConfig.HandshakeTimeout,
ReadLimit: int64(hp.config.ReadLimit),
WriteLimit: int64(hp.config.WriteLimit),
}

switch hp.config.Protocol {
case HTTPScheme:
return listener, nil
case HTTPSScheme, HTTP2Scheme:
return tls.NewListener(listener, hp.tlsConfig), nil
default:
listener.Close()
return nil, fmt.Errorf("invalid protocol %q", hp.config.Protocol)
if err := l.Listen(); err != nil {
return nil, err
}

return &l, nil
}

// Addr returns the address the server is listening on.
Expand Down
28 changes: 26 additions & 2 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,16 @@ type Proxy struct {
// Implementations can return ErrConnectFallback to indicate that the CONNECT request should be handled by martian.
ConnectFunc ConnectFunc

// ConnectTimeout specifies the maximum amount of time to connect to upstream before cancelling request.
ConnectTimeout time.Duration

// MITMFilter specifies a function to determine whether a CONNECT request should be MITMed.
MITMFilter func(*http.Request) bool

// MITMTLSHandshakeTimeout specifies the maximum amount of time to wait for a TLS handshake for a MITMed connection.
// Zero means no timeout.
MITMTLSHandshakeTimeout time.Duration

// WithoutWarning disables the warning header added to requests and responses when modifier errors occur.
WithoutWarning bool

Expand Down Expand Up @@ -512,7 +519,15 @@ func (p *Proxy) handleMITM(ctx *Context, req *http.Request, session *Session, br
io.MultiReader(bytes.NewReader(buf), conn),
}, p.mitm.TLSForHost(req.Host))

if err := tlsconn.Handshake(); err != nil {
var hctx context.Context
if p.MITMTLSHandshakeTimeout > 0 {
var hcancel context.CancelFunc
hctx, hcancel = context.WithTimeout(req.Context(), p.MITMTLSHandshakeTimeout)
defer hcancel()
} else {
hctx = req.Context()
}
if err = tlsconn.HandshakeContext(hctx); err != nil {
p.mitm.HandshakeErrorCallback(req, err)
if isClosedConnError(err) {
log.Debugf(req.Context(), "mitm: connection closed prematurely: %v", err)
Expand Down Expand Up @@ -953,7 +968,16 @@ func (p *Proxy) connectHTTP(req *http.Request, proxyURL *url.URL) (res *http.Res
d = dialvia.HTTPProxy(p.dial, proxyURL)
}
d.ConnectRequestModifier = p.ConnectRequestModifier
res, conn, err = d.DialContextR(req.Context(), "tcp", req.URL.Host)

var ctx context.Context
if p.ConnectTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(req.Context(), p.ConnectTimeout)
defer cancel()
} else {
ctx = req.Context()
}
res, conn, err = d.DialContextR(ctx, "tcp", req.URL.Host)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have ConnectTimeout field.

Let's open new issue to add context to socks.


if res != nil {
if res.StatusCode/100 == 2 {
Expand Down
70 changes: 70 additions & 0 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,76 @@ func TestHTTPThroughConnectWithMITM(t *testing.T) {
}
}

func TestTLSHandshakeTimeoutWithMITM(t *testing.T) {
t.Parallel()

l := newListener(t)
p := NewProxy()
p.MITMTLSHandshakeTimeout = 200 * time.Millisecond
defer p.Close()

tm := martiantest.NewModifier()
tm.RequestFunc(func(req *http.Request) {
ctx := NewContext(req)
ctx.SkipRoundTrip()

if req.Method != http.MethodGet && req.Method != http.MethodConnect {
t.Errorf("unexpected method on request handler: %v", req.Method)
}
})
p.SetRequestModifier(tm)

ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
if err != nil {
t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
}

mc, err := mitm.NewConfig(ca, priv)
if err != nil {
t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
}
p.SetMITM(mc)

go serve(p, l)

conn, err := l.dial()
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()

req, err := http.NewRequest(http.MethodConnect, "//example.com:80", http.NoBody)
if err != nil {
t.Fatalf("http.NewRequest(): got %v, want no error", err)
}

// CONNECT example.com:80 HTTP/1.1
// Host: example.com
if err := req.Write(conn); err != nil {
t.Fatalf("req.Write(): got %v, want no error", err)
}

// Response skipped round trip.
res, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
t.Fatalf("http.ReadResponse(): got %v, want no error", err)
}
res.Body.Close()

if got, want := res.StatusCode, 200; got != want {
t.Errorf("res.StatusCode: got %d, want %d", got, want)
}

if _, err := conn.Write([]byte{22}); err != nil {
t.Fatalf("conn.Write(): got %v, want no error", err)
}

time.Sleep(300 * time.Millisecond)
if _, err := conn.Read(make([]byte, 1)); !isClosedConnError(err) {
t.Fatalf("conn.Read(): got %v, want ClosedConnError", err)
}
}

func TestServerClosesConnection(t *testing.T) {
t.Parallel()

Expand Down
97 changes: 97 additions & 0 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ package forwarder

import (
"context"
"crypto/tls"
"net"
"syscall"
"time"

"github.com/saucelabs/forwarder/log"
"github.com/saucelabs/forwarder/ratelimit"
)

type DialConfig struct {
Expand Down Expand Up @@ -78,3 +82,96 @@ func Listen(network, address string) (net.Listener, error) {
// I asked about it here: https://groups.google.com/g/golang-nuts/c/Q1I7Viz9AJc
return defaultListenConfig().Listen(context.Background(), network, address)
}

type ListenerCallbacks interface {
// OnAccept is called when a new connection is successfully accepted.
OnAccept(net.Conn)

// OnTLSHandshakeError is called after a TLS handshake errors out.
OnTLSHandshakeError(*tls.Conn, error)
}

type Listener struct {
Address string
Log log.Logger
TLSConfig *tls.Config
TLSHandshakeTimeout time.Duration
ReadLimit int64
WriteLimit int64
Callbacks ListenerCallbacks

listener net.Listener
}

func (l *Listener) Listen() error {
ll, err := Listen("tcp", l.Address)
if err != nil {
return err
}

if rl, wl := l.ReadLimit, l.WriteLimit; rl > 0 || wl > 0 {
// Notice that the ReadLimit stands for the read limit *from* a proxy, and the WriteLimit
// stands for the write limit *to* a proxy, thus the ReadLimit is in fact
// a txBandwidth and the WriteLimit is a rxBandwidth.
ll = ratelimit.NewListener(ll, wl, rl)
}

l.listener = ll
return nil
}

func (l *Listener) Accept() (net.Conn, error) {
for {
c, err := l.listener.Accept()
if err != nil {
return nil, err
}

if l.Callbacks != nil {
l.Callbacks.OnAccept(c)
}

if l.TLSConfig == nil {
return c, nil
}

tc, err := l.withTLS(c)
if err != nil {
l.Log.Errorf("Failed to perform TLS handshake: %v", err)
if cerr := tc.Close(); cerr != nil {
l.Log.Errorf("Failed to close TLS connection: %v", cerr)
}
continue
}

return tc, nil
}
}

func (l *Listener) withTLS(conn net.Conn) (*tls.Conn, error) {
tconn := tls.Server(conn, l.TLSConfig)

var err error
if l.TLSHandshakeTimeout <= 0 {
err = tconn.Handshake()
} else {
ctx, cancel := context.WithTimeout(context.Background(), l.TLSHandshakeTimeout)
err = tconn.HandshakeContext(ctx)
cancel()
}
if err != nil {
if l.Callbacks != nil {
l.Callbacks.OnTLSHandshakeError(tconn, err)
}
}

return tconn, err
}

func (l *Listener) Addr() net.Addr {
return l.listener.Addr()
}

func (l *Listener) Close() error {
return l.listener.Close()
}
72 changes: 72 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2023 Sauce Labs Inc., all rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package forwarder

import (
"context"
"crypto/tls"
"errors"
"net"
"testing"
"time"

"github.com/saucelabs/forwarder/log"
)

func TestListenerTLSHandshakeTimeout(t *testing.T) {
tlsCfg := new(tls.Config)
if err := (&TLSServerConfig{HandshakeTimeout: 100 * time.Millisecond}).ConfigureTLSConfig(tlsCfg); err != nil {
t.Fatal(err)
}

done := make(chan struct{})

l := Listener{
Address: "localhost:0",
Log: log.NopLogger,
TLSConfig: tlsCfg,
TLSHandshakeTimeout: 100 * time.Millisecond,
Callbacks: &mockListenerCallback{
t: t,
done: done,
},
}

err := l.Listen()
if err != nil {
t.Fatal(err)
}
defer l.Close()

go func() {
// Accept won't return.
_, _ = l.Accept()
}()

conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("net.Dial(): got %v, want no error", err)
}
defer conn.Close()

<-done
}

type mockListenerCallback struct {
t *testing.T
done chan struct{}
}

func (m *mockListenerCallback) OnAccept(_ net.Conn) {
}

func (m *mockListenerCallback) OnTLSHandshakeError(_ *tls.Conn, err error) {
if !errors.Is(err, context.DeadlineExceeded) {
m.t.Errorf("tl.OnTLSHandshakeError(): got %v, want %v", err, context.DeadlineExceeded)
}
m.done <- struct{}{}
}
Loading