Skip to content

Commit

Permalink
Export interface for WireGuard
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 21, 2024
1 parent 8a18f0c commit cacaf7a
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 48 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.6.0-alpha.11
github.com/sagernet/sing v0.6.0-alpha.18
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.6.0-alpha.11 h1:ZcZlA0/vdDeiipAbjK73x9VabGJ/RRcAJgWhOo/OoBk=
github.com/sagernet/sing v0.6.0-alpha.11/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.6.0-alpha.18 h1:ih4CurU8KvbhfagYjSqVrE2LR0oBSXSZTNH2sAGPGiM=
github.com/sagernet/sing v0.6.0-alpha.18/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
Expand Down
136 changes: 136 additions & 0 deletions internal/gtcpip/header/interfaces.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package header

import (
"net/netip"

tcpip "github.com/sagernet/sing-tun/internal/gtcpip"
)

const (
// MaxIPPacketSize is the maximum supported IP packet size, excluding
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
// m is the minimum IPv6 header size; we leave room for some potential
// IP options.
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
)

// Transport offers generic methods to query and/or update the fields of the
// header of a transport protocol buffer.
type Transport interface {
// SourcePort returns the value of the "source port" field.
SourcePort() uint16

// Destination returns the value of the "destination port" field.
DestinationPort() uint16

// Checksum returns the value of the "checksum" field.
Checksum() uint16

// SetSourcePort sets the value of the "source port" field.
SetSourcePort(uint16)

// SetDestinationPort sets the value of the "destination port" field.
SetDestinationPort(uint16)

// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)

// Payload returns the data carried in the transport buffer.
Payload() []byte
}

// ChecksummableTransport is a Transport that supports checksumming.
type ChecksummableTransport interface {
Transport

// SetSourcePortWithChecksumUpdate sets the source port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetSourcePortWithChecksumUpdate(port uint16)

// SetDestinationPortWithChecksumUpdate sets the destination port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetDestinationPortWithChecksumUpdate(port uint16)

// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
//
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
// fully calculated checksum. Otherwise, it is assumed to hold a partially
// calculated checksum which only reflects the pseudo header.
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
}

// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address

// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address

DestinationAddr() netip.Addr

// Checksum returns the value of the "checksum" field.
Checksum() uint16

// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)

// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)

SetDestinationAddr(addr netip.Addr)

// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)

// TransportProtocol returns the number of the transport protocol
// stored in the payload.
TransportProtocol() tcpip.TransportProtocolNumber

// Payload returns a byte slice containing the payload of the network
// packet.
Payload() []byte

// TOS returns the values of the "type of service" and "flow label" fields.
TOS() (uint8, uint32)

// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}

// ChecksummableNetwork is a Network that supports checksumming.
type ChecksummableNetwork interface {
Network

// SetSourceAddressAndChecksum sets the source address and updates the
// checksum to reflect the new address.
SetSourceAddressWithChecksumUpdate(tcpip.Address)

// SetDestinationAddressAndChecksum sets the destination address and
// updates the checksum to reflect the new address.
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
}
37 changes: 9 additions & 28 deletions stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

const WithGVisor = true

const defaultNIC tcpip.NICID = 1
const DefaultNIC tcpip.NICID = 1

type GVisor struct {
ctx context.Context
Expand Down Expand Up @@ -68,28 +66,11 @@ func (t *GVisor) Start() error {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := newGVisorStack(linkEndpoint)
ipStack, err := NewGVisorStack(linkEndpoint)
if err != nil {
return err
}
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := t.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: t.ctx,
stack: t.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
Expand Down Expand Up @@ -124,7 +105,7 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
}
}

func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
Expand All @@ -137,19 +118,19 @@ func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
icmp.NewProtocol6,
},
})
err := ipStack.CreateNIC(defaultNIC, ep)
err := ipStack.CreateNIC(DefaultNIC, ep)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv4EmptySubnet, NIC: DefaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: DefaultNIC},
})
err = ipStack.SetSpoofing(defaultNIC, true)
err = ipStack.SetSpoofing(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
err = ipStack.SetPromiscuousMode(defaultNIC, true)
err = ipStack.SetPromiscuousMode(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
Expand Down
51 changes: 51 additions & 0 deletions stack_gvisor_tcp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//go:build with_gvisor

package tun

import (
"context"

"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

type TCPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
forwarder *tcp.Forwarder
}

func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
forwarder := &TCPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
}
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder
}

func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
return f.forwarder.HandlePacket(id, pkt)
}

func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: f.ctx,
stack: f.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil)
}
2 changes: 1 addition & 1 deletion stack_gvisor_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
defer packetBuffer.Release()

route, err := w.stack.FindRoute(
defaultNIC,
DefaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,
Expand Down
8 changes: 4 additions & 4 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (m *Mixed) Start() error {
return err
}
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := newGVisorStack(endpoint)
ipStack, err := NewGVisorStack(endpoint)
if err != nil {
return err
}
Expand Down Expand Up @@ -151,10 +151,10 @@ func (m *Mixed) processPacket(packet []byte) bool {
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
writeBack, err = m.processIPv4(packet)
case 6:
case header.IPv6Version:
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
Expand Down
10 changes: 5 additions & 5 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
Expand Down Expand Up @@ -502,7 +502,7 @@ func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) erro
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
Expand Down Expand Up @@ -684,7 +684,7 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e
}))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
Expand Down Expand Up @@ -724,7 +724,7 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
Expand Down Expand Up @@ -763,7 +763,7 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetChecksum(0)
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
Expand Down
Loading

0 comments on commit cacaf7a

Please sign in to comment.