From 6dbdc5a5fba69e31a2fc26820ece1b045331bada Mon Sep 17 00:00:00 2001 From: Arunvel Sriram Date: Fri, 14 May 2021 19:46:42 +0530 Subject: [PATCH] Move SSH related code to client package --- pkg/client/sftp_client.go | 22 +----- pkg/client/ssh.go | 85 ++++++++++++++++++++++ pkg/client/ssh_test.go | 127 +++++++++++++++++++++++++++++++++ pkg/internal/mocks/ssh_keys.go | 17 +++++ pkg/server/server.go | 2 +- pkg/utils/utils.go | 53 -------------- pkg/utils/utils_test.go | 110 ---------------------------- 7 files changed, 233 insertions(+), 183 deletions(-) create mode 100644 pkg/client/ssh.go create mode 100644 pkg/client/ssh_test.go diff --git a/pkg/client/sftp_client.go b/pkg/client/sftp_client.go index 14c8f01..10b0e41 100644 --- a/pkg/client/sftp_client.go +++ b/pkg/client/sftp_client.go @@ -1,10 +1,6 @@ package client import ( - "fmt" - - "github.com/arunvelsriram/sftp-exporter/pkg/config" - "github.com/arunvelsriram/sftp-exporter/pkg/utils" "github.com/kr/fs" "github.com/pkg/sftp" log "github.com/sirupsen/logrus" @@ -22,7 +18,6 @@ type ( sftpClient struct { *sftp.Client sshClient *ssh.Client - config config.Config } ) @@ -41,18 +36,7 @@ func (s *sftpClient) Close() error { } func (s *sftpClient) Connect() (err error) { - addr := fmt.Sprintf("%s:%d", s.config.GetSFTPHost(), s.config.GetSFTPPort()) - auth, err := utils.SSHAuthMethods(s.config.GetSFTPPass(), s.config.GetSFTPKey(), s.config.GetSFTPKeyPassphrase()) - if err != nil { - log.Error("unable to get SSH auth methods") - return err - } - clientConfig := &ssh.ClientConfig{ - User: s.config.GetSFTPUser(), - Auth: auth, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - s.sshClient, err = ssh.Dial("tcp", addr, clientConfig) + s.sshClient, err = NewSSHClient() if err != nil { return err } @@ -69,6 +53,6 @@ func (s *sftpClient) Connect() (err error) { return nil } -func NewSFTPClient(cfg config.Config) SFTPClient { - return &sftpClient{config: cfg} +func NewSFTPClient() SFTPClient { + return &sftpClient{} } diff --git a/pkg/client/ssh.go b/pkg/client/ssh.go new file mode 100644 index 0000000..701c9ab --- /dev/null +++ b/pkg/client/ssh.go @@ -0,0 +1,85 @@ +package client + +import ( + "encoding/base64" + "fmt" + + "github.com/arunvelsriram/sftp-exporter/pkg/constants/viperkeys" + log "github.com/sirupsen/logrus" + "github.com/spf13/viper" + "golang.org/x/crypto/ssh" +) + +func parsePrivateKey(key, keyPassphrase []byte) (parsedKey ssh.Signer, err error) { + if len(keyPassphrase) > 0 { + log.Debug("key has passphrase") + parsedKey, err = ssh.ParsePrivateKeyWithPassphrase(key, keyPassphrase) + if err != nil { + log.Error("failed to parse key with passphrase") + return nil, err + } + return parsedKey, err + } + + log.Debug("key has no passphrase") + parsedKey, err = ssh.ParsePrivateKey(key) + if err != nil { + log.Error("failed to parse key") + return nil, err + } + return parsedKey, err +} + +func sshAuthMethods() ([]ssh.AuthMethod, error) { + password := viper.GetString(viperkeys.SFTPPassword) + encodedKey := viper.GetString(viperkeys.SFTPKey) + key, err := base64.StdEncoding.DecodeString(encodedKey) + if err != nil { + return nil, err + } + keyPassphrase := []byte(viper.GetString(viperkeys.SFTPKeyPassphrase)) + + if len(password) > 0 && len(key) > 0 { + log.Debug("key and password are provided") + parsedKey, err := parsePrivateKey(key, keyPassphrase) + if err != nil { + return nil, err + } + return []ssh.AuthMethod{ + ssh.PublicKeys(parsedKey), + ssh.Password(password), + }, nil + + } else if len(password) > 0 { + log.Debug("password is provided") + return []ssh.AuthMethod{ + ssh.Password(password), + }, nil + } else if len(key) > 0 { + log.Debug("key is provided") + parsedKey, err := parsePrivateKey(key, keyPassphrase) + if err != nil { + return nil, err + } + return []ssh.AuthMethod{ + ssh.PublicKeys(parsedKey), + }, nil + } + + log.Debug("both password and key are not provided") + return nil, fmt.Errorf("failed to determine the SSH authentication methods to use") +} + +func NewSSHClient() (*ssh.Client, error) { + addr := fmt.Sprintf("%s:%d", viper.GetString(viperkeys.SFTPHost), viper.GetInt(viperkeys.SFTPPort)) + auth, err := sshAuthMethods() + if err != nil { + return nil, err + } + clientConfig := &ssh.ClientConfig{ + User: viper.GetString(viperkeys.SFTPUser), + Auth: auth, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + return ssh.Dial("tcp", addr, clientConfig) +} diff --git a/pkg/client/ssh_test.go b/pkg/client/ssh_test.go new file mode 100644 index 0000000..e433a2e --- /dev/null +++ b/pkg/client/ssh_test.go @@ -0,0 +1,127 @@ +package client + +import ( + "encoding/base64" + "fmt" + "reflect" + "runtime" + "testing" + + "github.com/arunvelsriram/sftp-exporter/pkg/constants/viperkeys" + "github.com/arunvelsriram/sftp-exporter/pkg/internal/mocks" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" +) + +func TestSSHAuthMethods(t *testing.T) { + tests := []struct { + desc string + password string + key string + keyPassphrase string + authMethods []ssh.AuthMethod + err error + }{ + { + desc: "should return error when key with invalid encoding is provided", + password: "", + key: "key-invalid-encoded", + keyPassphrase: "", + authMethods: nil, + err: base64.CorruptInputError(3), + }, + { + desc: "should return auth methods when password and key are given", + password: "password", + key: mocks.EncodedSSHKeyWithoutPassphrase(), + keyPassphrase: "", + authMethods: []ssh.AuthMethod{ssh.PublicKeys(), ssh.Password("password")}, + err: nil, + }, + { + desc: "should get auth method when password is given", + password: "password", + key: "", + keyPassphrase: "", + authMethods: []ssh.AuthMethod{ssh.Password("password")}, + err: nil, + }, + { + desc: "should return auth method when key is given", + password: "", + key: mocks.EncodedSSHKeyWithoutPassphrase(), + keyPassphrase: "", + authMethods: []ssh.AuthMethod{ssh.PublicKeys()}, + err: nil, + }, + { + desc: "should return error when both password and key are empty", + password: "", + key: "", + keyPassphrase: "", + authMethods: nil, + err: fmt.Errorf("failed to determine the SSH authentication methods to use"), + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + viper.Set(viperkeys.SFTPPassword, test.password) + viper.Set(viperkeys.SFTPKey, test.key) + viper.Set(viperkeys.SFTPKeyPassphrase, test.keyPassphrase) + + authMethods, err := sshAuthMethods() + + assert.Len(t, authMethods, len(test.authMethods)) + for i, expectedAuthMethod := range test.authMethods { + expected := runtime.FuncForPC(reflect.ValueOf(expectedAuthMethod).Pointer()).Name() + actual := runtime.FuncForPC(reflect.ValueOf(authMethods[i]).Pointer()).Name() + assert.Equal(t, expected, actual) + } + assert.Equal(t, test.err, err) + }) + } +} + +func TestParsePrivateKey(t *testing.T) { + tests := []struct { + desc string + key []byte + keyPassphrase []byte + err error + }{ + { + desc: "should parse key", + key: mocks.SSHKeyWithoutPassphrase(), + keyPassphrase: []byte{}, + err: nil, + }, + { + desc: "should parse encrypted key", + key: mocks.SSHKeyWithPassphrase(), + keyPassphrase: []byte(mocks.KeyPassphrase), + err: nil, + }, + { + desc: "should return when invalid key is given", + key: []byte("invalid-key"), + keyPassphrase: []byte(""), + err: fmt.Errorf("ssh: no key found"), + }, + { + desc: "should return error when wrong passphrase is given", + key: mocks.SSHKeyWithPassphrase(), + keyPassphrase: []byte("invalid-passphrase"), + err: fmt.Errorf("x509: decryption password incorrect"), + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, err := parsePrivateKey(test.key, test.keyPassphrase) + + assert.Equal(t, test.err, err) + }) + } +} diff --git a/pkg/internal/mocks/ssh_keys.go b/pkg/internal/mocks/ssh_keys.go index b95e832..84f79c5 100644 --- a/pkg/internal/mocks/ssh_keys.go +++ b/pkg/internal/mocks/ssh_keys.go @@ -1,5 +1,7 @@ package mocks +import "encoding/base64" + func SSHKeyWithoutPassphrase() []byte { return []byte(`-----BEGIN RSA PRIVATE KEY----- MIIJKQIBAAKCAgEA2riYo+9+vCXacILouh5uL7/chHheDFiFnFx1BOpiwuR5b1b/ @@ -54,6 +56,8 @@ aYQ5p5FzVtbes3BO1lu/nyShlWywlCRBpVCYoqcc3lD6X/2CM6doqNoxDom2mm3z -----END RSA PRIVATE KEY-----`) } +var KeyPassphrase = "password" + func SSHKeyWithPassphrase() []byte { return []byte(`-----BEGIN RSA PRIVATE KEY----- Proc-Type: 4,ENCRYPTED @@ -115,3 +119,16 @@ func SSHKeyPassphrase() []byte { return []byte("password") } +func EncodedSSHKeyWithoutPassphrase() string { + key := SSHKeyWithoutPassphrase() + return base64.StdEncoding.EncodeToString(key) +} + +func EncodedSSHKeyWithPassphrase() string { + key := SSHKeyWithPassphrase() + return base64.StdEncoding.EncodeToString(key) +} + +func EncodedInvalidSSHKey() string { + return base64.StdEncoding.EncodeToString([]byte("invalid-ssh-key")) +} diff --git a/pkg/server/server.go b/pkg/server/server.go index cc69f4f..a11b912 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -18,7 +18,7 @@ import ( ) func Start() error { - sftpClient := client.NewSFTPClient(config.NewConfig()) + sftpClient := client.NewSFTPClient() sftpService := service.NewSFTPService(config.NewConfig(), sftpClient) sftpCollector := collector.NewSFTPCollector(sftpService) prometheus.MustRegister(sftpCollector) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 0f44ec8..09e6923 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,11 +1,7 @@ package utils import ( - "fmt" "strings" - - log "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" ) func IsEmpty(s string) bool { @@ -21,52 +17,3 @@ func PanicIfErr(err error) { panic(err.Error()) } } - -func SSHAuthMethods(pass string, key, keyPassphrase []byte) ([]ssh.AuthMethod, error) { - if len(key) > 0 && IsNotEmpty(pass) { - log.Debug("will be authenticating using key and password") - parsedKey, err := parsePrivateKey(key, keyPassphrase) - if err != nil { - return nil, err - } - return []ssh.AuthMethod{ - ssh.PublicKeys(parsedKey), - ssh.Password(pass), - }, nil - } else if len(key) > 0 { - log.Debug("will be authenticating using key") - parsedKey, err := parsePrivateKey(key, keyPassphrase) - if err != nil { - return nil, err - } - return []ssh.AuthMethod{ - ssh.PublicKeys(parsedKey), - }, nil - } else if IsNotEmpty(pass) { - log.Debug("will be authenticating using password") - return []ssh.AuthMethod{ - ssh.Password(pass), - }, nil - } - return nil, fmt.Errorf("either one of password or key is required") -} - -func parsePrivateKey(key, keyPassphrase []byte) (parsedKey ssh.Signer, err error) { - if len(keyPassphrase) > 0 { - log.Debug("key has passphrase") - parsedKey, err = ssh.ParsePrivateKeyWithPassphrase(key, keyPassphrase) - if err != nil { - log.Error("failed to parse key with passphrase") - return nil, err - } - return parsedKey, err - } - - log.Debug("key has no passphrase") - parsedKey, err = ssh.ParsePrivateKey(key) - if err != nil { - log.Error("failed to parse key") - return nil, err - } - return parsedKey, err -} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 0cbfb30..6fb89fd 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -2,14 +2,10 @@ package utils_test import ( "fmt" - "reflect" - "runtime" "testing" - "github.com/arunvelsriram/sftp-exporter/pkg/internal/mocks" "github.com/arunvelsriram/sftp-exporter/pkg/utils" "github.com/stretchr/testify/assert" - "golang.org/x/crypto/ssh" ) func TestIsEmpty(t *testing.T) { @@ -75,109 +71,3 @@ func TestPanicIfErrShouldPanicForErr(t *testing.T) { func TestPanicIfErrShouldNotPanicWhenErrIsNil(t *testing.T) { assert.NotPanics(t, func() { utils.PanicIfErr(nil) }) } - -func TestSSHAuthMethods(t *testing.T) { - tests := []struct { - desc string - pass string - key []byte - keyPassphrase []byte - authMethods []ssh.AuthMethod - err error - }{ - { - desc: "should return error when both pass and key are empty", - pass: "", - key: []byte{}, - keyPassphrase: []byte{}, - authMethods: nil, - err: fmt.Errorf("either one of password or key is required"), - }, - { - desc: "should return pass and key auth methods when pass and key are provided", - pass: "password", - key: mocks.SSHKeyWithoutPassphrase(), - keyPassphrase: []byte{}, - authMethods: []ssh.AuthMethod{ssh.PublicKeys(), ssh.Password("password")}, - err: nil, - }, - { - desc: "should return pass and key auth methods when pass and encrypted key are provided", - pass: "password", - key: mocks.SSHKeyWithPassphrase(), - keyPassphrase: mocks.SSHKeyPassphrase(), - authMethods: []ssh.AuthMethod{ssh.PublicKeys(), ssh.Password("password")}, - err: nil, - }, - { - desc: "should return error when pass and invalid key are provided", - pass: "password", - key: []byte("invalid-key"), - keyPassphrase: []byte{}, - authMethods: nil, - err: fmt.Errorf("ssh: no key found"), - }, - { - desc: "should return error when pass and wrong key passphrase are provided", - pass: "password", - key: mocks.SSHKeyWithPassphrase(), - keyPassphrase: []byte("wrong-passphrase"), - authMethods: nil, - err: fmt.Errorf("x509: decryption password incorrect"), - }, - { - desc: "should return key auth method when only key is provided", - pass: "", - key: mocks.SSHKeyWithoutPassphrase(), - keyPassphrase: []byte{}, - authMethods: []ssh.AuthMethod{ssh.PublicKeys()}, - err: nil, - }, - { - desc: "should return key auth method when only encrypted key is provided", - pass: "", - key: mocks.SSHKeyWithPassphrase(), - keyPassphrase: mocks.SSHKeyPassphrase(), - authMethods: []ssh.AuthMethod{ssh.PublicKeys()}, - err: nil, - }, - { - desc: "should return error when key is invalid", - pass: "", - key: []byte("invalid-key"), - keyPassphrase: []byte{}, - authMethods: nil, - err: fmt.Errorf("ssh: no key found"), - }, - { - desc: "should return error when key passphrase is wrong", - pass: "", - key: mocks.SSHKeyWithPassphrase(), - keyPassphrase: []byte("wrong-passphrase"), - authMethods: nil, - err: fmt.Errorf("x509: decryption password incorrect"), - }, - { - desc: "should return pass auth method when only pass is provided", - pass: "password", - key: []byte{}, - keyPassphrase: []byte{}, - authMethods: []ssh.AuthMethod{ssh.Password("password")}, - err: nil, - }, - } - - for _, test := range tests { - t.Run(test.desc, func(t *testing.T) { - authMethods, err := utils.SSHAuthMethods(test.pass, test.key, test.keyPassphrase) - - assert.Len(t, authMethods, len(test.authMethods)) - for i, expectedAuthMethod := range test.authMethods { - expected := runtime.FuncForPC(reflect.ValueOf(expectedAuthMethod).Pointer()).Name() - actual := runtime.FuncForPC(reflect.ValueOf(authMethods[i]).Pointer()).Name() - assert.Equal(t, expected, actual) - } - assert.Equal(t, test.err, err) - }) - } -}