Skip to content

Commit

Permalink
Merge pull request #25 from jaksi/update
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
jaksi authored Aug 3, 2024
2 parents 7ef7bd5 + 113f03b commit 8e506f7
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 84 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/pr-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: golangci/golangci-lint-action@v3
- uses: golangci/golangci-lint-action@v6
test:
name: Test
strategy:
Expand All @@ -21,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
- uses: actions/setup-go@v5
with:
go-version: ^1.19
go-version: ^1.22
- run: go test -race -timeout 1m
44 changes: 31 additions & 13 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,36 @@ package sshutils

import (
"encoding/hex"
"errors"
"fmt"
"net"
"strconv"

"golang.org/x/crypto/ssh"
)

var (
ErrEstablishSSH = errors.New("failed to establish SSH connection")
ErrSendRequest = errors.New("failed to send request")
ErrChannelOpen = errors.New("failed to open channel")
)
type EstablishError struct {
err error
}

func (e EstablishError) Error() string {
return fmt.Sprintf("failed to establish SSH connection: %v", e.err)
}

type SendRequestError struct {
err error
}

func (e SendRequestError) Error() string {
return fmt.Sprintf("failed to send request: %v", e.err)
}

type ChannelOpenError struct {
err error
}

func (e ChannelOpenError) Error() string {
return fmt.Sprintf("failed to open channel: %v", e.err)
}

