Skip to content

Commit

Permalink
incorporate suggestions and feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
nitram509 committed Oct 1, 2024
1 parent 374c61e commit 4abc7d4
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 35 deletions.
3 changes: 1 addition & 2 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package keystore

import (
"crypto/rand"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -58,6 +57,6 @@ func TestPasswordBytes(t *testing.T) {

for _, tt := range table {
output := passwordBytes(tt.input)
assert.Truef(t, reflect.DeepEqual(output, tt.output), "convert password bytes '%v', '%v'", output, tt.output)
assert.Equal(t, tt.output, output, "convert password bytes")
}
}
4 changes: 2 additions & 2 deletions decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (d decoder) readPrivateKeyEntry(version uint32) (PrivateKeyEntry, error) {
chain = append(chain, cert)
}

creationDateTime := time.UnixMilli(int64(creationTimeStamp)) //nolint:all
creationDateTime := time.UnixMilli(int64(creationTimeStamp)) //nolint:gosec
privateKeyEntry := PrivateKeyEntry{
PrivateKey: encryptedPrivateKey,
CreationTime: creationDateTime,
Expand All @@ -149,7 +149,7 @@ func (d decoder) readTrustedCertificateEntry(version uint32) (TrustedCertificate
return TrustedCertificateEntry{}, fmt.Errorf("read certificate: %w", err)
}

creationDateTime := time.UnixMilli(int64(creationTimeStamp)) //nolint:all
creationDateTime := time.UnixMilli(int64(creationTimeStamp)) //nolint:gosec
trustedCertificateEntry := TrustedCertificateEntry{
CreationTime: creationDateTime,
Certificate: certificate,
Expand Down
29 changes: 14 additions & 15 deletions decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"io"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -72,14 +71,14 @@ func TestReadUint16(t *testing.T) {
}

number, err := d.readUint16()
assert.Truef(t, reflect.DeepEqual(err, tt.err), "invalid error '%v' '%v'", err, tt.err)
assert.Equal(t, tt.err, err)

if err == nil {
assert.Equal(t, tt.number, number)
}

hash := d.h.Sum(nil)
assert.Truef(t, reflect.DeepEqual(hash, tt.hash[:]), "invalid hash '%v' '%v'", hash, tt.hash)
assert.Equal(t, tt.hash[:], hash)
}
}

Expand Down Expand Up @@ -140,14 +139,14 @@ func TestReadUint32(t *testing.T) {
}

number, err := d.readUint32()
assert.Truef(t, reflect.DeepEqual(err, tt.err), "invalid error '%v' '%v'", err, tt.err)
assert.Equal(t, tt.err, err)

if err == nil {
assert.Equal(t, tt.number, number)
}

hash := d.h.Sum(nil)
assert.Truef(t, reflect.DeepEqual(hash, tt.hash[:]), "invalid hash '%v' '%v'", hash, tt.hash)
assert.Equal(t, tt.hash[:], hash)
}
}

Expand Down Expand Up @@ -212,14 +211,14 @@ func TestReadUint64(t *testing.T) {
}

number, err := d.readUint64()
assert.Truef(t, reflect.DeepEqual(err, tt.err), "invalid error '%v' '%v'", err, tt.err)
assert.Equal(t, tt.err, err)

if err == nil {
assert.Equal(t, tt.number, number)
}

hash := d.h.Sum(nil)
assert.Truef(t, reflect.DeepEqual(hash, tt.hash[:]), "invalid hash '%v' '%v'", hash, tt.hash)
assert.Equal(t, tt.hash[:], hash)
}
}

Expand Down Expand Up @@ -278,10 +277,10 @@ func TestReadBytes(t *testing.T) {
bts, err := d.readBytes(tt.readLen)
require.NoError(t, err)

assert.Truef(t, reflect.DeepEqual(bts, tt.bytes), "invalid bytes '%v' '%v'", bts, tt.bytes)
assert.Equal(t, tt.bytes, bts)

hash := d.h.Sum(nil)
assert.Truef(t, reflect.DeepEqual(hash, tt.hash[:]), "invalid hash '%v' '%v'", hash, tt.hash)
assert.Equal(t, tt.hash[:], hash)
}
}

Expand Down Expand Up @@ -321,7 +320,7 @@ func TestReadString(t *testing.T) {
})
str := "some string to read"
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(len(str))) //nolint:all
binary.BigEndian.PutUint16(buf, uint16(len(str))) //nolint:gosec
buf = append(buf, []byte(str)...)
table = append(table, item{
input: buf,
Expand All @@ -340,11 +339,11 @@ func TestReadString(t *testing.T) {
}

str, err := d.readString()
assert.Truef(t, reflect.DeepEqual(err, tt.err), "invalid error '%v' '%v'", err, tt.err)
assert.Equal(t, tt.err, err)
assert.Equal(t, tt.string, str)

hash := d.h.Sum(nil)
assert.Truef(t, reflect.DeepEqual(hash, tt.hash[:]), "invalid hash '%v' '%v'", hash, tt.hash)
assert.Equal(t, tt.hash[:], hash)
}
}

Expand Down Expand Up @@ -439,10 +438,10 @@ func TestReadCertificate(t *testing.T) {
}

