diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..3dbd287 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,19 @@ +name: ci + +on: + push: + branches: + - master + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + + - name: Run tests + run: go test -v ./... diff --git a/client/tunnel.go b/client/tunnel.go index 74c4805..a411ae6 100644 --- a/client/tunnel.go +++ b/client/tunnel.go @@ -16,6 +16,14 @@ import ( gNet "gitlab.com/gartnera/golib/net" ) +type TunnelOpt func(t *Tunnel) + +func WithControlTLSConfig(tlsConfig *tls.Config) TunnelOpt { + return func(t *Tunnel) { + t.controlTTLSconfig = tlsConfig + } +} + type Tunnel struct { token string server string @@ -24,11 +32,14 @@ type Tunnel struct { tlsSkipVerify bool target string httpTargetHostHeader bool - connectLock sync.Mutex + + controlTTLSconfig *tls.Config + issuedAddr string + connectLock sync.Mutex } -func NewTunnel(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string) *Tunnel { - return &Tunnel{ +func New(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string, opts ...TunnelOpt) *Tunnel { + t := &Tunnel{ server: server, hostname: hostname, token: token, @@ -36,7 +47,16 @@ func NewTunnel(server, hostname, token string, useTLS, tlsSkipVerify, httpTarget tlsSkipVerify: tlsSkipVerify, target: target, httpTargetHostHeader: httpTargetHostHeader, + controlTTLSconfig: &tls.Config{}, + } + + for _, opt := range opts { + opt(t) } + serverName, _, _ := net.SplitHostPort(t.server) + t.controlTTLSconfig.ServerName = serverName + + return t } func (t *Tunnel) Start() error { @@ -53,7 +73,7 @@ func (t *Tunnel) Start() error { } func (t *Tunnel) Shutdown() { - conn, err := tls.Dial("tcp", t.server, t.getControlTlsConfig()) + conn, err := tls.Dial("tcp", t.server, t.controlTTLSconfig) if err != nil { panic(err) } @@ -64,20 +84,27 @@ func (t *Tunnel) Shutdown() { } } -func (t *Tunnel) getControlTlsConfig() *tls.Config { - serverName, _, _ := net.SplitHostPort(t.server) - return &tls.Config{ - ServerName: serverName, +// IssuedAddr gets the address issued by the server +func (t *Tunnel) IssuedAddr() string { + return t.issuedAddr +} + +// IssuedAddrHTTPS gets the address issued by the server with https prefix +func (t *Tunnel) IssuedAddrHTTPS() string { + addr := t.issuedAddr + if strings.HasSuffix(addr, ":443") { + addr = strings.TrimSuffix(addr, ":443") } + return fmt.Sprintf("https://%s", addr) } -func (t *Tunnel) stage1(print bool) (net.Conn, error) { +func (t *Tunnel) stage1(first bool) (net.Conn, error) { var err error var conn net.Conn backoff := time.Second * 10 t.connectLock.Lock() for { - conn, err = tls.Dial("tcp", t.server, t.getControlTlsConfig()) + conn, err = tls.Dial("tcp", t.server, t.controlTTLSconfig) if err == nil { break } @@ -100,13 +127,12 @@ func (t *Tunnel) stage1(print bool) (net.Conn, error) { } res := string(buf[:n]) - if print { + if first { _, port, _ := net.SplitHostPort(t.server) - portPart := "" - if port != "443" { - portPart = fmt.Sprintf(":%s", port) + if port == "" { + port = "443" } - fmt.Printf("URL: https://%s%s\n", res, portPart) + t.issuedAddr = fmt.Sprintf("%s:%s", res, port) } return conn, nil } diff --git a/cmd/tunnel-client/main.go b/cmd/tunnel-client/main.go index b830b90..957e43a 100644 --- a/cmd/tunnel-client/main.go +++ b/cmd/tunnel-client/main.go @@ -95,11 +95,12 @@ var rootCmd = &cobra.Command{ hostnameFqdn = strings.Join([]string{hostname, serverHostOnly}, ".") } - tunnel := client.NewTunnel(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target) + tunnel := client.New(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target) err := tunnel.Start() if err != nil { return fmt.Errorf("start %s: %w", controlName, err) } + fmt.Printf("URL: %s\n", tunnel.IssuedAddrHTTPS()) tunnels = append(tunnels, tunnel) } diff --git a/go.mod b/go.mod index ed8aa77..7ac41fc 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( ) require ( + github.com/benbjohnson/clock v1.1.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect ) @@ -107,7 +108,7 @@ require ( github.com/spf13/cobra v1.1.1 github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.4.0 // indirect - github.com/stretchr/testify v1.7.5 // indirect + github.com/stretchr/testify v1.7.5 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.287 // indirect github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.287 // indirect github.com/transip/gotransip/v6 v6.6.1 // indirect diff --git a/server/server.go b/server/server.go index c84eaac..202ec05 100644 --- a/server/server.go +++ b/server/server.go @@ -18,6 +18,8 @@ type Server struct { controlName string logger *zap.Logger + addr net.Addr + sync.RWMutex hostnameMap map[string]*proxySession secretMap map[string]*proxySession @@ -44,6 +46,9 @@ func (s *Server) Start(laddr string, tlsConfig *tls.Config) error { s.logger.Fatal("could not listen", zap.String("laddr", laddr), zap.Error(err)) } defer ln.Close() + s.Lock() + s.addr = ln.Addr() + s.Unlock() ctx := context.Background() for { @@ -63,6 +68,10 @@ func (s *Server) Start(laddr string, tlsConfig *tls.Config) error { } } +func (s *Server) Addr() net.Addr { + return s.addr +} + // getHostname generates a three word unique subdomain // recursively call self until we get a unique name func (s *Server) getHostname() string { diff --git a/test/e2e_test.go b/test/e2e_test.go new file mode 100644 index 0000000..51df39f --- /dev/null +++ b/test/e2e_test.go @@ -0,0 +1,123 @@ +package test + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gitlab.com/gartnera/tunnel/client" + "gitlab.com/gartnera/tunnel/server" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" +) + +// generateCertificate generates a CA certificate, client certificate, and returns a tls.Config. +func generateCertificate(cn string) (*tls.Config, error) { + // Generate CA private key + caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + // Create CA certificate template + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"My CA"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment, + BasicConstraintsValid: true, + IsCA: true, + } + + // Self-sign CA certificate + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, err + } + caCert, err := x509.ParseCertificate(caCertDER) + if err != nil { + return nil, err + } + + // Generate server private key + serverPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + // Create server certificate template + serverTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{cn}, + } + + // Sign server certificate with CA + serverCertDER, err := x509.CreateCertificate(rand.Reader, serverTemplate, caCert, &serverPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, err + } + + // Create a tls.Config with the server certificate + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{serverCertDER}, + PrivateKey: serverPrivateKey, + }, + }, + RootCAs: x509.NewCertPool(), + } + + tlsConfig.RootCAs.AddCert(caCert) + + return tlsConfig, nil +} + +func TestE2E(t *testing.T) { + r := require.New(t) + logger := zaptest.NewLogger(t) + + server := server.New("localtest.me", logger) + + tlsConfig, err := generateCertificate("*.localtest.me") + r.NoError(err) + + go func() { + err = server.Start(":0", tlsConfig) + logger.Error("server start returned error", zap.Error(err)) + }() + time.Sleep(time.Millisecond * 50) + + _, port, _ := net.SplitHostPort(server.Addr().String()) + controlAddr := fmt.Sprintf("control.localtest.me:%s", port) + client := client.New( + controlAddr, + "", + "", + false, + false, + false, + "localhost:1234", + client.WithControlTLSConfig(tlsConfig), + ) + err = client.Start() + r.NoError(err, "client start") + + r.Contains(client.IssuedAddr(), "localtest.me") +}