type Listener struct {
net.Listener
Expand All @@ -28,7 +46,7 @@ func (listener *Listener) Accept() (*Conn, error) {
sshConn, sshNewChannels, sshRequests, err := ssh.NewServerConn(conn, &listener.config)
if err != nil {
conn.Close()
return nil, fmt.Errorf("%w: %v", ErrEstablishSSH, err)
return nil, EstablishError{err}
}
return handleConn(sshConn, sshNewChannels, sshRequests), nil
}
Expand All @@ -51,7 +69,7 @@ type Conn struct {
func (conn *Conn) RawChannel(name string, payload []byte) (*Channel, error) {
sshChannel, sshRequests, err := conn.Conn.OpenChannel(name, payload)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrChannelOpen, err)
return nil, ChannelOpenError{err}
}
return handleChannel(sshChannel, sshRequests, conn, name), nil
}
Expand All @@ -67,7 +85,7 @@ func (conn *Conn) Channel(name string, payload Payload) (*Channel, error) {
func (conn *Conn) RawRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
accepted, reply, err := conn.SendRequest(name, wantReply, payload)
if err != nil {
return false, nil, fmt.Errorf("%w: %v", ErrSendRequest, err)
return false, nil, SendRequestError{err}
}
return accepted, reply, nil
}
Expand All @@ -92,7 +110,7 @@ func Dial(address string, config *ssh.ClientConfig) (*Conn, error) {
sshConn, sshNewChannels, sshRequests, err := ssh.NewClientConn(conn, address, config)
if err != nil {
conn.Close()
return nil, fmt.Errorf("%w: %v", ErrEstablishSSH, err)
return nil, EstablishError{err}
}
return handleConn(sshConn, sshNewChannels, sshRequests), nil
}
Expand Down Expand Up @@ -137,7 +155,7 @@ type NewChannel struct {
func (newChannel *NewChannel) AcceptChannel() (*Channel, error) {
sshChannel, sshRequests, err := newChannel.Accept()
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrChannelOpen, err)
return nil, ChannelOpenError{err}
}
return handleChannel(sshChannel, sshRequests, newChannel.conn, newChannel.ChannelType()), nil
}
Expand Down Expand Up @@ -177,7 +195,7 @@ func (channel *Channel) ConnMetadata() ssh.ConnMetadata {
func (channel *Channel) RawRequest(name string, wantReply bool, payload []byte) (bool, error) {
accepted, err := channel.SendRequest(name, wantReply, payload)
if err != nil {
return false, fmt.Errorf("%w: %v", ErrSendRequest, err)
return false, SendRequestError{err}
}
return accepted, nil
}
Expand All @@ -196,7 +214,7 @@ func (channel *Channel) String() string {

func handleChannel(sshChannel ssh.Channel, sshRequests <-chan *ssh.Request, conn *Conn, name string) *Channel {
requests := make(chan *ChannelRequest)
channel := &Channel{sshChannel, requests, fmt.Sprint(conn.nextChannelID), name, conn}
channel := &Channel{sshChannel, requests, strconv.FormatInt(int64(conn.nextChannelID), 10), name, conn}
go func() {
for request := range sshRequests {
requests <- &ChannelRequest{request, channel}
Expand Down
1 change: 1 addition & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"
"testing"

//nolint:depguard
"github.com/jaksi/sshutils"
"golang.org/x/crypto/ssh"
)
Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module github.com/jaksi/sshutils

go 1.19
go 1.22

require golang.org/x/crypto v0.5.0
require golang.org/x/crypto v0.25.0

require golang.org/x/sys v0.4.0 // indirect
require golang.org/x/sys v0.22.0 // indirect
11 changes: 6 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
45 changes: 31 additions & 14 deletions host_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"os"
Expand All @@ -20,11 +19,29 @@ const (
hostKeyFilePerms = 0o600
)

var (
ErrInvalidKey = errors.New("invalid key")
ErrInvalidKeyFile = errors.New("invalid key file")
ErrUnsupportedKeyType = errors.New("unsupported key type")
)
type InvalidKeyError struct {
err error
}

func (e InvalidKeyError) Error() string {
return fmt.Sprintf("invalid key: %v", e.err)
}

type InvalidKeyFileError struct {
err error
}

func (e InvalidKeyFileError) Error() string {
return fmt.Sprintf("invalid key file: %v", e.err)
}

type UnsupportedKeyTypeError struct {
t KeyType
}

func (e UnsupportedKeyTypeError) Error() string {
return fmt.Sprintf("unsupported key type: %v", e.t)
}

type KeyType int

Expand Down Expand Up @@ -59,7 +76,7 @@ func (key *HostKey) String() string {
func hostKeyFromKey(key interface{}) (*HostKey, error) {
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidKey, err)
return nil, InvalidKeyError{err}
}
return &HostKey{
Signer: signer,
Expand All @@ -78,42 +95,42 @@ func GenerateHostKey(rand io.Reader, t KeyType) (*HostKey, error) {
case Ed25519:
_, key, err = ed25519.GenerateKey(rand)
default:
return nil, fmt.Errorf("%w: %v", ErrUnsupportedKeyType, t)
return nil, UnsupportedKeyTypeError{t}
}
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidKey, err)
return nil, InvalidKeyError{err}
}
return hostKeyFromKey(key)
}

func LoadHostKey(fileName string) (*HostKey, error) {
keyBytes, err := os.ReadFile(fileName)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidKeyFile, err)
return nil, InvalidKeyFileError{err}
}
key, err := ssh.ParseRawPrivateKey(keyBytes)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidKeyFile, err)
return nil, InvalidKeyFileError{err}
}
return hostKeyFromKey(key)
}

func (key *HostKey) Save(fileName string) error {
file, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_EXCL, hostKeyFilePerms) //nolint:nosnakecase
if err != nil {
return fmt.Errorf("%w: %v", ErrInvalidKeyFile, err)
return InvalidKeyFileError{err}
}
defer file.Close()
keyBytes, err := x509.MarshalPKCS8PrivateKey(key.key)
if err != nil {
return fmt.Errorf("%w: %v", ErrInvalidKey, err)
return InvalidKeyError{err}
}
if _, err = file.Write(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Headers: nil,
Bytes: keyBytes,
})); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidKeyFile, err)
return InvalidKeyFileError{err}
}
return nil
}
2 changes: 1 addition & 1 deletion host_keys_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"
"testing"

//nolint:depguard
"github.com/jaksi/sshutils"
"golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -50,7 +51,6 @@ func TestGenerateHostKey(t *testing.T) {
"unsupported key type: unknown type (42)",
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if tt.keyType.String() != tt.keyTypeString {
Expand Down
Loading

0 comments on commit 8e506f7

Please sign in to comment.