Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for record not found when searching the store #1686

Merged
merged 3 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions management/server/sqlite_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"errors"
"fmt"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -255,7 +256,11 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer

result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
}

peer.Status = &peerStatus
Expand All @@ -267,7 +272,11 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee
var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID)
if result.Error != nil {
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "peer %s not found", peer.ID)
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting peer from store")
}

peer.Location = peerWithLocation.Location
Expand All @@ -291,7 +300,11 @@ func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error)
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.Errorf("error when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

// TODO: rework to not call GetAccount
Expand All @@ -302,7 +315,11 @@ func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting setup key from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting setup key from store")
}

if key.AccountID == "" {
Expand All @@ -316,7 +333,11 @@ func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting token from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}

return token.ID, nil
Expand All @@ -326,7 +347,11 @@ func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, "id = ?", tokenID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting token from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if token.UserID == "" {
Expand Down Expand Up @@ -370,8 +395,11 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
Preload(clause.Associations).
First(&account, "id = ?", accountID)
if result.Error != nil {
log.Errorf("when getting account from the store: %s", result.Error)
return nil, status.Errorf(status.NotFound, "account not found")
log.Errorf("error when getting account from the store: %s", result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found")
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
Expand Down Expand Up @@ -431,7 +459,11 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, "id = ?", userID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting user from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if user.AccountID == "" {
Expand All @@ -445,7 +477,11 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if peer.AccountID == "" {
Expand All @@ -460,7 +496,11 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {

result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}

if peer.AccountID == "" {
Expand All @@ -476,7 +516,11 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time

result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
if result.Error != nil {
return status.Errorf(status.NotFound, "user %s not found", userID)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "user %s not found", userID)
}
log.Errorf("error when getting user from the store: %s", result.Error)
return status.Errorf(status.Internal, "issue getting user from store")
}

user.LastLogin = lastLogin
Expand Down
47 changes: 47 additions & 0 deletions management/server/sqlite_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/netbirdio/netbird/management/server/status"

nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/util"
)
Expand Down Expand Up @@ -174,6 +176,26 @@ func TestSqlite_DeleteAccount(t *testing.T) {

}

func TestSqlite_GetAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}

store := newSqliteStoreFromFile(t, "testdata/store.json")

id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

account, err := store.GetAccount(id)
require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match")

_, err = store.GetAccount("non-existing-account")
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}

func TestSqlite_SavePeerStatus(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
Expand All @@ -188,6 +210,9 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")

// save new status of existing peer
account.Peers["testpeer"] = &nbpeer.Peer{
Expand Down Expand Up @@ -254,6 +279,13 @@ func TestSqlite_SavePeerLocation(t *testing.T) {

actual := account.Peers[peer.ID].Location
assert.Equal(t, peer.Location, actual)

peer.ID = "non-existing-peer"
err = store.SavePeerLocation(account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}

func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
Expand All @@ -271,6 +303,9 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {

_, err = store.GetAccountByPrivateDomain("missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}

func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
Expand All @@ -286,6 +321,12 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
token, err := store.GetTokenIDByHashedToken(hashed)
require.NoError(t, err)
require.Equal(t, id, token)

_, err = store.GetTokenIDByHashedToken("non-existing-hash")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}

func TestSqlite_GetUserByTokenID(t *testing.T) {
Expand All @@ -300,6 +341,12 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
user, err := store.GetUserByTokenID(id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)

_, err = store.GetUserByTokenID("non-existing-id")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}

func newSqliteStore(t *testing.T) *SqliteStore {
Expand Down
Loading