From 862579ab8a91a4bf9ee487ba4c872fdf7a7bce07 Mon Sep 17 00:00:00 2001 From: AlexStocks Date: Thu, 23 Apr 2020 14:04:59 +0800 Subject: [PATCH] Add: listen on random local port --- go.mod | 2 +- go.sum | 2 ++ server.go | 53 +++++++++++++++++++++++++++---------------- server_test.go | 61 +++++++++++++++++++++----------------------------- session.go | 16 ++++++------- 5 files changed, 71 insertions(+), 63 deletions(-) diff --git a/go.mod b/go.mod index ffeecdba..be1c736e 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/dubbogo/getty require ( - github.com/dubbogo/gost v1.5.2 + github.com/dubbogo/gost v1.9.0 github.com/golang/snappy v0.0.1 github.com/gorilla/websocket v1.4.0 github.com/juju/errors v0.0.0-20190930114154-d42613fe1ab9 diff --git a/go.sum b/go.sum index d0e532e6..c6fe58ce 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dubbogo/gost v1.5.2 h1:ri/03971hdpnn3QeCU+4UZgnRNGDXLDGDucR/iozZm8= github.com/dubbogo/gost v1.5.2/go.mod h1:pPTjVyoJan3aPxBPNUX0ADkXjPibLo+/Ib0/fADXSG8= +github.com/dubbogo/gost v1.9.0 h1:UT+dWwvLyJiDotxJERO75jB3Yxgsdy10KztR5ycxRAk= +github.com/dubbogo/gost v1.9.0/go.mod h1:pPTjVyoJan3aPxBPNUX0ADkXjPibLo+/Ib0/fADXSG8= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q= diff --git a/server.go b/server.go index cb0238a4..866f9133 100644 --- a/server.go +++ b/server.go @@ -17,6 +17,7 @@ import ( "io/ioutil" "net" "net/http" + "strings" "sync" "sync/atomic" "time" @@ -67,9 +68,9 @@ func newServer(t EndPointType, opts ...ServerOption) *server { s.init(opts...) - if s.addr == "" { - panic(fmt.Sprintf("@addr:%s", s.addr)) - } + //if len(s.addr) == 0 { + // panic(fmt.Sprintf("@addr:%s", s.addr)) + //} return s } @@ -163,9 +164,16 @@ func (s *server) listenTCP() error { streamListener net.Listener ) - streamListener, err = net.Listen("tcp", s.addr) - if err != nil { - return perrors.Wrapf(err, "net.Listen(tcp, addr:%s))", s.addr) + if len(s.addr) == 0 || !strings.Contains(s.addr, ":") { + streamListener, err = gxnet.ListenOnTCPRandomPort(s.addr) + if err != nil { + return perrors.Wrapf(err, "gxnet.ListenOnTCPRandomPort(addr:%s)", s.addr) + } + } else { + streamListener, err = net.Listen("tcp", s.addr) + if err != nil { + return perrors.Wrapf(err, "net.Listen(tcp, addr:%s)", s.addr) + } } s.streamListener = streamListener @@ -180,13 +188,20 @@ func (s *server) listenUDP() error { pktListener *net.UDPConn ) - localAddr, err = net.ResolveUDPAddr("udp", s.addr) - if err != nil { - return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr) - } - pktListener, err = net.ListenUDP("udp", localAddr) - if err != nil { - return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr) + if len(s.addr) == 0 || !strings.Contains(s.addr, ":") { + pktListener, err = gxnet.ListenOnUDPRandomPort(s.addr) + if err != nil { + return perrors.Wrapf(err, "gxnet.ListenOnUDPRandomPort(addr:%s)", s.addr) + } + } else { + localAddr, err = net.ResolveUDPAddr("udp", s.addr) + if err != nil { + return perrors.Wrapf(err, "net.ResolveUDPAddr(udp, addr:%s)", s.addr) + } + pktListener, err = net.ListenUDP("udp", localAddr) + if err != nil { + return perrors.Wrapf(err, "net.ListenUDP((udp, localAddr:%#v)", localAddr) + } } s.pktListener = pktListener @@ -256,7 +271,7 @@ func (s *server) runTcpEventLoop(newSession NewSessionCallback) { } continue } - log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, err) + log.Warnf("server{%s}.Accept() = err {%+v}", s.addr, perrors.WithStack(err)) continue } delay = 0 @@ -357,7 +372,7 @@ func (s *server) runWSEventLoop(newSession NewSessionCallback) { s.lock.Unlock() err = server.Serve(s.streamListener) if err != nil { - log.Errorf("http.server.Serve(addr{%s}) = err{%+v}", s.addr, err) + log.Errorf("http.server.Serve(addr{%s}) = err{%+v}", s.addr, perrors.WithStack(err)) // panic(err) } }() @@ -381,7 +396,7 @@ func (s *server) runWSSEventLoop(newSession NewSessionCallback) { if certificate, err = tls.LoadX509KeyPair(s.cert, s.privateKey); err != nil { panic(fmt.Sprintf("tls.LoadX509KeyPair(cert{%s}, privateKey{%s}) = err{%+v}", - s.cert, s.privateKey, err)) + s.cert, s.privateKey, perrors.WithStack(err))) return } config = &tls.Config{ @@ -394,7 +409,7 @@ func (s *server) runWSSEventLoop(newSession NewSessionCallback) { if s.caCert != "" { certPem, err = ioutil.ReadFile(s.caCert) if err != nil { - panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%+v}", s.caCert, err)) + panic(fmt.Errorf("ioutil.ReadFile(certFile{%s}) = err{%+v}", s.caCert, perrors.WithStack(err))) } certPool = x509.NewCertPool() if ok := certPool.AppendCertsFromPEM(certPem); !ok { @@ -419,7 +434,7 @@ func (s *server) runWSSEventLoop(newSession NewSessionCallback) { s.lock.Unlock() err = server.Serve(tls.NewListener(s.streamListener, config)) if err != nil { - log.Errorf("http.server.Serve(addr{%s}) = err{%+v}", s.addr, err) + log.Errorf("http.server.Serve(addr{%s}) = err{%+v}", s.addr, perrors.WithStack(err)) panic(err) } }() @@ -429,7 +444,7 @@ func (s *server) runWSSEventLoop(newSession NewSessionCallback) { // @newSession: new connection callback func (s *server) RunEventLoop(newSession NewSessionCallback) { if err := s.listen(); err != nil { - panic(fmt.Errorf("server.listen() = error:%+v", err)) + panic(fmt.Errorf("server.listen() = error:%+v", perrors.WithStack(err))) } switch s.endPointType { diff --git a/server_test.go b/server_test.go index 8d9a3bbb..5643380e 100644 --- a/server_test.go +++ b/server_test.go @@ -9,16 +9,16 @@ import ( "github.com/stretchr/testify/assert" ) -func TestTCPServer(t *testing.T) { +func testTCPServer(t *testing.T, address string) { var ( server *server serverMsgHandler MessageHandler ) - addr := "127.0.0.1:0" + func() { server = newServer( TCP_SERVER, - WithLocalAddress(addr), + WithLocalAddress(address), ) newServerSession := func(session Session) error { return newSessionCallback(session, &serverMsgHandler) @@ -26,11 +26,12 @@ func TestTCPServer(t *testing.T) { server.RunEventLoop(newServerSession) assert.True(t, server.ID() > 0) assert.True(t, server.EndPointType() == TCP_SERVER) + assert.NotNil(t, server.streamListener) }() time.Sleep(500e6) - addr = server.streamListener.Addr().String() - t.Logf("server addr: %v", addr) + addr := server.streamListener.Addr().String() + t.Logf("@address:%s, tcp server addr: %v", address, addr) clt := newClient(TCP_CLIENT, WithServerAddress(addr), WithReconnectInterval(5e8), @@ -58,16 +59,15 @@ func TestTCPServer(t *testing.T) { assert.True(t, server.IsClosed()) } -func TestUDPServer(t *testing.T) { +func testUDPServer(t *testing.T, address string) { var ( server *server serverMsgHandler MessageHandler ) - addr := "127.0.0.1:0" func() { server = newServer( UDP_ENDPOINT, - WithLocalAddress(addr), + WithLocalAddress(address), ) newServerSession := func(session Session) error { return newSessionCallback(session, &serverMsgHandler) @@ -75,34 +75,25 @@ func TestUDPServer(t *testing.T) { server.RunEventLoop(newServerSession) assert.True(t, server.ID() > 0) assert.True(t, server.EndPointType() == UDP_ENDPOINT) + assert.NotNil(t, server.pktListener) }() time.Sleep(500e6) - //addr = server.streamListener.Addr().String() - //t.Logf("server addr: %v", addr) - //clt := newClient(TCP_CLIENT, - // WithServerAddress(addr), - // WithReconnectInterval(5e8), - // WithConnectionNumber(1), - //) - //assert.NotNil(t, clt) - //assert.True(t, clt.ID() > 0) - //assert.Equal(t, clt.endPointType, TCP_CLIENT) - // - //var ( - // msgHandler MessageHandler - //) - //cb := func(session Session) error { - // return newSessionCallback(session, &msgHandler) - //} - // - //clt.RunEventLoop(cb) - //time.Sleep(1e9) - // - //assert.Equal(t, 1, msgHandler.SessionNumber()) - //clt.Close() - //assert.True(t, clt.IsClosed()) - // - //server.Close() - //assert.True(t, server.IsClosed()) + addr := server.pktListener.LocalAddr().String() + t.Logf("@address:%s, udp server addr: %v", address, addr) +} + +func TestServer(t *testing.T) { + var addr string + + testTCPServer(t, addr) + testUDPServer(t, addr) + + addr = "127.0.0.1:0" + testTCPServer(t, addr) + testUDPServer(t, addr) + + addr = "127.0.0.1" + testTCPServer(t, addr) + testUDPServer(t, addr) } diff --git a/session.go b/session.go index 5e9dc3ff..0f9ec78d 100644 --- a/session.go +++ b/session.go @@ -703,12 +703,12 @@ func (s *session) handleTCPPackage() error { break } if perrors.Cause(err) == io.EOF { - log.Infof("%s, [session.conn.read] = error:%+v", s.sessionToken(), err) + log.Infof("%s, [session.conn.read] = error:%+v", s.sessionToken(), perrors.WithStack(err)) err = nil exit = true break } - log.Errorf("%s, [session.conn.read] = error:%+v", s.sessionToken(), err) + log.Errorf("%s, [session.conn.read] = error:%+v", s.sessionToken(), perrors.WithStack(err)) exit = true } break @@ -784,7 +784,7 @@ func (s *session) handleUDPPackage() error { } bufLen, addr, err = conn.recv(buf) - log.Debugf("conn.read() = bufLen:%d, addr:%#v, err:%+v", bufLen, addr, err) + log.Debugf("conn.read() = bufLen:%d, addr:%#v, err:%+v", bufLen, addr, perrors.WithStack(err)) if netError, ok = perrors.Cause(err).(net.Error); ok && netError.Timeout() { continue } @@ -796,7 +796,7 @@ func (s *session) handleUDPPackage() error { } if bufLen == 0 { - log.Errorf("conn.read() = bufLen:%d, addr:%s, err:%+v", bufLen, addr, err) + log.Errorf("conn.read() = bufLen:%d, addr:%s, err:%+v", bufLen, addr, perrors.WithStack(err)) continue } @@ -806,17 +806,17 @@ func (s *session) handleUDPPackage() error { } pkg, pkgLen, err = s.reader.Read(s, buf[:bufLen]) - log.Debugf("s.reader.Read() = pkg:%#v, pkgLen:%d, err:%+v", pkg, pkgLen, err) + log.Debugf("s.reader.Read() = pkg:%#v, pkgLen:%d, err:%+v", pkg, pkgLen, perrors.WithStack(err)) if err == nil && s.maxMsgLen > 0 && bufLen > int(s.maxMsgLen) { err = perrors.Errorf("Message Too Long, bufLen %d, session max message len %d", bufLen, s.maxMsgLen) } if err != nil { log.Warnf("%s, [session.handleUDPPackage] = len{%d}, error:%+v", - s.sessionToken(), pkgLen, err) + s.sessionToken(), pkgLen, perrors.WithStack(err)) continue } if pkgLen == 0 { - log.Errorf("s.reader.Read() = pkg:%#v, pkgLen:%d, err:%+v", pkg, pkgLen, err) + log.Errorf("s.reader.Read() = pkg:%#v, pkgLen:%d, err:%+v", pkg, pkgLen, perrors.WithStack(err)) continue } @@ -861,7 +861,7 @@ func (s *session) handleWSPackage() error { } if err != nil { log.Warnf("%s, [session.handleWSPackage] = len{%d}, error:%+v", - s.sessionToken(), length, err) + s.sessionToken(), length, perrors.WithStack(err)) continue }