Skip to content

Commit

Permalink
[relay] Code cleaning (#3074)
Browse files Browse the repository at this point in the history
- Keep message byte processing in message.go file
- Add new unit tests
  • Loading branch information
pappz authored Jan 15, 2025
1 parent b34887a commit 6a6b527
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 66 deletions.
13 changes: 8 additions & 5 deletions relay/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("validate version: %w", err)
}

msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
return err
Expand All @@ -317,7 +317,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("unexpected message type")
}

addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil {
return err
}
Expand Down Expand Up @@ -348,24 +348,27 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
c.bufPool.Put(bufPtr)
break
}

_, err := messages.ValidateVersion(buf[:n])
buf = buf[:n]

_, err := messages.ValidateVersion(buf)
if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr)
continue
}

msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf)
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}

if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
break
}
}
Expand Down
99 changes: 50 additions & 49 deletions relay/messages/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,26 @@ const (
MsgTypeAuth = 6
MsgTypeAuthResponse = 7

SizeOfVersionByte = 1
SizeOfMsgType = 1

SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType

sizeOfMagicByte = 4

headerSizeTransport = IDSize

// base size of the message
sizeOfVersionByte = 1
sizeOfMsgType = 1
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType

// auth message
sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize
offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth

// hello message
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0

headerSizeAuth = sizeOfMagicByte + IDSize
headerSizeAuthResp = 0
// transport
headerSizeTransport = IDSize
offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
)

var (
Expand Down Expand Up @@ -73,7 +79,7 @@ func (m MsgType) String() string {

// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
version := int(msg[0])
Expand All @@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {

// DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}

msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHello,
Expand All @@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {

// DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}

msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHelloResponse,
Expand All @@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}

msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)

copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)

msg = append(msg, peerID...)
msg = append(msg, additions...)
Expand All @@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}

return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
}

// Deprecated: Use MarshalAuthResponse instead.
Expand All @@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
Expand All @@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
return nil, ErrInvalidMessageLength
}
return msg, nil
Expand All @@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}

msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)

copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfProtoHeader:], magicHeader)

msg = append(msg, peerID...)
msg = append(msg, authPayload...)
Expand All @@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {

// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}

return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
}

// MarshalAuthResponse creates a response message to the auth.
Expand All @@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
Expand All @@ -243,69 +249,64 @@ func MarshalAuthResponse(address string) ([]byte, error) {

// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 {
if len(msg) < sizeOfProtoHeader+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
return string(msg[sizeOfProtoHeader:]), nil
}

// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte {
msg := make([]byte, SizeOfProtoHeader)

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeClose)

return msg
return []byte{
byte(CurrentProtocolVersion),
byte(MsgTypeClose),
}
}

// MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID.
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}

msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))

msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport)

copy(msg[SizeOfProtoHeader:], peerID)

copy(msg[sizeOfProtoHeader:], peerID)
msg = append(msg, payload...)

return msg, nil
}

// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
if len(buf) < headerSizeTransport {
if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength
}

return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
}

// UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport {
if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], nil
return buf[offsetTransportID:headerTotalSizeTransport], nil
}

// UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < len(peerID) {
if len(msg) < offsetTransportID+len(peerID) {
return ErrInvalidMessageLength
}
copy(msg, peerID)
copy(msg[offsetTransportID:], peerID)
return nil
}

Expand Down
Loading

0 comments on commit 6a6b527

Please sign in to comment.