Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Add QUIC support #2962

Merged
merged 28 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add quick listener
  • Loading branch information
pappz committed Nov 15, 2024
commit b23169de63a74c825bf8e79389e2da4ff6bd84c3
8 changes: 3 additions & 5 deletions relay/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ func (c *Client) Connect() error {
return nil
}

err := c.connect()
if err != nil {
if err := c.connect(); err != nil {
return err
}

Expand Down Expand Up @@ -266,8 +265,7 @@ func (c *Client) connect() error {
}
c.relayConn = conn

err = c.handShake()
if err != nil {
if err = c.handShake(); err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
Expand Down Expand Up @@ -341,7 +339,7 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Debugf("failed to read message from relay server: %s", errExit)
c.log.Errorf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
break
Expand Down
11 changes: 5 additions & 6 deletions relay/client/dialer/quic/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package quic

import (
"context"
"fmt"
"net"
"time"

"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)

type QuicAddr struct {
Expand Down Expand Up @@ -36,22 +36,21 @@ func NewConn(session quic.Connection, serverAddress string) net.Conn {
}

func (c *Conn) Read(b []byte) (n int, err error) {
// Use the QUIC stream's Read method directly
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, fmt.Errorf("failed to read from QUIC stream: %v", err)
log.Errorf("failed to read from QUIC session: %v", err)
return 0, err
}

// Copy data to b, ensuring we don’t exceed the size of b
n = copy(b, dgram)
return n, nil
}

func (c *Conn) Write(b []byte) (int, error) {
// Use the QUIC stream's Write method directly
err := c.session.SendDatagram(b)
if err != nil {
return 0, fmt.Errorf("failed to write to QUIC stream: %v", err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
Expand Down
6 changes: 5 additions & 1 deletion relay/client/dialer/quic/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
)

const (
Expand All @@ -35,9 +36,12 @@ func Dial(address string) (net.Conn, error) {
EnableDatagrams: true,
}

// todo add support for custom dialer

session, err := quic.DialAddr(ctx, quicURL, tlsConf, quicConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial QUIC server '%s': %v", quicURL, err)
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}

conn := NewConn(session, address)
Expand Down
2 changes: 0 additions & 2 deletions relay/client/dialer/ws/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ func Dial(address string) (net.Conn, error) {
}
parsedURL.Path = ws.URLPath

log.Infof("------ Dialing to Relay server: %s", wsURL)

wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
if err != nil {
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
Expand Down
68 changes: 0 additions & 68 deletions relay/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,10 @@ package cmd

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"math/big"
"net"
"net/http"
"os"
"os/signal"
Expand Down Expand Up @@ -148,13 +141,6 @@ func execute(cmd *cobra.Command, args []string) error {
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)

tlsSupport = true
srvListenerCfg.TLSConfig, err = generateTestTLSConfig()
if err != nil {
log.Debugf("failed to generate test TLS config: %s", err)
return fmt.Errorf("failed to generate test TLS config: %s", err)
}

srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
if err != nil {
log.Debugf("failed to create relay server: %v", err)
Expand Down Expand Up @@ -227,57 +213,3 @@ func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string
}
return certManager.TLSConfig(), nil
}

// GenerateTestTLSConfig creates a self-signed certificate for testing
func generateTestTLSConfig() (*tls.Config, error) {
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}

// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Organization"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 180), // Valid for 180 days
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}

// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, err
}

// Encode certificate and private key to PEM format
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})

privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})

// Create TLS certificate
tlsCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)
if err != nil {
return nil, err
}

return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"netbird-relay"}, // Your application protocol
}, nil
}
18 changes: 7 additions & 11 deletions relay/server/listener/quic/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"

Expand Down Expand Up @@ -37,32 +38,27 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
for {
session, err := listener.Accept(context.Background())
if err != nil {
// Check if the listener was closed intentionally
if err.Error() == "server closed" {
if errors.Is(err, quic.ErrServerClosed) {
return nil
}

log.Errorf("Failed to accept QUIC session: %v", err)
continue
}

// Handle each session in a separate goroutine
go l.handleSession(session)
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
l.acceptFn(conn)
}
}

func (l *Listener) handleSession(session quic.Connection) {
conn := NewConn(session)
l.acceptFn(conn)
}

func (l *Listener) Shutdown(ctx context.Context) error {
if l.listener == nil {
return nil
}

log.Infof("stopping QUIC listener")
err := l.listener.Close()
if err != nil {
if err := l.listener.Close(); err != nil {
return fmt.Errorf("listener shutdown failed: %v", err)
}
log.Infof("QUIC listener stopped")
Expand Down
3 changes: 3 additions & 0 deletions relay/server/listener/ws/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Listener struct {

server *http.Server
acceptFn func(conn net.Conn)
log *log.Entry
}

func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
Expand Down Expand Up @@ -88,6 +89,8 @@ func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
return
}

log.Infof("WS client connected from: %s", rAddr)

conn := NewConn(wsConn, lAddr, rAddr)
l.acceptFn(conn)
}
Expand Down
Loading