Skip to content

Commit

Permalink
Add basic testing/CI
Browse files Browse the repository at this point in the history
  • Loading branch information
gartnera committed Nov 27, 2024
1 parent b44556f commit 9732e14
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 17 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: ci

on:
push:
branches:
- master
pull_request:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v5

- name: Run tests
run: go test -v ./...
56 changes: 41 additions & 15 deletions client/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ import (
gNet "gitlab.com/gartnera/golib/net"
)

type TunnelOpt func(t *Tunnel)

func WithControlTLSConfig(tlsConfig *tls.Config) TunnelOpt {
return func(t *Tunnel) {
t.controlTTLSconfig = tlsConfig
}
}

type Tunnel struct {
token string
server string
Expand All @@ -24,19 +32,31 @@ type Tunnel struct {
tlsSkipVerify bool
target string
httpTargetHostHeader bool
connectLock sync.Mutex

controlTTLSconfig *tls.Config
issuedAddr string
connectLock sync.Mutex
}

func NewTunnel(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string) *Tunnel {
return &Tunnel{
func New(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string, opts ...TunnelOpt) *Tunnel {
t := &Tunnel{
server: server,
hostname: hostname,
token: token,
useTLS: useTLS,
tlsSkipVerify: tlsSkipVerify,
target: target,
httpTargetHostHeader: httpTargetHostHeader,
controlTTLSconfig: &tls.Config{},
}

for _, opt := range opts {
opt(t)
}
serverName, _, _ := net.SplitHostPort(t.server)
t.controlTTLSconfig.ServerName = serverName

return t
}

func (t *Tunnel) Start() error {
Expand All @@ -53,7 +73,7 @@ func (t *Tunnel) Start() error {
}

func (t *Tunnel) Shutdown() {
conn, err := tls.Dial("tcp", t.server, t.getControlTlsConfig())
conn, err := tls.Dial("tcp", t.server, t.controlTTLSconfig)
if err != nil {
panic(err)
}
Expand All @@ -64,20 +84,27 @@ func (t *Tunnel) Shutdown() {
}
}

func (t *Tunnel) getControlTlsConfig() *tls.Config {
serverName, _, _ := net.SplitHostPort(t.server)
return &tls.Config{
ServerName: serverName,
// IssuedAddr gets the address issued by the server
func (t *Tunnel) IssuedAddr() string {
return t.issuedAddr
}

// IssuedAddrHTTPS gets the address issued by the server with https prefix
func (t *Tunnel) IssuedAddrHTTPS() string {
addr := t.issuedAddr
if strings.HasSuffix(addr, ":443") {
addr = strings.TrimSuffix(addr, ":443")
}
return fmt.Sprintf("https://%s", addr)
}

func (t *Tunnel) stage1(print bool) (net.Conn, error) {
func (t *Tunnel) stage1(first bool) (net.Conn, error) {
var err error
var conn net.Conn
backoff := time.Second * 10
t.connectLock.Lock()
for {
conn, err = tls.Dial("tcp", t.server, t.getControlTlsConfig())
conn, err = tls.Dial("tcp", t.server, t.controlTTLSconfig)
if err == nil {
break
}
Expand All @@ -100,13 +127,12 @@ func (t *Tunnel) stage1(print bool) (net.Conn, error) {

}
res := string(buf[:n])
if print {
if first {
_, port, _ := net.SplitHostPort(t.server)
portPart := ""
if port != "443" {
portPart = fmt.Sprintf(":%s", port)
if port == "" {
port = "443"
}
fmt.Printf("URL: https://%s%s\n", res, portPart)
t.issuedAddr = fmt.Sprintf("%s:%s", res, port)
}
return conn, nil
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/tunnel-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ var rootCmd = &cobra.Command{
hostnameFqdn = strings.Join([]string{hostname, serverHostOnly}, ".")
}

tunnel := client.NewTunnel(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target)
tunnel := client.New(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target)
err := tunnel.Start()
if err != nil {
return fmt.Errorf("start %s: %w", controlName, err)
}
fmt.Printf("URL: %s\n", tunnel.IssuedAddrHTTPS())
tunnels = append(tunnels, tunnel)
}

Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
)

require (
github.com/benbjohnson/clock v1.1.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
)
Expand Down Expand Up @@ -107,7 +108,7 @@ require (
github.com/spf13/cobra v1.1.1
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.4.0 // indirect
github.com/stretchr/testify v1.7.5 // indirect
github.com/stretchr/testify v1.7.5
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.287 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.287 // indirect
github.com/transip/gotransip/v6 v6.6.1 // indirect
Expand Down
9 changes: 9 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type Server struct {
controlName string
logger *zap.Logger

addr net.Addr

sync.RWMutex
hostnameMap map[string]*proxySession
secretMap map[string]*proxySession
Expand All @@ -44,6 +46,9 @@ func (s *Server) Start(laddr string, tlsConfig *tls.Config) error {
s.logger.Fatal("could not listen", zap.String("laddr", laddr), zap.Error(err))
}
defer ln.Close()
s.Lock()
s.addr = ln.Addr()
s.Unlock()

ctx := context.Background()
for {
Expand All @@ -63,6 +68,10 @@ func (s *Server) Start(laddr string, tlsConfig *tls.Config) error {
}
}

func (s *Server) Addr() net.Addr {
return s.addr
}

// getHostname generates a three word unique subdomain
// recursively call self until we get a unique name
func (s *Server) getHostname() string {
Expand Down
123 changes: 123 additions & 0 deletions test/e2e_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
"gitlab.com/gartnera/tunnel/client"
"gitlab.com/gartnera/tunnel/server"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)

// generateCertificate generates a CA certificate, client certificate, and returns a tls.Config.
func generateCertificate(cn string) (*tls.Config, error) {
// Generate CA private key
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}

// Create CA certificate template
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"My CA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
IsCA: true,
}

// Self-sign CA certificate
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, err
}
caCert, err := x509.ParseCertificate(caCertDER)
if err != nil {
return nil, err
}

// Generate server private key
serverPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}

// Create server certificate template
serverTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{cn},
}

// Sign server certificate with CA
serverCertDER, err := x509.CreateCertificate(rand.Reader, serverTemplate, caCert, &serverPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, err
}

// Create a tls.Config with the server certificate
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{serverCertDER},
PrivateKey: serverPrivateKey,
},
},
RootCAs: x509.NewCertPool(),
}

tlsConfig.RootCAs.AddCert(caCert)

return tlsConfig, nil
}

func TestE2E(t *testing.T) {
r := require.New(t)
logger := zaptest.NewLogger(t)

server := server.New("localtest.me", logger)

tlsConfig, err := generateCertificate("*.localtest.me")
r.NoError(err)

go func() {
err = server.Start(":0", tlsConfig)
logger.Error("server start returned error", zap.Error(err))
}()
time.Sleep(time.Millisecond * 50)

_, port, _ := net.SplitHostPort(server.Addr().String())
controlAddr := fmt.Sprintf("control.localtest.me:%s", port)
client := client.New(
controlAddr,
"",
"",
false,
false,
false,
"localhost:1234",
client.WithControlTLSConfig(tlsConfig),
)
err = client.Start()
r.NoError(err, "client start")

r.Contains(client.IssuedAddr(), "localtest.me")
}

0 comments on commit 9732e14

Please sign in to comment.