Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic testing/CI #1

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: ci

on:
push:
branches:
- master
pull_request:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: "go.mod"

- name: Run tests
run: go test -v ./...
56 changes: 41 additions & 15 deletions client/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ import (
gNet "gitlab.com/gartnera/golib/net"
)

type TunnelOpt func(t *Tunnel)

func WithControlTLSConfig(tlsConfig *tls.Config) TunnelOpt {
return func(t *Tunnel) {
t.controlTTLSconfig = tlsConfig
}
}

type Tunnel struct {
token string
server string
Expand All @@ -24,19 +32,31 @@ type Tunnel struct {
tlsSkipVerify bool
target string
httpTargetHostHeader bool
connectLock sync.Mutex

controlTTLSconfig *tls.Config
issuedAddr string
connectLock sync.Mutex
}

func NewTunnel(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string) *Tunnel {
return &Tunnel{
func New(server, hostname, token string, useTLS, tlsSkipVerify, httpTargetHostHeader bool, target string, opts ...TunnelOpt) *Tunnel {
t := &Tunnel{
server: server,
hostname: hostname,
token: token,
useTLS: useTLS,
tlsSkipVerify: tlsSkipVerify,
target: target,
httpTargetHostHeader: httpTargetHostHeader,
controlTTLSconfig: &tls.Config{},
}

for _, opt := range opts {
opt(t)
}
serverName, _, _ := net.SplitHostPort(t.server)
t.controlTTLSconfig.ServerName = serverName

return t
}

func (t *Tunnel) Start() error {
Expand All @@ -53,7 +73,7 @@ func (t *Tunnel) Start() error {
}

func (t *Tunnel) Shutdown() {
conn, err := tls.Dial("tcp", t.server, t.getControlTlsConfig())
conn, err := tls.Dial("tcp", t.server, t.controlTTLSconfig)
if err != nil {
panic(err)
}
Expand All @@ -64,20 +84,27 @@ func (t *Tunnel) Shutdown() {
}
}

func (t *Tunnel) getControlTlsConfig() *tls.Config {
serverName, _, _ := net.SplitHostPort(t.server)
return &tls.Config{
ServerName: serverName,
// IssuedAddr gets the address issued by the server
func (t *Tunnel) IssuedAddr() string {
return t.issuedAddr
}

// IssuedAddrHTTPS gets the address issued by the server with https prefix
func (t *Tunnel) IssuedAddrHTTPS() string {
addr := t.issuedAddr
if strings.HasSuffix(addr, ":443") {
addr = strings.TrimSuffix(addr, ":443")
}
return fmt.Sprintf("https://%s", addr)
}

func (t *Tunnel) stage1(print bool) (net.Conn, error) {
func (t *Tunnel) stage1(first bool) (net.Conn, error) {
var err error
var conn net.Conn
backoff := time.Second * 10
t.connectLock.Lock()
for {
conn, err = tls.Dial("tcp", t.server, t.getControlTlsConfig())
conn, err = tls.Dial("tcp", t.server, t.controlTTLSconfig)
if err == nil {
break
}
Expand All @@ -100,13 +127,12 @@ func (t *Tunnel) stage1(print bool) (net.Conn, error) {

}
res := string(buf[:n])
if print {
if first {
_, port, _ := net.SplitHostPort(t.server)
portPart := ""
if port != "443" {
portPart = fmt.Sprintf(":%s", port)
if port == "" {
port = "443"
}
fmt.Printf("URL: https://%s%s\n", res, portPart)
t.issuedAddr = fmt.Sprintf("%s:%s", res, port)
}
return conn, nil
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/tunnel-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ var rootCmd = &cobra.Command{
hostnameFqdn = strings.Join([]string{hostname, serverHostOnly}, ".")
}

tunnel := client.NewTunnel(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target)
tunnel := client.New(controlName, hostnameFqdn, token, useTLS, tlsSkipVerify, httpTargetHostHeader, target)
err := tunnel.Start()
if err != nil {
return fmt.Errorf("start %s: %w", controlName, err)
}
fmt.Printf("URL: %s\n", tunnel.IssuedAddrHTTPS())
tunnels = append(tunnels, tunnel)
}

Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
)

require (
github.com/benbjohnson/clock v1.1.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
go.uber.org/multierr v1.6.0 // indirect
)
Expand Down Expand Up @@ -107,7 +108,7 @@ require (
github.com/spf13/cobra v1.1.1
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.4.0 // indirect
github.com/stretchr/testify v1.7.5 // indirect
github.com/stretchr/testify v1.7.5
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.287 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/dnspod v1.0.287 // indirect
github.com/transip/gotransip/v6 v6.6.1 // indirect
Expand Down
9 changes: 9 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type Server struct {
controlName string
logger *zap.Logger

addr net.Addr

sync.RWMutex
hostnameMap map[string]*proxySession
secretMap map[string]*proxySession
Expand All @@ -44,6 +46,9 @@ func (s *Server) Start(laddr string, tlsConfig *tls.Config) error {
s.logger.Fatal("could not listen", zap.String("laddr", laddr), zap.Error(err))
}
defer ln.Close()
s.Lock()
s.addr = ln.Addr()
s.Unlock()

ctx := context.Background()
for {
Expand All @@ -63,6 +68,10 @@ func (s *Server) Start(laddr string, tlsConfig *tls.Config) error {
}
}

func (s *Server) Addr() net.Addr {
return s.addr
}

// getHostname generates a three word unique subdomain
// recursively call self until we get a unique name
func (s *Server) getHostname() string {
Expand Down
123 changes: 123 additions & 0 deletions test/e2e_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
"gitlab.com/gartnera/tunnel/client"
"gitlab.com/gartnera/tunnel/server"
"go.uber.org/zap"
"go.uber.org/zap/zaptest"
)

// generateCertificate generates a CA certificate, client certificate, and returns a tls.Config.
func generateCertificate(cn string) (*tls.Config, error) {
// Generate CA private key
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}

// Create CA certificate template
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"My CA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
IsCA: true,
}

// Self-sign CA certificate
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, err
}
caCert, err := x509.ParseCertificate(caCertDER)
if err != nil {
return nil, err
}

// Generate server private key
serverPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}

// Create server certificate template
serverTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{cn},
}

// Sign server certificate with CA
serverCertDER, err := x509.CreateCertificate(rand.Reader, serverTemplate, caCert, &serverPrivateKey.PublicKey, caPrivateKey)
if err != nil {
return nil, err
}

// Create a tls.Config with the server certificate
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{
{
Certificate: [][]byte{serverCertDER},
PrivateKey: serverPrivateKey,
},
},
RootCAs: x509.NewCertPool(),
}

tlsConfig.RootCAs.AddCert(caCert)

return tlsConfig, nil
}

func TestE2E(t *testing.T) {
r := require.New(t)
logger := zaptest.NewLogger(t)

server := server.New("localtest.me", logger)

tlsConfig, err := generateCertificate("*.localtest.me")
r.NoError(err)

go func() {
err = server.Start(":0", tlsConfig)
logger.Error("server start returned error", zap.Error(err))
}()
time.Sleep(time.Millisecond * 50)

_, port, _ := net.SplitHostPort(server.Addr().String())
controlAddr := fmt.Sprintf("control.localtest.me:%s", port)
client := client.New(
controlAddr,
"",
"",
false,
false,
false,
"localhost:1234",
client.WithControlTLSConfig(tlsConfig),
)
err = client.Start()
r.NoError(err, "client start")

r.Contains(client.IssuedAddr(), "localtest.me")
}