diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index 06369696091..bfde82a6de7 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -55,8 +55,9 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, file := filepath.Join(dataDir, storeStr) db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ - Logger: logger.Default.LogMode(logger.Silent), - PrepareStmt: true, + Logger: logger.Default.LogMode(logger.Silent), + CreateBatchSize: 400, + PrepareStmt: true, }) if err != nil { return nil, err @@ -196,7 +197,8 @@ func (s *SqliteStore) SaveAccount(account *Account) error { result = tx. Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}).Create(account) + Clauses(clause.OnConflict{UpdateAll: true}). + Create(account) if result.Error != nil { return result.Error } diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index 31f9b8a5b32..8a1bcd10aeb 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -2,6 +2,9 @@ package server import ( "fmt" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "math/rand" "net" "net/netip" "path/filepath" @@ -33,6 +36,151 @@ func TestSqlite_NewStore(t *testing.T) { } } +func TestSqlite_SaveAccount_Large(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStore(t) + + account := newAccountWithId("account_id", "testuser", "") + groupALL, err := account.GetGroupAll() + if err != nil { + t.Fatal(err) + } + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + const numPerAccount = 2000 + for n := 0; n < numPerAccount; n++ { + netIP := randomIPv4() + peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + + peer := &nbpeer.Peer{ + ID: peerID, + Key: peerID, + SetupKey: "", + IP: netIP, + Name: peerID, + DNSLabel: peerID, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + account.Peers[peerID] = peer + group, _ := account.GetGroupAll() + group.Peers = append(group.Peers, peerID) + user := &User{ + Id: fmt.Sprintf("%s-user-%d", account.Id, n), + AccountID: account.Id, + } + account.Users[user.Id] = user + route := &route2.Route{ + ID: fmt.Sprintf("network-id-%d", n), + Description: "base route", + NetID: fmt.Sprintf("network-id-%d", n), + Network: netip.MustParsePrefix(netIP.String() + "/24"), + NetworkType: route2.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + Groups: []string{groupALL.ID}, + } + account.Routes[route.ID] = route + + group = &nbgroup.Group{ + ID: fmt.Sprintf("group-id-%d", n), + AccountID: account.Id, + Name: fmt.Sprintf("group-id-%d", n), + Issued: "api", + Peers: nil, + } + account.Groups[group.ID] = group + + nameserver := &nbdns.NameServerGroup{ + ID: fmt.Sprintf("nameserver-id-%d", n), + AccountID: account.Id, + Name: fmt.Sprintf("nameserver-id-%d", n), + Description: "", + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, + Groups: []string{group.ID}, + Primary: false, + Domains: nil, + Enabled: false, + SearchDomainsEnabled: false, + } + account.NameServerGroups[nameserver.ID] = nameserver + + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + } + + err = store.SaveAccount(account) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 1 { + t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") + } + + a, err := store.GetAccount(account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } + + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } + + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } + + if a != nil && len(a.Peers) != numPerAccount { + t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d", + numPerAccount, len(a.Peers)) + return + } + + if a != nil && len(a.Users) != numPerAccount+1 { + t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d", + numPerAccount+1, len(a.Users)) + return + } + + if a != nil && len(a.Routes) != numPerAccount { + t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d", + numPerAccount, len(a.Routes)) + return + } + + if a != nil && len(a.NameServerGroups) != numPerAccount { + t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d", + numPerAccount, len(a.NameServerGroups)) + return + } + + if a != nil && len(a.NameServerGroups) != numPerAccount { + t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d", + numPerAccount, len(a.NameServerGroups)) + return + } + + if a != nil && len(a.SetupKeys) != numPerAccount+1 { + t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d", + numPerAccount+1, len(a.SetupKeys)) + return + } +} + +func randomIPv4() net.IP { + rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, 4) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return net.IP(b) +} + func TestSqlite_SaveAccount(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet")