diff --git a/client/pkg/transport/transport.go b/client/pkg/transport/transport.go index 67170d7436d0..561cae8d253c 100644 --- a/client/pkg/transport/transport.go +++ b/client/pkg/transport/transport.go @@ -18,6 +18,8 @@ import ( "context" "net" "net/http" + "net/url" + "os" "strings" "time" ) @@ -39,7 +41,17 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er } t := &http.Transport{ - Proxy: http.ProxyFromEnvironment, + Proxy: func(req *http.Request) (*url.URL, error) { + // according to the comment of http.ProxyFromEnvironment: if the + // proxy URL is "localhost" (with or without a port number), + // then a nil URL and nil error will be returned. + // Thus, we need to workaround this by manually setting an + // ENV named FORWARD_PROXY and parse the URL (which is a localhost in our case) + if forwardProxy, exists := os.LookupEnv("FORWARD_PROXY"); exists { + return url.Parse(forwardProxy) + } + return http.ProxyFromEnvironment(req) + }, DialContext: (&net.Dialer{ Timeout: dialtimeoutd, LocalAddr: ipAddr, diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index 6d7931b4e33a..6f0737c27975 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -15,6 +15,8 @@ package proxy import ( + "bufio" + "bytes" "context" "fmt" "io" @@ -130,18 +132,21 @@ type Server interface { // ServerConfig defines proxy server configuration. type ServerConfig struct { - Logger *zap.Logger - From url.URL - To url.URL - TLSInfo transport.TLSInfo - DialTimeout time.Duration - BufferSize int - RetryInterval time.Duration + Logger *zap.Logger + From url.URL + To url.URL + TLSInfo transport.TLSInfo + DialTimeout time.Duration + BufferSize int + RetryInterval time.Duration + IsForwardProxy bool } type server struct { lg *zap.Logger + isForwardProxy bool + from url.URL fromPort int to url.URL @@ -194,6 +199,8 @@ func NewServer(cfg ServerConfig) Server { s := &server{ lg: cfg.Logger, + isForwardProxy: cfg.IsForwardProxy, + from: cfg.From, to: cfg.To, @@ -216,10 +223,12 @@ func NewServer(cfg ServerConfig) Server { if err == nil { s.fromPort, _ = strconv.Atoi(fromPort) } - var toPort string - _, toPort, err = net.SplitHostPort(cfg.To.Host) - if err == nil { - s.toPort, _ = strconv.Atoi(toPort) + if !s.isForwardProxy { + var toPort string + _, toPort, err = net.SplitHostPort(cfg.To.Host) + if err == nil { + s.toPort, _ = strconv.Atoi(toPort) + } } if s.dialTimeout == 0 { @@ -239,8 +248,10 @@ func NewServer(cfg ServerConfig) Server { if strings.HasPrefix(s.from.Scheme, "http") { s.from.Scheme = "tcp" } - if strings.HasPrefix(s.to.Scheme, "http") { - s.to.Scheme = "tcp" + if !s.isForwardProxy { + if strings.HasPrefix(s.to.Scheme, "http") { + s.to.Scheme = "tcp" + } } addr := fmt.Sprintf(":%d", s.fromPort) @@ -273,7 +284,10 @@ func (s *server) From() string { } func (s *server) To() string { - return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host) + if !s.isForwardProxy { + return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host) + } + return "" } // TODO: implement packet reordering from multiple TCP connections @@ -353,6 +367,44 @@ func (s *server) listenAndServe() { continue } + parseHeaderForDestination := func() *string { + // the first request should always contain a CONNECT header field + // since we set the transport to forward the traffic to the proxy + buf := make([]byte, s.bufferSize) + var data []byte + var nr1 int + if nr1, err = in.Read(buf); err != nil { + if err == io.EOF { + return nil + // why?? + // panic("No data available for forward proxy to work on") + } + panic(err) + } else { + data = buf[:nr1] + } + + // attempt to parse for the HOST from the CONNECT request + var req *http.Request + if req, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(data))); err != nil { + panic("Failed to parse header in forward proxy") + } + + if req.Method == http.MethodConnect { + // make sure a reply is sent back to the client + connectResponse := &http.Response{ + StatusCode: 200, + ProtoMajor: 1, + ProtoMinor: 1, + } + connectResponse.Write(in) + + return &req.URL.Host + } + + panic("Wrong header type to start the connection") + } + var out net.Conn if !s.tlsInfo.Empty() { var tp *http.Transport @@ -370,9 +422,25 @@ func (s *server) listenAndServe() { } continue } - out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host) + if s.isForwardProxy { + if dest := parseHeaderForDestination(); dest == nil { + continue + } else { + out, err = tp.DialContext(ctx, "tcp", *dest) + } + } else { + out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host) + } } else { - out, err = net.Dial(s.to.Scheme, s.to.Host) + if s.isForwardProxy { + if dest := parseHeaderForDestination(); dest == nil { + continue + } else { + out, err = net.Dial("tcp", *dest) + } + } else { + out, err = net.Dial(s.to.Scheme, s.to.Host) + } } if err != nil { select { diff --git a/tests/e2e/blackhole_test.go b/tests/e2e/blackhole_test.go index 68150cb8608a..ee3f41c14453 100644 --- a/tests/e2e/blackhole_test.go +++ b/tests/e2e/blackhole_test.go @@ -51,17 +51,20 @@ func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLea require.NoError(t, epc.Close(), "failed to close etcd cluster") }() - leaderId := epc.WaitLeader(t) - mockPartitionNodeIndex := leaderId + leaderID := epc.WaitLeader(t) + mockPartitionNodeIndex := leaderID if !partitionLeader { - mockPartitionNodeIndex = (leaderId + 1) % (clusterSize) + mockPartitionNodeIndex = (leaderID + 1) % (clusterSize) } partitionedMember := epc.Procs[mockPartitionNodeIndex] // Mock partition - proxy := partitionedMember.PeerProxy() + forwardProxy := partitionedMember.PeerForwardProxy() + reverseProxy := partitionedMember.PeerReverseProxy() t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name) - proxy.BlackholeTx() - proxy.BlackholeRx() + forwardProxy.BlackholeTx() + forwardProxy.BlackholeRx() + reverseProxy.BlackholeTx() + reverseProxy.BlackholeRx() t.Logf("Wait 5s for any open connections to expire") time.Sleep(5 * time.Second) @@ -79,8 +82,10 @@ func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLea // Wait for some time to restore the network time.Sleep(1 * time.Second) t.Logf("Unblackholing traffic from and to member %q", partitionedMember.Config().Name) - proxy.UnblackholeTx() - proxy.UnblackholeRx() + forwardProxy.UnblackholeTx() + forwardProxy.UnblackholeRx() + reverseProxy.UnblackholeTx() + reverseProxy.UnblackholeRx() leaderEPC = epc.Procs[epc.WaitLeader(t)] time.Sleep(5 * time.Second) diff --git a/tests/e2e/http_health_check_test.go b/tests/e2e/http_health_check_test.go index 8aa2694344f5..74593f0784de 100644 --- a/tests/e2e/http_health_check_test.go +++ b/tests/e2e/http_health_check_test.go @@ -384,10 +384,13 @@ func triggerSlowApply(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCl func blackhole(_ context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, _ time.Duration) { member := clus.Procs[0] - proxy := member.PeerProxy() + forwardProxy := member.PeerForwardProxy() + reverseProxy := member.PeerReverseProxy() t.Logf("Blackholing traffic from and to member %q", member.Config().Name) - proxy.BlackholeTx() - proxy.BlackholeRx() + forwardProxy.BlackholeTx() + forwardProxy.BlackholeRx() + reverseProxy.BlackholeTx() + reverseProxy.BlackholeRx() } func triggerRaftLoopDeadLock(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster, duration time.Duration) { diff --git a/tests/framework/e2e/cluster.go b/tests/framework/e2e/cluster.go index cb8b35d7fd85..a18335b40369 100644 --- a/tests/framework/e2e/cluster.go +++ b/tests/framework/e2e/cluster.go @@ -502,12 +502,13 @@ func (cfg *EtcdProcessClusterConfig) SetInitialOrDiscovery(serverCfg *EtcdServer func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i int) *EtcdServerProcessConfig { var curls []string var curl string - port := cfg.BasePort + 5*i + port := cfg.BasePort + 6*i clientPort := port - peerPort := port + 1 + peerPort := port + 1 // the port that the peer actually listens on metricsPort := port + 2 - peer2Port := port + 3 + reverseProxyPort := port + 3 // the port that the peer advertises clientHTTPPort := port + 4 + forwardProxyPort := port + 5 if cfg.Client.ConnectionType == ClientTLSAndNonTLS { curl = clientURL(cfg.ClientScheme(), clientPort, ClientNonTLS) @@ -519,17 +520,33 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in peerListenURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} peerAdvertiseURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", peerPort)} - var proxyCfg *proxy.ServerConfig + var forwardProxyCfg *proxy.ServerConfig + var reverseProxyCfg *proxy.ServerConfig if cfg.PeerProxy { if !cfg.IsPeerTLS { panic("Can't use peer proxy without peer TLS as it can result in malformed packets") } - peerAdvertiseURL.Host = fmt.Sprintf("localhost:%d", peer2Port) - proxyCfg = &proxy.ServerConfig{ + + // setup reverse proxy + peerAdvertiseURL.Host = fmt.Sprintf("localhost:%d", reverseProxyPort) + reverseProxyCfg = &proxy.ServerConfig{ Logger: zap.NewNop(), To: peerListenURL, From: peerAdvertiseURL, } + + // setup forward proxy + forwardProxyURL := url.URL{Scheme: cfg.PeerScheme(), Host: fmt.Sprintf("localhost:%d", forwardProxyPort)} + forwardProxyCfg = &proxy.ServerConfig{ + Logger: zap.NewNop(), + From: forwardProxyURL, + IsForwardProxy: true, + } + + if cfg.EnvVars == nil { + cfg.EnvVars = make(map[string]string) + } + cfg.EnvVars["FORWARD_PROXY"] = fmt.Sprintf("http://127.0.0.1:%d", forwardProxyPort) } name := fmt.Sprintf("%s-test-%d", testNameCleanRegex.ReplaceAllString(tb.Name(), ""), i) @@ -651,7 +668,8 @@ func (cfg *EtcdProcessClusterConfig) EtcdServerProcessConfig(tb testing.TB, i in InitialToken: cfg.ServerConfig.InitialClusterToken, GoFailPort: gofailPort, GoFailClientTimeout: cfg.GoFailClientTimeout, - Proxy: proxyCfg, + ReverseProxy: reverseProxyCfg, + ForwardProxy: forwardProxyCfg, LazyFSEnabled: cfg.LazyFSEnabled, } } diff --git a/tests/framework/e2e/etcd_process.go b/tests/framework/e2e/etcd_process.go index f46fa6d8661e..52bf4d4b178a 100644 --- a/tests/framework/e2e/etcd_process.go +++ b/tests/framework/e2e/etcd_process.go @@ -55,7 +55,8 @@ type EtcdProcess interface { Stop() error Close() error Config() *EtcdServerProcessConfig - PeerProxy() proxy.Server + PeerReverseProxy() proxy.Server + PeerForwardProxy() proxy.Server Failpoints() *BinaryFailpoints LazyFS() *LazyFS Logs() LogsExpect @@ -69,12 +70,13 @@ type LogsExpect interface { } type EtcdServerProcess struct { - cfg *EtcdServerProcessConfig - proc *expect.ExpectProcess - proxy proxy.Server - lazyfs *LazyFS - failpoints *BinaryFailpoints - donec chan struct{} // closed when Interact() terminates + cfg *EtcdServerProcessConfig + proc *expect.ExpectProcess + forwardProxy proxy.Server + reverseProxy proxy.Server + lazyfs *LazyFS + failpoints *BinaryFailpoints + donec chan struct{} // closed when Interact() terminates } type EtcdServerProcessConfig struct { @@ -101,7 +103,8 @@ type EtcdServerProcessConfig struct { GoFailClientTimeout time.Duration LazyFSEnabled bool - Proxy *proxy.ServerConfig + ReverseProxy *proxy.ServerConfig + ForwardProxy *proxy.ServerConfig } func NewEtcdServerProcess(t testing.TB, cfg *EtcdServerProcessConfig) (*EtcdServerProcess, error) { @@ -151,12 +154,26 @@ func (ep *EtcdServerProcess) Start(ctx context.Context) error { if ep.proc != nil { panic("already started") } - if ep.cfg.Proxy != nil && ep.proxy == nil { - ep.cfg.lg.Info("starting proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.Proxy.From.String()), zap.String("to", ep.cfg.Proxy.To.String())) - ep.proxy = proxy.NewServer(*ep.cfg.Proxy) + + if !((ep.cfg.ReverseProxy != nil && ep.cfg.ForwardProxy != nil) || (ep.cfg.ReverseProxy == nil && ep.cfg.ForwardProxy == nil)) { + panic("both forward and reverse proxy confiugration files must exist or not exist at the same time") + } + + if ep.cfg.ReverseProxy != nil && ep.reverseProxy == nil { + ep.cfg.lg.Info("starting reverse proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.ReverseProxy.From.String()), zap.String("to", ep.cfg.ReverseProxy.To.String())) + ep.reverseProxy = proxy.NewServer(*ep.cfg.ReverseProxy) + select { + case <-ep.reverseProxy.Ready(): + case err := <-ep.reverseProxy.Error(): + return err + } + } + if ep.cfg.ForwardProxy != nil && ep.forwardProxy == nil { + ep.cfg.lg.Info("starting forward proxy...", zap.String("name", ep.cfg.Name), zap.String("from", ep.cfg.ForwardProxy.From.String()), zap.String("to", ep.cfg.ForwardProxy.To.String())) + ep.forwardProxy = proxy.NewServer(*ep.cfg.ForwardProxy) select { - case <-ep.proxy.Ready(): - case err := <-ep.proxy.Error(): + case <-ep.forwardProxy.Ready(): + case err := <-ep.forwardProxy.Error(): return err } } @@ -221,10 +238,18 @@ func (ep *EtcdServerProcess) Stop() (err error) { } } ep.cfg.lg.Info("stopped server.", zap.String("name", ep.cfg.Name)) - if ep.proxy != nil { - ep.cfg.lg.Info("stopping proxy...", zap.String("name", ep.cfg.Name)) - err = ep.proxy.Close() - ep.proxy = nil + if ep.forwardProxy != nil { + ep.cfg.lg.Info("stopping forward proxy...", zap.String("name", ep.cfg.Name)) + err = ep.forwardProxy.Close() + ep.forwardProxy = nil + if err != nil { + return err + } + } + if ep.reverseProxy != nil { + ep.cfg.lg.Info("stopping reverse proxy...", zap.String("name", ep.cfg.Name)) + err = ep.reverseProxy.Close() + ep.reverseProxy = nil if err != nil { return err } @@ -326,8 +351,12 @@ func AssertProcessLogs(t *testing.T, ep EtcdProcess, expectLog string) { } } -func (ep *EtcdServerProcess) PeerProxy() proxy.Server { - return ep.proxy +func (ep *EtcdServerProcess) PeerReverseProxy() proxy.Server { + return ep.reverseProxy +} + +func (ep *EtcdServerProcess) PeerForwardProxy() proxy.Server { + return ep.forwardProxy } func (ep *EtcdServerProcess) LazyFS() *LazyFS { diff --git a/tests/robustness/failpoint/network.go b/tests/robustness/failpoint/network.go index b355b5182bc4..589dead30437 100644 --- a/tests/robustness/failpoint/network.go +++ b/tests/robustness/failpoint/network.go @@ -62,23 +62,26 @@ func (tb triggerBlackhole) Available(config e2e.EtcdProcessClusterConfig, proces if tb.waitTillSnapshot && (entriesToGuaranteeSnapshot(config) > 200 || !e2e.CouldSetSnapshotCatchupEntries(process.Config().ExecPath)) { return false } - return config.ClusterSize > 1 && process.PeerProxy() != nil + return config.ClusterSize > 1 && process.PeerForwardProxy() != nil && process.PeerReverseProxy() != nil } func Blackhole(ctx context.Context, t *testing.T, member e2e.EtcdProcess, clus *e2e.EtcdProcessCluster, shouldWaitTillSnapshot bool) error { - proxy := member.PeerProxy() + reverseProxy := member.PeerReverseProxy() + forwardProxy := member.PeerForwardProxy() - // Blackholing will cause peers to not be able to use streamWriters registered with member - // but peer traffic is still possible because member has 'pipeline' with peers - // TODO: find a way to stop all traffic t.Logf("Blackholing traffic from and to member %q", member.Config().Name) - proxy.BlackholeTx() - proxy.BlackholeRx() + reverseProxy.BlackholeTx() + reverseProxy.BlackholeRx() + forwardProxy.BlackholeTx() + forwardProxy.BlackholeRx() defer func() { t.Logf("Traffic restored from and to member %q", member.Config().Name) - proxy.UnblackholeTx() - proxy.UnblackholeRx() + reverseProxy.UnblackholeTx() + reverseProxy.UnblackholeRx() + forwardProxy.UnblackholeTx() + forwardProxy.UnblackholeRx() }() + if shouldWaitTillSnapshot { return waitTillSnapshot(ctx, t, clus, member) } @@ -163,15 +166,20 @@ type delayPeerNetworkFailpoint struct { func (f delayPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { member := clus.Procs[rand.Int()%len(clus.Procs)] - proxy := member.PeerProxy() + reverseProxy := member.PeerReverseProxy() + forwardProxy := member.PeerForwardProxy() - proxy.DelayRx(f.baseLatency, f.randomizedLatency) - proxy.DelayTx(f.baseLatency, f.randomizedLatency) + reverseProxy.DelayRx(f.baseLatency, f.randomizedLatency) + reverseProxy.DelayTx(f.baseLatency, f.randomizedLatency) + forwardProxy.DelayRx(f.baseLatency, f.randomizedLatency) + forwardProxy.DelayTx(f.baseLatency, f.randomizedLatency) lg.Info("Delaying traffic from and to member", zap.String("member", member.Config().Name), zap.Duration("baseLatency", f.baseLatency), zap.Duration("randomizedLatency", f.randomizedLatency)) time.Sleep(f.duration) lg.Info("Traffic delay removed", zap.String("member", member.Config().Name)) - proxy.UndelayRx() - proxy.UndelayTx() + reverseProxy.UndelayRx() + reverseProxy.UndelayTx() + forwardProxy.UndelayRx() + forwardProxy.UndelayTx() return nil, nil } @@ -180,7 +188,7 @@ func (f delayPeerNetworkFailpoint) Name() string { } func (f delayPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess) bool { - return config.ClusterSize > 1 && clus.PeerProxy() != nil + return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil && clus.PeerReverseProxy() != nil } type dropPeerNetworkFailpoint struct { @@ -190,15 +198,20 @@ type dropPeerNetworkFailpoint struct { func (f dropPeerNetworkFailpoint) Inject(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, baseTime time.Time, ids identity.Provider) ([]report.ClientReport, error) { member := clus.Procs[rand.Int()%len(clus.Procs)] - proxy := member.PeerProxy() + reverseProxy := member.PeerReverseProxy() + forwardProxy := member.PeerForwardProxy() - proxy.ModifyRx(f.modifyPacket) - proxy.ModifyTx(f.modifyPacket) + reverseProxy.ModifyRx(f.modifyPacket) + reverseProxy.ModifyTx(f.modifyPacket) + forwardProxy.ModifyRx(f.modifyPacket) + forwardProxy.ModifyTx(f.modifyPacket) lg.Info("Dropping traffic from and to member", zap.String("member", member.Config().Name), zap.Int("probability", f.dropProbabilityPercent)) time.Sleep(f.duration) lg.Info("Traffic drop removed", zap.String("member", member.Config().Name)) - proxy.UnmodifyRx() - proxy.UnmodifyTx() + reverseProxy.UnmodifyRx() + reverseProxy.UnmodifyTx() + forwardProxy.UnmodifyRx() + forwardProxy.UnmodifyTx() return nil, nil } @@ -214,5 +227,5 @@ func (f dropPeerNetworkFailpoint) Name() string { } func (f dropPeerNetworkFailpoint) Available(config e2e.EtcdProcessClusterConfig, clus e2e.EtcdProcess) bool { - return config.ClusterSize > 1 && clus.PeerProxy() != nil + return config.ClusterSize > 1 && clus.PeerForwardProxy() != nil && clus.PeerReverseProxy() != nil }