Skip to content

Commit

Permalink
Move SSH related code to client package
Browse files Browse the repository at this point in the history
  • Loading branch information
arunvelsriram committed May 14, 2021
1 parent 177004b commit 6dbdc5a
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 183 deletions.
22 changes: 3 additions & 19 deletions pkg/client/sftp_client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package client

import (
"fmt"

"github.com/arunvelsriram/sftp-exporter/pkg/config"
"github.com/arunvelsriram/sftp-exporter/pkg/utils"
"github.com/kr/fs"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
Expand All @@ -22,7 +18,6 @@ type (
sftpClient struct {
*sftp.Client
sshClient *ssh.Client
config config.Config
}
)

Expand All @@ -41,18 +36,7 @@ func (s *sftpClient) Close() error {
}

func (s *sftpClient) Connect() (err error) {
addr := fmt.Sprintf("%s:%d", s.config.GetSFTPHost(), s.config.GetSFTPPort())
auth, err := utils.SSHAuthMethods(s.config.GetSFTPPass(), s.config.GetSFTPKey(), s.config.GetSFTPKeyPassphrase())
if err != nil {
log.Error("unable to get SSH auth methods")
return err
}
clientConfig := &ssh.ClientConfig{
User: s.config.GetSFTPUser(),
Auth: auth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
s.sshClient, err = ssh.Dial("tcp", addr, clientConfig)
s.sshClient, err = NewSSHClient()
if err != nil {
return err
}
Expand All @@ -69,6 +53,6 @@ func (s *sftpClient) Connect() (err error) {
return nil
}

func NewSFTPClient(cfg config.Config) SFTPClient {
return &sftpClient{config: cfg}
func NewSFTPClient() SFTPClient {
return &sftpClient{}
}
85 changes: 85 additions & 0 deletions pkg/client/ssh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package client

import (
"encoding/base64"
"fmt"

"github.com/arunvelsriram/sftp-exporter/pkg/constants/viperkeys"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
"golang.org/x/crypto/ssh"
)

func parsePrivateKey(key, keyPassphrase []byte) (parsedKey ssh.Signer, err error) {
if len(keyPassphrase) > 0 {
log.Debug("key has passphrase")
parsedKey, err = ssh.ParsePrivateKeyWithPassphrase(key, keyPassphrase)
if err != nil {
log.Error("failed to parse key with passphrase")
return nil, err
}
return parsedKey, err
}

log.Debug("key has no passphrase")
parsedKey, err = ssh.ParsePrivateKey(key)
if err != nil {
log.Error("failed to parse key")
return nil, err
}
return parsedKey, err
}

func sshAuthMethods() ([]ssh.AuthMethod, error) {
password := viper.GetString(viperkeys.SFTPPassword)
encodedKey := viper.GetString(viperkeys.SFTPKey)
key, err := base64.StdEncoding.DecodeString(encodedKey)
if err != nil {
return nil, err
}
keyPassphrase := []byte(viper.GetString(viperkeys.SFTPKeyPassphrase))

if len(password) > 0 && len(key) > 0 {
log.Debug("key and password are provided")
parsedKey, err := parsePrivateKey(key, keyPassphrase)
if err != nil {
return nil, err
}
return []ssh.AuthMethod{
ssh.PublicKeys(parsedKey),
ssh.Password(password),
}, nil

} else if len(password) > 0 {
log.Debug("password is provided")
return []ssh.AuthMethod{
ssh.Password(password),
}, nil
} else if len(key) > 0 {
log.Debug("key is provided")
parsedKey, err := parsePrivateKey(key, keyPassphrase)
if err != nil {
return nil, err
}
return []ssh.AuthMethod{
ssh.PublicKeys(parsedKey),
}, nil
}

log.Debug("both password and key are not provided")
return nil, fmt.Errorf("failed to determine the SSH authentication methods to use")
}

func NewSSHClient() (*ssh.Client, error) {
addr := fmt.Sprintf("%s:%d", viper.GetString(viperkeys.SFTPHost), viper.GetInt(viperkeys.SFTPPort))
auth, err := sshAuthMethods()
if err != nil {
return nil, err
}
clientConfig := &ssh.ClientConfig{
User: viper.GetString(viperkeys.SFTPUser),
Auth: auth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
return ssh.Dial("tcp", addr, clientConfig)
}
127 changes: 127 additions & 0 deletions pkg/client/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package client

import (
"encoding/base64"
"fmt"
"reflect"
"runtime"
"testing"

"github.com/arunvelsriram/sftp-exporter/pkg/constants/viperkeys"
"github.com/arunvelsriram/sftp-exporter/pkg/internal/mocks"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
)

func TestSSHAuthMethods(t *testing.T) {
tests := []struct {
desc string
password string
key string
keyPassphrase string
authMethods []ssh.AuthMethod
err error
}{
{
desc: "should return error when key with invalid encoding is provided",
password: "",
key: "key-invalid-encoded",
keyPassphrase: "",
authMethods: nil,
err: base64.CorruptInputError(3),
},
{
desc: "should return auth methods when password and key are given",
password: "password",
key: mocks.EncodedSSHKeyWithoutPassphrase(),
keyPassphrase: "",
authMethods: []ssh.AuthMethod{ssh.PublicKeys(), ssh.Password("password")},
err: nil,
},
{
desc: "should get auth method when password is given",
password: "password",
key: "",
keyPassphrase: "",
authMethods: []ssh.AuthMethod{ssh.Password("password")},
err: nil,
},
{
desc: "should return auth method when key is given",
password: "",
key: mocks.EncodedSSHKeyWithoutPassphrase(),
keyPassphrase: "",
authMethods: []ssh.AuthMethod{ssh.PublicKeys()},
err: nil,
},
{
desc: "should return error when both password and key are empty",
password: "",
key: "",
keyPassphrase: "",
authMethods: nil,
err: fmt.Errorf("failed to determine the SSH authentication methods to use"),
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
viper.Set(viperkeys.SFTPPassword, test.password)
viper.Set(viperkeys.SFTPKey, test.key)
viper.Set(viperkeys.SFTPKeyPassphrase, test.keyPassphrase)

authMethods, err := sshAuthMethods()

assert.Len(t, authMethods, len(test.authMethods))
for i, expectedAuthMethod := range test.authMethods {
expected := runtime.FuncForPC(reflect.ValueOf(expectedAuthMethod).Pointer()).Name()
actual := runtime.FuncForPC(reflect.ValueOf(authMethods[i]).Pointer()).Name()
assert.Equal(t, expected, actual)
}
assert.Equal(t, test.err, err)
})
}
}

func TestParsePrivateKey(t *testing.T) {
tests := []struct {
desc string
key []byte
keyPassphrase []byte
err error
}{
{
desc: "should parse key",
key: mocks.SSHKeyWithoutPassphrase(),
keyPassphrase: []byte{},
err: nil,
},
{
desc: "should parse encrypted key",
key: mocks.SSHKeyWithPassphrase(),
keyPassphrase: []byte(mocks.KeyPassphrase),
err: nil,
},
{
desc: "should return when invalid key is given",
key: []byte("invalid-key"),
keyPassphrase: []byte(""),
err: fmt.Errorf("ssh: no key found"),
},
{
desc: "should return error when wrong passphrase is given",
key: mocks.SSHKeyWithPassphrase(),
keyPassphrase: []byte("invalid-passphrase"),
err: fmt.Errorf("x509: decryption password incorrect"),
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
_, err := parsePrivateKey(test.key, test.keyPassphrase)

assert.Equal(t, test.err, err)
})
}
}
17 changes: 17 additions & 0 deletions pkg/internal/mocks/ssh_keys.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package mocks

import "encoding/base64"

func SSHKeyWithoutPassphrase() []byte {
return []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIJKQIBAAKCAgEA2riYo+9+vCXacILouh5uL7/chHheDFiFnFx1BOpiwuR5b1b/
Expand Down Expand Up @@ -54,6 +56,8 @@ aYQ5p5FzVtbes3BO1lu/nyShlWywlCRBpVCYoqcc3lD6X/2CM6doqNoxDom2mm3z
-----END RSA PRIVATE KEY-----`)
}

var KeyPassphrase = "password"

func SSHKeyWithPassphrase() []byte {
return []byte(`-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
Expand Down Expand Up @@ -115,3 +119,16 @@ func SSHKeyPassphrase() []byte {
return []byte("password")
}

func EncodedSSHKeyWithoutPassphrase() string {
key := SSHKeyWithoutPassphrase()
return base64.StdEncoding.EncodeToString(key)
}

func EncodedSSHKeyWithPassphrase() string {
key := SSHKeyWithPassphrase()
return base64.StdEncoding.EncodeToString(key)
}

func EncodedInvalidSSHKey() string {
return base64.StdEncoding.EncodeToString([]byte("invalid-ssh-key"))
}
2 changes: 1 addition & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
)

func Start() error {
sftpClient := client.NewSFTPClient(config.NewConfig())
sftpClient := client.NewSFTPClient()
sftpService := service.NewSFTPService(config.NewConfig(), sftpClient)
sftpCollector := collector.NewSFTPCollector(sftpService)
prometheus.MustRegister(sftpCollector)
Expand Down
53 changes: 0 additions & 53 deletions pkg/utils/utils.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package utils

import (
"fmt"
"strings"

log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

func IsEmpty(s string) bool {
Expand All @@ -21,52 +17,3 @@ func PanicIfErr(err error) {
panic(err.Error())
}
}

func SSHAuthMethods(pass string, key, keyPassphrase []byte) ([]ssh.AuthMethod, error) {
if len(key) > 0 && IsNotEmpty(pass) {
log.Debug("will be authenticating using key and password")
parsedKey, err := parsePrivateKey(key, keyPassphrase)
if err != nil {
return nil, err
}
return []ssh.AuthMethod{
ssh.PublicKeys(parsedKey),
ssh.Password(pass),
}, nil
} else if len(key) > 0 {
log.Debug("will be authenticating using key")
parsedKey, err := parsePrivateKey(key, keyPassphrase)
if err != nil {
return nil, err
}
return []ssh.AuthMethod{
ssh.PublicKeys(parsedKey),
}, nil
} else if IsNotEmpty(pass) {
log.Debug("will be authenticating using password")
return []ssh.AuthMethod{
ssh.Password(pass),
}, nil
}
return nil, fmt.Errorf("either one of password or key is required")
}

func parsePrivateKey(key, keyPassphrase []byte) (parsedKey ssh.Signer, err error) {
if len(keyPassphrase) > 0 {
log.Debug("key has passphrase")
parsedKey, err = ssh.ParsePrivateKeyWithPassphrase(key, keyPassphrase)
if err != nil {
log.Error("failed to parse key with passphrase")
return nil, err
}
return parsedKey, err
}

log.Debug("key has no passphrase")
parsedKey, err = ssh.ParsePrivateKey(key)
if err != nil {
log.Error("failed to parse key")
return nil, err
}
return parsedKey, err
}
Loading

0 comments on commit 6dbdc5a

Please sign in to comment.