diff --git a/stack.go b/stack.go index 43c5383..2d5ef5a 100644 --- a/stack.go +++ b/stack.go @@ -28,14 +28,6 @@ type StackOptions struct { InterfaceFinder control.InterfaceFinder } -func (o *StackOptions) BufferSize() uint32 { - if o.TunOptions.GSO { - return o.TunOptions.GSOMaxSize - } else { - return o.TunOptions.MTU - } -} - func NewStack( stack string, options StackOptions, diff --git a/stack_mixed.go b/stack_mixed.go index 7a651d1..db397fd 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -48,7 +48,7 @@ func (m *Mixed) Start() error { if err != nil { return err } - endpoint := channel.New(1024, m.mtu, "") + endpoint := channel.New(1024, uint32(m.mtu), "") ipStack, err := newGVisorStack(endpoint) if err != nil { return err @@ -95,8 +95,16 @@ func (m *Mixed) tunLoop() { m.wintunLoop(winTun) return } + + if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN { + batchSize := batchTUN.BatchSize() + if batchSize > 1 { + m.batchLoop(batchTUN, batchSize) + return + } + } frontHeadroom := m.tun.FrontHeadroom() - packetBuffer := make([]byte, m.bufferSize+frontHeadroom+PacketOffset) + packetBuffer := make([]byte, m.mtu+frontHeadroom+PacketOffset) for { n, err := m.tun.Read(packetBuffer[frontHeadroom:]) if err != nil { @@ -110,17 +118,7 @@ func (m *Mixed) tunLoop() { } rawPacket := packetBuffer[:frontHeadroom+n] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = m.processIPv4(rawPacket, packet) - case 6: - err = m.processIPv6(rawPacket, packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } - if err != nil { - m.logger.Trace(err) - } + m.processPacket(rawPacket, packet) } } @@ -134,18 +132,53 @@ func (m *Mixed) wintunLoop(winTun WinTun) { release() continue } - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = m.processIPv4(packet, packet) - case 6: - err = m.processIPv6(packet, packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } + m.processPacket(packet, packet) + release() + } +} + +func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { + frontHeadroom := m.tun.FrontHeadroom() + packetBuffers := make([][]byte, batchSize) + for i := range packetBuffers { + packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) + } + packetSizes := make([]int, batchSize) + for { + n, err := linuxTUN.BatchRead(packetBuffers, packetSizes) if err != nil { - m.logger.Trace(err) + if E.IsClosed(err) { + return + } + m.logger.Error(E.Cause(err, "batch read packet")) } - release() + if n == 0 { + continue + } + for i := 0; i < n; i++ { + packetBuffer := packetBuffers[i][:packetSizes[i]] + if n < clashtcpip.IPv4PacketMinLength { + continue + } + rawPacket := packetBuffer[:frontHeadroom+n] + packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] + m.processPacket(rawPacket, packet) + } + } +} + +func (m *Mixed) processPacket(rawPacket []byte, packet []byte) { + var err error + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = m.processIPv4(rawPacket, packet) + case 6: + err = m.processIPv6(rawPacket, packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) } } diff --git a/stack_system.go b/stack_system.go index ce18934..608b63e 100644 --- a/stack_system.go +++ b/stack_system.go @@ -22,8 +22,7 @@ type System struct { ctx context.Context tun Tun tunName string - mtu uint32 - bufferSize int + mtu int handler Handler logger logger.Logger inet4Prefixes []netip.Prefix @@ -57,8 +56,7 @@ func NewSystem(options StackOptions) (Stack, error) { ctx: options.Context, tun: options.Tun, tunName: options.TunOptions.Name, - mtu: options.TunOptions.MTU, - bufferSize: int(options.BufferSize()), + mtu: int(options.TunOptions.MTU), udpTimeout: options.UDPTimeout, handler: options.Handler, logger: options.Logger, @@ -147,8 +145,15 @@ func (s *System) tunLoop() { s.wintunLoop(winTun) return } + if batchTUN, isBatchTUN := s.tun.(BatchTUN); isBatchTUN { + batchSize := batchTUN.BatchSize() + if batchSize > 1 { + s.batchLoop(batchTUN, batchSize) + return + } + } frontHeadroom := s.tun.FrontHeadroom() - packetBuffer := make([]byte, s.bufferSize+frontHeadroom+PacketOffset) + packetBuffer := make([]byte, s.mtu+frontHeadroom+PacketOffset) for { n, err := s.tun.Read(packetBuffer[frontHeadroom:]) if err != nil { @@ -162,17 +167,7 @@ func (s *System) tunLoop() { } rawPacket := packetBuffer[:frontHeadroom+n] packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = s.processIPv4(rawPacket, packet) - case 6: - err = s.processIPv6(rawPacket, packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } - if err != nil { - s.logger.Trace(err) - } + s.processPacket(rawPacket, packet) } } @@ -186,18 +181,53 @@ func (s *System) wintunLoop(winTun WinTun) { release() continue } - switch ipVersion := packet[0] >> 4; ipVersion { - case 4: - err = s.processIPv4(packet, packet) - case 6: - err = s.processIPv6(packet, packet) - default: - err = E.New("ip: unknown version: ", ipVersion) - } + s.processPacket(packet, packet) + release() + } +} + +func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { + frontHeadroom := s.tun.FrontHeadroom() + packetBuffers := make([][]byte, batchSize) + for i := range packetBuffers { + packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset) + } + packetSizes := make([]int, batchSize) + for { + n, err := linuxTUN.BatchRead(packetBuffers, packetSizes) if err != nil { - s.logger.Trace(err) + if E.IsClosed(err) { + return + } + s.logger.Error(E.Cause(err, "batch read packet")) } - release() + if n == 0 { + continue + } + for i := 0; i < n; i++ { + packetBuffer := packetBuffers[i][:packetSizes[i]] + if n < clashtcpip.IPv4PacketMinLength { + continue + } + rawPacket := packetBuffer[:frontHeadroom+n] + packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n] + s.processPacket(rawPacket, packet) + } + } +} + +func (s *System) processPacket(rawPacket []byte, packet []byte) { + var err error + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = s.processIPv4(rawPacket, packet) + case 6: + err = s.processIPv6(rawPacket, packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + s.logger.Trace(err) } } @@ -354,7 +384,7 @@ func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom(), headerCopy, source} + return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source} }) return nil } @@ -380,7 +410,7 @@ func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet, headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom(), headerCopy, source} + return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom() + PacketOffset, headerCopy, source} }) return nil } @@ -421,8 +451,7 @@ type systemUDPPacketWriter4 struct { func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) defer newPacket.Release() - newPacket.WriteZeroN(w.frontHeadroom) - newPacket.Advance(w.frontHeadroom) + newPacket.Resize(w.frontHeadroom, 0) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) @@ -435,7 +464,11 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize)) udpHdr.ResetChecksum(ipHdr.PseudoSum()) ipHdr.ResetChecksum() - newPacket.Advance(-w.frontHeadroom) + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET + } else { + newPacket.Advance(-w.frontHeadroom) + } return common.Error(w.tun.Write(newPacket.Bytes())) } @@ -449,8 +482,7 @@ type systemUDPPacketWriter6 struct { func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) defer newPacket.Release() - newPacket.WriteZeroN(w.frontHeadroom) - newPacket.Advance(w.frontHeadroom) + newPacket.Resize(w.frontHeadroom, 0) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) @@ -463,6 +495,10 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(udpLen) udpHdr.ResetChecksum(ipHdr.PseudoSum()) - newPacket.Advance(-w.frontHeadroom) + if PacketOffset > 0 { + newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6 + } else { + newPacket.Advance(-w.frontHeadroom) + } return common.Error(w.tun.Write(newPacket.Bytes())) } diff --git a/tun.go b/tun.go index f897c28..e6de7f1 100644 --- a/tun.go +++ b/tun.go @@ -32,6 +32,11 @@ type WinTun interface { ReadPacket() ([]byte, func(), error) } +type BatchTUN interface { + BatchSize() int + BatchRead(buffers [][]byte, readN []int) (n int, err error) +} + type Options struct { Name string Inet4Address []netip.Prefix diff --git a/tun_linux.go b/tun_linux.go index 62c3992..25801cf 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -21,6 +21,8 @@ import ( "golang.org/x/sys/unix" ) +var _ BatchTUN = (*NativeTun)(nil) + type NativeTun struct { tunFd int tunFile *os.File @@ -119,6 +121,29 @@ func (t *NativeTun) Write(p []byte) (n int, err error) { return t.tunFile.Write(p) } +func (t *NativeTun) BatchSize() int { + if !t.gsoEnabled { + return 1 + } + return idealBatchSize +} + +func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) { + if t.gsoEnabled { + n, err = t.tunFile.Read(t.gsoBuffer) + if err != nil { + return + } + n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0) + if err != nil { + return + } + return + } else { + return 0, os.ErrInvalid + } +} + var controlPath string func init() { diff --git a/tun_linux_offload.go b/tun_linux_offload.go index ae76b19..8d3e4fb 100644 --- a/tun_linux_offload.go +++ b/tun_linux_offload.go @@ -23,7 +23,7 @@ import ( const ( tcpFlagsOffset = 13 - idealBatchSize = 1 + idealBatchSize = 128 ) const (