Skip to content

Commit

Permalink
Merge pull request #366 from lesismal/fix1319
Browse files Browse the repository at this point in the history
Fix1319
  • Loading branch information
lesismal authored Nov 16, 2023
2 parents e246bda + 56b0559 commit f02683a
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 77 deletions.
32 changes: 24 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,27 +167,43 @@ package main

import (
"fmt"
"log"
"net/http"

"github.com/lesismal/nbio/nbhttp/websocket"
)

func echo(w http.ResponseWriter, r *http.Request) {
var (
upgrader = newUpgrader()
)

func newUpgrader() *websocket.Upgrader {
u := websocket.NewUpgrader()
u.OnMessage(func(c *websocket.Conn, mt websocket.MessageType, data []byte) {
c.WriteMessage(mt, data)
u.OnOpen(func(c *websocket.Conn) {
// echo
fmt.Println("OnOpen:", c.RemoteAddr().String())
})
u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) {
// echo
fmt.Println("OnMessage:", messageType, string(data))
c.WriteMessage(messageType, data)
})
_, err := u.Upgrade(w, r, nil)
u.OnClose(func(c *websocket.Conn, err error) {
fmt.Println("OnClose:", c.RemoteAddr().String(), err)
})
return u
}

func onWebsocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
panic(err)
}
fmt.Println("Upgraded:", conn.RemoteAddr().String())
}

