Skip to content

Commit

Permalink
Merge pull request #16 from kazeborja/addstomptls
Browse files Browse the repository at this point in the history
Adds TLS support for STOMP connections
  • Loading branch information
djw8605 authored Feb 7, 2023
2 parents 50a212f + a1ab484 commit 4ab70ab
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 9 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Environment variables:
* SHOVELER_STOMP_PASSWORD
* SHOVELER_STOMP_URL
* SHOVELER_STOMP_TOPIC
* SHOVELER_STOMP_CERT
* SHOVELER_STOMP_CERT_KEY
* SHOVELER_METRICS_PORT
* SHOVELER_METRICS_ENABLE
* SHOVELER_MAP_ALL
Expand Down
10 changes: 10 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type Config struct {
StompPassword string
StompURL *url.URL
StompTopic string
StompCert string
StompCertKey string
}

func (c *Config) ReadConfig() {
Expand Down Expand Up @@ -80,6 +82,14 @@ func (c *Config) ReadConfig() {

c.StompTopic = viper.GetString("stomp.topic")
log.Debugln("STOMP Topic:", c.StompTopic)

// Get the STOMP cert
c.StompCert = viper.GetString("stomp.cert")
log.Debugln("STOMP CERT:", c.StompCert)

// Get the STOMP certkey
c.StompCertKey = viper.GetString("stomp.certkey")
log.Debugln("STOMP CERTKEY:", c.StompCertKey)
} else {
log.Panic("MQ option is not one of the allowed ones (amqp, stomp)")
}
Expand Down
13 changes: 13 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
# Select which protocol to use in order to connect to the MQ
# mq: amqp/stomp

# If using amqp protocol
amqp:
url: amqps://username:password@example.com/vhost
exchange: shoveled-xrd
topic:
token_location: /etc/xrootd-monitoring-shoveler/token

# If using stomp protocol
stomp:
user: username
password: password
url: messagebroker.org:port
topic: mytopic
cert: path/to/cert/file
certkey: path/to/certkey/file

listen:
port: 9993
ip: 0.0.0.0
Expand Down
42 changes: 37 additions & 5 deletions stomp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"crypto/tls"
"net/url"
"strings"
"time"
Expand All @@ -16,13 +17,15 @@ func StartStomp(config *Config, queue *ConfirmationQueue) {
stompPassword := config.StompPassword
stompUrl := config.StompURL
stompTopic := config.StompTopic
stompCert := config.StompCert
stompCertKey := config.StompCertKey

if !strings.HasPrefix(stompTopic, "/topic/") {
stompTopic = "/topic/" + stompTopic
}

stompSession := NewStompConnection(stompUser, stompPassword,
*stompUrl, stompTopic)
stompSession := GetNewStompConnection(stompUser, stompPassword,
*stompUrl, stompTopic, stompCert, stompCertKey)

// Message loop, constantly be dequeing and sending the message
// No fancy stuff needed
Expand All @@ -37,21 +40,39 @@ func StartStomp(config *Config, queue *ConfirmationQueue) {

}

func GetNewStompConnection(username string, password string,
stompUrl url.URL, topic string, stompCert string, stompCertKey string) *StompSession {
if stompCert != "" && stompCertKey != "" {
cert, err := tls.LoadX509KeyPair(stompCert, stompCertKey)
if err != nil {
log.Errorln("Failed to load certificate:", err)
}

return NewStompConnection(username, password,
stompUrl, topic, cert)
} else {
return NewStompConnection(username, password,
stompUrl, topic)
}
}

type StompSession struct {
username string
password string
stompUrl url.URL
topic string
cert []tls.Certificate
conn *stomp.Conn
}

func NewStompConnection(username string, password string,
stompUrl url.URL, topic string) *StompSession {
stompUrl url.URL, topic string, cert ...tls.Certificate) *StompSession {
session := StompSession{
username: username,
password: password,
stompUrl: stompUrl,
topic: topic,
cert: cert,
}

session.handleReconnect()
Expand All @@ -72,8 +93,7 @@ func (session *StompSession) handleReconnect() {
reconnectLoop:
for {
// Start a new session
conn, err := stomp.Dial("tcp", session.stompUrl.String(),
stomp.ConnOpt.Login(session.username, session.password))
conn, err := GetStompConnection(session)
if err == nil {
session.conn = conn
break reconnectLoop
Expand All @@ -84,6 +104,18 @@ reconnectLoop:
}
}

func GetStompConnection(session *StompSession) (*stomp.Conn, error) {
if session.cert != nil {
netConn, err := tls.Dial("tcp", session.stompUrl.String(), &tls.Config{Certificates: session.cert})
if err != nil {
log.Errorln("Failed to connect using TLS:", err.Error())
}
return stomp.Connect(netConn)
}
cfg := stomp.ConnOpt.Login(session.username, session.password)
return stomp.Dial("tcp", session.stompUrl.String(), cfg)
}

// publish will send the message to the stomp message bus
// It will also handle any error in sending by calling handleReconnect
func (session *StompSession) publish(msg []byte) {
Expand Down
7 changes: 3 additions & 4 deletions verify_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package main

import (
"math/rand"
"testing"
"github.com/stretchr/testify/assert"
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"math/rand"
"testing"
"time"
)

Expand All @@ -26,7 +26,6 @@ func TestGoodVerify(t *testing.T) {

assert.True(t, verifyPacket(buf.Bytes()), "Failed to verify packet")


}

// TestBadVerify tests the validation if the packets are not good (random bits)
Expand Down

0 comments on commit 4ab70ab

Please sign in to comment.