Skip to content

Commit

Permalink
fix: define common network dialer, http transport and client for use …
Browse files Browse the repository at this point in the history
…in proxy requests
  • Loading branch information
anfragment committed Feb 22, 2024
1 parent c286d2e commit eacaf25
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
60 changes: 40 additions & 20 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ type filter interface {

// Proxy is a forward HTTP/HTTPS proxy that can filter requests.
type Proxy struct {
port int
filter filter
certGenerator certGenerator
server *http.Server
ignoredHosts []string
ignoredHostsMu sync.RWMutex
filter filter
certGenerator certGenerator
port int
server *http.Server
requestTransport http.RoundTripper
requestClient *http.Client
netDialer *net.Dialer
ignoredHosts []string
ignoredHostsMu sync.RWMutex
}

func NewProxy(filter filter, certGenerator certGenerator, port int) (*Proxy, error) {
Expand All @@ -43,11 +46,31 @@ func NewProxy(filter filter, certGenerator certGenerator, port int) (*Proxy, err
return nil, errors.New("certGenerator is nil")
}

return &Proxy{
p := &Proxy{
filter: filter,
certGenerator: certGenerator,
port: port,
}, nil
}

p.netDialer = &net.Dialer{
// Such high values are set to avoid timeouts on slow connections.
Timeout: 60 * time.Second,
KeepAlive: 30 * time.Second,
}
p.requestTransport = &http.Transport{
Dial: p.netDialer.Dial,
TLSHandshakeTimeout: 20 * time.Second,
}
p.requestClient = &http.Client{
Timeout: 60 * time.Second,
Transport: p.requestTransport,
// Let the client handle any redirects.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}

return p, nil
}

// Start starts the proxy on the given address.
Expand Down Expand Up @@ -81,10 +104,14 @@ func (p *Proxy) Start() error {
func (p *Proxy) initExclusionList() {
var wg sync.WaitGroup
wg.Add(len(exclusionListURLs))
client := &http.Client{
Timeout: 20 * time.Second,
}

for _, url := range exclusionListURLs {
go func(url string) {
defer wg.Done()
resp, err := http.Get(url)
resp, err := client.Get(url)
if err != nil {
log.Printf("failed to get exclusion list: %v", err)
return
Expand Down Expand Up @@ -154,22 +181,15 @@ func (p *Proxy) proxyHTTP(w http.ResponseWriter, r *http.Request) {
return
}

client := &http.Client{
// let the client handle any redirects
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}

r.RequestURI = ""

removeConnectionHeaders(r.Header)
removeHopHeaders(r.Header)

resp, err := client.Do(r)
resp, err := p.requestClient.Do(r)
if err != nil {
log.Printf("client.Do: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
log.Printf("error making request: %v", err)
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
Expand Down Expand Up @@ -273,7 +293,7 @@ func (p *Proxy) proxyConnect(w http.ResponseWriter, r *http.Request) {
break
}

resp, err := http.DefaultTransport.RoundTrip(req)
resp, err := p.requestTransport.RoundTrip(req)
if err != nil {
if strings.Contains(err.Error(), "tls: ") {
log.Printf("adding %s to ignored hosts", host)
Expand Down
6 changes: 3 additions & 3 deletions proxy/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import (
"crypto/tls"
"io"
"log"
"net"
"net/http"
"strings"
)

func (p *Proxy) proxyWebsocketTLS(w http.ResponseWriter, req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) {
targetConn, err := tls.Dial("tcp", req.URL.Host, tlsConfig)
dialer := &tls.Dialer{NetDialer: p.netDialer, Config: tlsConfig}
targetConn, err := dialer.Dial("tcp", req.URL.Host)
if err != nil {
clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
log.Printf("dialing websocket backend(%s): %v", req.URL.Host, err)
Expand All @@ -27,7 +27,7 @@ func (p *Proxy) proxyWebsocketTLS(w http.ResponseWriter, req *http.Request, tlsC
}

func (p *Proxy) proxyWebsocket(w http.ResponseWriter, req *http.Request) {
targetConn, err := net.Dial("tcp", req.URL.Host)
targetConn, err := p.netDialer.Dial("tcp", req.URL.Host)
if err != nil {
w.WriteHeader(http.StatusBadGateway)
log.Printf("dialing websocket backend(%s): %v", req.URL.Host, err)
Expand Down

0 comments on commit eacaf25

Please sign in to comment.