diff --git a/proxy/proxy.go b/proxy/proxy.go index 3339286a..1c9de3e3 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -27,12 +27,15 @@ type filter interface { // Proxy is a forward HTTP/HTTPS proxy that can filter requests. type Proxy struct { - port int - filter filter - certGenerator certGenerator - server *http.Server - ignoredHosts []string - ignoredHostsMu sync.RWMutex + filter filter + certGenerator certGenerator + port int + server *http.Server + requestTransport http.RoundTripper + requestClient *http.Client + netDialer *net.Dialer + ignoredHosts []string + ignoredHostsMu sync.RWMutex } func NewProxy(filter filter, certGenerator certGenerator, port int) (*Proxy, error) { @@ -43,11 +46,31 @@ func NewProxy(filter filter, certGenerator certGenerator, port int) (*Proxy, err return nil, errors.New("certGenerator is nil") } - return &Proxy{ + p := &Proxy{ filter: filter, certGenerator: certGenerator, port: port, - }, nil + } + + p.netDialer = &net.Dialer{ + // Such high values are set to avoid timeouts on slow connections. + Timeout: 60 * time.Second, + KeepAlive: 30 * time.Second, + } + p.requestTransport = &http.Transport{ + Dial: p.netDialer.Dial, + TLSHandshakeTimeout: 20 * time.Second, + } + p.requestClient = &http.Client{ + Timeout: 60 * time.Second, + Transport: p.requestTransport, + // Let the client handle any redirects. + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + return p, nil } // Start starts the proxy on the given address. @@ -81,10 +104,14 @@ func (p *Proxy) Start() error { func (p *Proxy) initExclusionList() { var wg sync.WaitGroup wg.Add(len(exclusionListURLs)) + client := &http.Client{ + Timeout: 20 * time.Second, + } + for _, url := range exclusionListURLs { go func(url string) { defer wg.Done() - resp, err := http.Get(url) + resp, err := client.Get(url) if err != nil { log.Printf("failed to get exclusion list: %v", err) return @@ -154,22 +181,15 @@ func (p *Proxy) proxyHTTP(w http.ResponseWriter, r *http.Request) { return } - client := &http.Client{ - // let the client handle any redirects - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - } - r.RequestURI = "" removeConnectionHeaders(r.Header) removeHopHeaders(r.Header) - resp, err := client.Do(r) + resp, err := p.requestClient.Do(r) if err != nil { - log.Printf("client.Do: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + log.Printf("error making request: %v", err) + http.Error(w, err.Error(), http.StatusBadGateway) return } defer resp.Body.Close() @@ -273,7 +293,7 @@ func (p *Proxy) proxyConnect(w http.ResponseWriter, r *http.Request) { break } - resp, err := http.DefaultTransport.RoundTrip(req) + resp, err := p.requestTransport.RoundTrip(req) if err != nil { if strings.Contains(err.Error(), "tls: ") { log.Printf("adding %s to ignored hosts", host) diff --git a/proxy/websocket.go b/proxy/websocket.go index 4324bb70..40e683f8 100644 --- a/proxy/websocket.go +++ b/proxy/websocket.go @@ -5,13 +5,13 @@ import ( "crypto/tls" "io" "log" - "net" "net/http" "strings" ) func (p *Proxy) proxyWebsocketTLS(w http.ResponseWriter, req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) { - targetConn, err := tls.Dial("tcp", req.URL.Host, tlsConfig) + dialer := &tls.Dialer{NetDialer: p.netDialer, Config: tlsConfig} + targetConn, err := dialer.Dial("tcp", req.URL.Host) if err != nil { clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) log.Printf("dialing websocket backend(%s): %v", req.URL.Host, err) @@ -27,7 +27,7 @@ func (p *Proxy) proxyWebsocketTLS(w http.ResponseWriter, req *http.Request, tlsC } func (p *Proxy) proxyWebsocket(w http.ResponseWriter, req *http.Request) { - targetConn, err := net.Dial("tcp", req.URL.Host) + targetConn, err := p.netDialer.Dial("tcp", req.URL.Host) if err != nil { w.WriteHeader(http.StatusBadGateway) log.Printf("dialing websocket backend(%s): %v", req.URL.Host, err)