func main() {
mux := &http.ServeMux{}
mux.HandleFunc("/ws", echo)
mux.HandleFunc("/ws", onWebsocket)
server := http.Server{
Addr: "localhost:8080",
Handler: mux,
Expand Down
5 changes: 3 additions & 2 deletions nbhttp/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,11 @@ func (c *ClientConn) Do(req *http.Request, handler func(res *http.Response, conn
isNonblock := true
tlsConn.ResetConn(nbc, isNonblock)

c.conn = tlsConn
nbhttpConn := &Conn{Conn: tlsConn}
c.conn = nbhttpConn
processor := NewClientProcessor(c, c.onResponse)
parser := NewParser(processor, true, engine.ReadLimit, nbc.Execute)
parser.Conn = tlsConn
parser.Conn = nbhttpConn
parser.Engine = engine
parser.OnClose(func(p *Parser, err error) {
c.CloseWithError(err)
Expand Down
120 changes: 68 additions & 52 deletions nbhttp/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,13 @@ func (e *Engine) closeAllConns() {
}
}

func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net.Conn, *tls.Config, func()), decrease func()) {
type Conn struct {
net.Conn
Parser *Parser
Trasfered bool
}

func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(*Conn, *tls.Config, func()), decrease func()) {
e.WaitGroup.Add(1)
go func() {
defer func() {
Expand All @@ -289,7 +295,7 @@ func (e *Engine) listen(ln net.Listener, tlsConfig *tls.Config, addConn func(net
for !e.shutdown {
conn, err := ln.Accept()
if err == nil && !e.shutdown {
addConn(conn, tlsConfig, decrease)
addConn(&Conn{Conn: conn}, tlsConfig, decrease)
} else {
var ne net.Error
if ok := errors.As(err, &ne); ok && ne.Temporary() {
Expand Down Expand Up @@ -520,31 +526,34 @@ func (e *Engine) TLSDataHandler(c *nbio.Conn, data []byte) {
c.Close()
return
}
if tlsConn, ok := parser.Processor.Conn().(*tls.Conn); ok {
defer tlsConn.ResetOrFreeBuffer()

readed := data
buffer := data
for {
_, nread, err := tlsConn.AppendAndRead(readed, buffer)
readed = nil
if err != nil {
c.CloseWithError(err)
return
}
if nread > 0 {
err := parser.Read(buffer[:nread])
nbhttpConn, ok := parser.Processor.Conn().(*Conn)
if ok {
if tlsConn, ok := nbhttpConn.Conn.(*tls.Conn); ok {
defer tlsConn.ResetOrFreeBuffer()

readed := data
buffer := data
for {
_, nread, err := tlsConn.AppendAndRead(readed, buffer)
readed = nil
if err != nil {
logging.Debug("parser.Read failed: %v", err)
c.CloseWithError(err)
return
}
if nread > 0 {
err := parser.Read(buffer[:nread])
if err != nil {
logging.Debug("parser.Read failed: %v", err)
c.CloseWithError(err)
return
}
}
if nread == 0 {
return
}
}
if nread == 0 {
return
}
// c.SetReadDeadline(time.Now().Add(conf.KeepaliveTime))
}
// c.SetReadDeadline(time.Now().Add(conf.KeepaliveTime))
}
}

Expand Down Expand Up @@ -572,13 +581,14 @@ func (engine *Engine) AddTransferredConn(nbc *nbio.Conn) error {
}

// AddConnNonTLSNonBlocking .
func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(c)
func (engine *Engine) AddConnNonTLSNonBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(conn.Conn)
if err != nil {
c.Close()
conn.Close()
logging.Error("AddConnNonTLSNonBlocking failed: %v", err)
return
}
conn.Conn = nbc
if nbc.Session() != nil {
nbc.Close()
return
Expand All @@ -599,13 +609,14 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config
}
engine.conns[key] = struct{}{}
engine.mux.Unlock()
engine._onOpen(nbc)
processor := NewServerProcessor(nbc, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
engine._onOpen(conn.Conn)
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, nbc.Execute)
if engine.isOneshot {
parser.Execute = SyncExecutor
}
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
nbc.SetSession(parser)
nbc.OnData(engine.DataHandler)
Expand All @@ -614,7 +625,7 @@ func (engine *Engine) AddConnNonTLSNonBlocking(c net.Conn, tlsConfig *tls.Config
}

// AddConnNonTLSBlocking .
func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) {
func (engine *Engine) AddConnNonTLSBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
engine.mux.Lock()
if len(engine.conns) >= engine.MaxLoad {
engine.mux.Unlock()
Expand All @@ -623,7 +634,7 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config
decrease()
return
}
switch vt := conn.(type) {
switch vt := conn.Conn.(type) {
case *net.TCPConn, *net.UnixConn:
key, err := conn2Array(vt)
if err != nil {
Expand All @@ -646,19 +657,21 @@ func (engine *Engine) AddConnNonTLSBlocking(conn net.Conn, tlsConfig *tls.Config
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, SyncExecutor)
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
conn.SetReadDeadline(time.Now().Add(engine.KeepaliveTime))
go engine.readConnBlocking(conn, parser, decrease)
}

// AddConnTLSNonBlocking .
func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(conn)
func (engine *Engine) AddConnTLSNonBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
nbc, err := nbio.NBConn(conn.Conn)
if err != nil {
conn.Close()
logging.Error("AddConnTLSNonBlocking failed: %v", err)
return
}
conn.Conn = nbc
if nbc.Session() != nil {
nbc.Close()
logging.Error("AddConnTLSNonBlocking failed: session should not be nil")
Expand All @@ -681,18 +694,20 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config

engine.conns[key] = struct{}{}
engine.mux.Unlock()
engine._onOpen(nbc)
engine._onOpen(conn.Conn)

isClient := false
isNonBlock := true
tlsConn := tls.NewConn(nbc, tlsConfig, isClient, isNonBlock, engine.TLSAllocator)
processor := NewServerProcessor(tlsConn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
conn = &Conn{Conn: tlsConn}
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, nbc.Execute)
if engine.isOneshot {
parser.Execute = SyncExecutor
}
parser.Conn = tlsConn
parser.Conn = conn
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
nbc.SetSession(parser)

Expand All @@ -702,7 +717,7 @@ func (engine *Engine) AddConnTLSNonBlocking(conn net.Conn, tlsConfig *tls.Config
}

// AddConnTLSBlocking .
func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, decrease func()) {
func (engine *Engine) AddConnTLSBlocking(conn *Conn, tlsConfig *tls.Config, decrease func()) {
engine.mux.Lock()
if len(engine.conns) >= engine.MaxLoad {
engine.mux.Unlock()
Expand All @@ -712,7 +727,8 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d
return
}

switch vt := conn.(type) {
underLayerConn := conn.Conn
switch vt := underLayerConn.(type) {
case *net.TCPConn, *net.UnixConn:
key, err := conn2Array(vt)
if err != nil {
Expand All @@ -735,18 +751,20 @@ func (engine *Engine) AddConnTLSBlocking(conn net.Conn, tlsConfig *tls.Config, d

isClient := false
isNonBlock := true
tlsConn := tls.NewConn(conn, tlsConfig, isClient, isNonBlock, engine.TLSAllocator)
processor := NewServerProcessor(tlsConn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
tlsConn := tls.NewConn(underLayerConn, tlsConfig, isClient, isNonBlock, engine.TLSAllocator)
conn = &Conn{Conn: tlsConn}
processor := NewServerProcessor(conn, engine.Handler, engine.KeepaliveTime, !engine.DisableSendfile)
parser := NewParser(processor, false, engine.ReadLimit, SyncExecutor)
parser.Conn = tlsConn
parser.Conn = conn
parser.Engine = engine
conn.Parser = parser
processor.(*ServerProcessor).parser = parser
conn.SetReadDeadline(time.Now().Add(engine.KeepaliveTime))
tlsConn.SetSession(parser)
go engine.readTLSConnBlocking(conn, tlsConn, parser, decrease)
go engine.readTLSConnBlocking(conn, underLayerConn, tlsConn, parser, decrease)
}

func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease func()) {
func (engine *Engine) readConnBlocking(conn *Conn, parser *Parser, decrease func()) {
var (
n int
err error
Expand All @@ -764,7 +782,7 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f
// go func() {
parser.Close(err)
engine.mux.Lock()
switch vt := conn.(type) {
switch vt := conn.Conn.(type) {
case *net.TCPConn, *net.UnixConn:
key, _ := conn2Array(vt)
delete(engine.conns, key)
Expand All @@ -781,13 +799,10 @@ func (engine *Engine) readConnBlocking(conn net.Conn, parser *Parser, decrease f
return
}
parser.Read(buf[:n])
if parser.hijacked {
return
}
}
}

func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, parser *Parser, decrease func()) {
func (engine *Engine) readTLSConnBlocking(conn *Conn, rconn net.Conn, tlsConn *tls.Conn, parser *Parser, decrease func()) {
var (
err error
nread int
Expand All @@ -801,10 +816,13 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
buffer := readBufferPool.Malloc(engine.BlockingReadBufferSize)
defer func() {
readBufferPool.Free(buffer)
parser.Close(err)
tlsConn.Close()
if !conn.Trasfered {
parser.Close(err)
tlsConn.Close()
}

engine.mux.Lock()
switch vt := conn.(type) {
switch vt := rconn.(type) {
case *net.TCPConn, *net.UnixConn:
key, _ := conn2Array(vt)
delete(engine.conns, key)
Expand All @@ -815,7 +833,7 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
}()

for {
nread, err = conn.Read(buffer)
nread, err = rconn.Read(buffer)
if err != nil {
return
}
Expand All @@ -833,9 +851,6 @@ func (engine *Engine) readTLSConnBlocking(conn net.Conn, tlsConn *tls.Conn, pars
logging.Debug("parser.Read failed: %v", err)
return
}
// if parser.hijacked {
// return
// }
}
if nread == 0 {
break
Expand Down Expand Up @@ -1011,6 +1026,7 @@ func NewEngine(conf Config) *Engine {
engine.mux.Lock()
key, _ := conn2Array(c)
delete(engine.conns, key)
delete(engine.dialerConns, key)
engine.mux.Unlock()
})
})
Expand Down
11 changes: 5 additions & 6 deletions nbhttp/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ type Parser struct {

cache []byte

state int8
isClient bool
hijacked bool

readLimit int

errClose error
Expand Down Expand Up @@ -66,8 +62,11 @@ type Parser struct {
trailer http.Header
contentLength int
chunkSize int
chunked bool
headerExists bool

state int8
chunked bool
isClient bool
headerExists bool
}

func (p *Parser) nextState(state int8) {
Expand Down
4 changes: 1 addition & 3 deletions nbhttp/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,9 @@ func (p *ServerProcessor) OnComplete(parser *Parser) {
}

func (p *ServerProcessor) flushResponse(res *Response) {
hijacked := res.hijacked
p.parser.hijacked = hijacked
if p.conn != nil {
req := res.request
if !hijacked {
if !res.hijacked {
res.eoncodeHead()
if err := res.flushTrailer(p.conn); err != nil {
p.conn.Close()
Expand Down
Loading

0 comments on commit f02683a

Please sign in to comment.