cert, err := d.readCertificate(tt.version)
assert.Truef(t, reflect.DeepEqual(err, tt.err), "invalid error '%v' '%v'", err, tt.err)
assert.Truef(t, reflect.DeepEqual(cert, tt.cert), "invalid certificate '%v' '%v'", cert, tt.cert)
assert.Equal(t, tt.err, err)
assert.Equal(t, tt.cert, cert)

hash := d.h.Sum(nil)
assert.Truef(t, reflect.DeepEqual(hash, tt.hash[:]), "invalid hash '%v' '%v'", hash, tt.hash)
assert.Equal(t, tt.hash[:], hash)
}
}
4 changes: 2 additions & 2 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (e encoder) writePrivateKeyEntry(alias string, pke PrivateKeyEntry) error {
return fmt.Errorf("write alias: %w", err)
}

if err := e.writeUint64(uint64(pke.CreationTime.UnixMilli())); err != nil { //nolint:all
if err := e.writeUint64(uint64(pke.CreationTime.UnixMilli())); err != nil { //nolint:gosec
return fmt.Errorf("write creation timestamp: %w", err)
}

Expand Down Expand Up @@ -140,7 +140,7 @@ func (e encoder) writeTrustedCertificateEntry(alias string, tce TrustedCertifica
return fmt.Errorf("write alias: %w", err)
}

if err := e.writeUint64(uint64(tce.CreationTime.UnixMilli())); err != nil { //nolint:all
if err := e.writeUint64(uint64(tce.CreationTime.UnixMilli())); err != nil { //nolint:gosec
return fmt.Errorf("write creation timestamp: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (ks KeyStore) Store(w io.Writer, password []byte) error {
return fmt.Errorf("write version: %w", err)
}

if err := e.writeUint32(uint32(len(ks.m))); err != nil { //nolint:all
if err := e.writeUint32(uint32(len(ks.m))); err != nil { //nolint:gosec
return fmt.Errorf("write number of entries: %w", err)
}

Expand Down Expand Up @@ -192,7 +192,7 @@ func (ks KeyStore) Load(r io.Reader, password []byte) error {

computedDigest := d.h.Sum(nil)

actualDigest, err := d.readBytes(uint32(d.h.Size())) //nolint:all
actualDigest, err := d.readBytes(uint32(d.h.Size())) //nolint:gosec
if err != nil {
return fmt.Errorf("read digest: %w", err)
}
Expand Down
18 changes: 6 additions & 12 deletions keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package keystore
import (
"encoding/pem"
"os"
"reflect"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -48,17 +47,15 @@ func TestSetGetMethods(t *testing.T) {

pkeGet, err := ks.GetPrivateKeyEntry(pkeAlias, password)
require.NoError(t, err)
assert.Equal(t, pke, pkeGet)

chainGet, err := ks.GetPrivateKeyEntryCertificateChain(pkeAlias)
require.NoError(t, err)
assert.Equal(t, pke.CertificateChain, chainGet)

tceGet, err := ks.GetTrustedCertificateEntry(tceAlias)
require.NoError(t, err)

assert.True(t, reflect.DeepEqual(pke, pkeGet), "private key entries not equal")
assert.True(t, reflect.DeepEqual(pke.CertificateChain, chainGet),
"certificate chains of private key entries are not equal")
assert.True(t, reflect.DeepEqual(tce, tceGet), "private key entries not equal")
assert.Equal(t, tce, tceGet)

_, err = ks.GetPrivateKeyEntry(nonExistentAlias, password)
require.ErrorIs(t, err, ErrEntryNotFound)
Expand Down Expand Up @@ -139,14 +136,12 @@ func TestAliases(t *testing.T) {
require.NoError(t, err)

expectedAliases := []string{pkeAlias, tceAlias}

sort.Strings(expectedAliases)

actualAliases := ks.Aliases()

sort.Strings(actualAliases)

assert.True(t, reflect.DeepEqual(expectedAliases, actualAliases), "aliases must be equal")
assert.Equal(t, expectedAliases, actualAliases)
}

func TestLoad(t *testing.T) {
Expand Down Expand Up @@ -182,7 +177,7 @@ func TestLoad(t *testing.T) {

decodedPK, _ := pem.Decode(pkPEM)

assert.True(t, reflect.DeepEqual(actualPKE.PrivateKey, decodedPK.Bytes), "unexpected private key")
assert.Equal(t, decodedPK.Bytes, actualPKE.PrivateKey, "unexpected private key")
}

func TestLoadKeyPassword(t *testing.T) {
Expand Down Expand Up @@ -222,8 +217,7 @@ func TestLoadKeyPassword(t *testing.T) {

decodedPK, _ := pem.Decode(pkPEM)

assert.Truef(t, reflect.DeepEqual(actualPKE.PrivateKey, decodedPK.Bytes),
"unexpected private key %v \n %v", actualPKE.PrivateKey, decodedPK.Bytes)
assert.Equal(t, decodedPK.Bytes, actualPKE.PrivateKey, "unexpected private key")
}

func readPrivateKey(t *testing.T) []byte {
Expand Down

0 comments on commit 4abc7d4

Please sign in to comment.