Skip to content

Commit

Permalink
Simplify MultiStorage methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson committed Jun 26, 2021
1 parent a9d0462 commit 3907f14
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 117 deletions.
45 changes: 23 additions & 22 deletions storage/allmulti/allmulti.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,36 @@ func New(logger log.Logger, stores ...storage.AllStorage) *MultiAllStorage {
return &MultiAllStorage{logger: logger, stores: stores}
}

func (ms *MultiAllStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error {
finalErr := ms.stores[0].StoreAuthenticate(r, msg)
type storageErrorer func(storage.AllStorage) error

func (ms *MultiAllStorage) runAndLogOthers(storageCallback storageErrorer) {
for n, storage := range ms.stores[1:] {
if err := storage.StoreAuthenticate(r, msg); err != nil {
ms.logger.Info("method", "StoreAuthenticate", "storage", n+1, "err", err)
continue
if err := storageCallback(storage); err != nil {
ms.logger.Info("msg", n+1, "err", err)
}
}
return finalErr
}

func (ms *MultiAllStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error {
err := ms.stores[0].StoreAuthenticate(r, msg)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.StoreAuthenticate(r, msg)
})
return err
}

func (ms *MultiAllStorage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error {
finalErr := ms.stores[0].StoreTokenUpdate(r, msg)
for n, storage := range ms.stores[1:] {
if err := storage.StoreTokenUpdate(r, msg); err != nil {
ms.logger.Info("method", "StoreTokenUpdate", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].StoreTokenUpdate(r, msg)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.StoreTokenUpdate(r, msg)
})
return err
}

func (ms *MultiAllStorage) Disable(r *mdm.Request) error {
finalErr := ms.stores[0].Disable(r)
for n, storage := range ms.stores[1:] {
if err := storage.Disable(r); err != nil {
ms.logger.Info("method", "Disable", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].Disable(r)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.Disable(r)
})
return err
}
24 changes: 10 additions & 14 deletions storage/allmulti/bstoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,22 @@ package allmulti

import (
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)

func (ms *MultiAllStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error {
finalErr := ms.stores[0].StoreBootstrapToken(r, msg)
for n, storage := range ms.stores[1:] {
if err := storage.StoreBootstrapToken(r, msg); err != nil {
ms.logger.Info("method", "StoreBootstrapToken", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].StoreBootstrapToken(r, msg)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.StoreBootstrapToken(r, msg)
})
return err
}

func (ms *MultiAllStorage) RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) {
finalToken, finalErr := ms.stores[0].RetrieveBootstrapToken(r, msg)
for n, storage := range ms.stores[1:] {
if _, err := storage.RetrieveBootstrapToken(r, msg); err != nil {
ms.logger.Info("method", "RetrieveBootstrapToken", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.RetrieveBootstrapToken(r, msg)
return err
})
return finalToken, finalErr
}
48 changes: 21 additions & 27 deletions storage/allmulti/certauth.go
Original file line number Diff line number Diff line change
@@ -1,47 +1,41 @@
package allmulti

import "github.com/micromdm/nanomdm/mdm"
import (
"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)

func (ms *MultiAllStorage) HasCertHash(r *mdm.Request, hash string) (bool, error) {
hasFinal, finalErr := ms.stores[0].HasCertHash(r, hash)
for n, storage := range ms.stores[1:] {
if _, err := storage.HasCertHash(r, hash); err != nil {
ms.logger.Info("method", "HasCertHash", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.HasCertHash(r, hash)
return err
})
return hasFinal, finalErr
}

func (ms *MultiAllStorage) EnrollmentHasCertHash(r *mdm.Request, hash string) (bool, error) {
hasFinal, finalErr := ms.stores[0].EnrollmentHasCertHash(r, hash)
for n, storage := range ms.stores[1:] {
if _, err := storage.EnrollmentHasCertHash(r, hash); err != nil {
ms.logger.Info("method", "EnrollmentHasCertHash", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.EnrollmentHasCertHash(r, hash)
return err
})
return hasFinal, finalErr
}

func (ms *MultiAllStorage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) {
isAssocFinal, finalErr := ms.stores[0].IsCertHashAssociated(r, hash)
for n, storage := range ms.stores[1:] {
if _, err := storage.IsCertHashAssociated(r, hash); err != nil {
ms.logger.Info("method", "IsCertHashAssociated", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.IsCertHashAssociated(r, hash)
return err
})
return isAssocFinal, finalErr
}

func (ms *MultiAllStorage) AssociateCertHash(r *mdm.Request, hash string) error {
finalErr := ms.stores[0].AssociateCertHash(r, hash)
for n, storage := range ms.stores[1:] {
if err := storage.AssociateCertHash(r, hash); err != nil {
ms.logger.Info("method", "AssociateCertHash", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].AssociateCertHash(r, hash)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.AssociateCertHash(r, hash)
})
return err
}
11 changes: 5 additions & 6 deletions storage/allmulti/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import (
"context"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)

func (ms *MultiAllStorage) RetrievePushInfo(ctx context.Context, ids []string) (map[string]*mdm.Push, error) {
finalMap, finalErr := ms.stores[0].RetrievePushInfo(ctx, ids)
for n, storage := range ms.stores[1:] {
if _, err := storage.RetrievePushInfo(ctx, ids); err != nil {
ms.logger.Info("method", "RetrievePushInfo", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.RetrievePushInfo(ctx, ids)
return err
})
return finalMap, finalErr
}
36 changes: 16 additions & 20 deletions storage/allmulti/pushcert.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,33 @@ package allmulti
import (
"context"
"crypto/tls"

"github.com/micromdm/nanomdm/storage"
)

func (ms *MultiAllStorage) IsPushCertStale(ctx context.Context, topic string, staleToken string) (bool, error) {
finalStale, finalErr := ms.stores[0].IsPushCertStale(ctx, topic, staleToken)
for n, storage := range ms.stores[1:] {
if _, err := storage.IsPushCertStale(ctx, topic, staleToken); err != nil {
ms.logger.Info("method", "IsPushCertStale", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.IsPushCertStale(ctx, topic, staleToken)
return err
})
return finalStale, finalErr
}

func (ms *MultiAllStorage) RetrievePushCert(ctx context.Context, topic string) (cert *tls.Certificate, staleToken string, err error) {
finalCert, finalToken, finalErr := ms.stores[0].RetrievePushCert(ctx, topic)
for n, storage := range ms.stores[1:] {
if _, _, err := storage.RetrievePushCert(ctx, topic); err != nil {
ms.logger.Info("method", "RetrievePushCert", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, _, err := s.RetrievePushCert(ctx, topic)
return err
})

return finalCert, finalToken, finalErr
}

func (ms *MultiAllStorage) StorePushCert(ctx context.Context, pemCert, pemKey []byte) error {
finalErr := ms.stores[0].StorePushCert(ctx, pemCert, pemKey)
for n, storage := range ms.stores[1:] {
if err := storage.StorePushCert(ctx, pemCert, pemKey); err != nil {
ms.logger.Info("method", "StorePushCert", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].StorePushCert(ctx, pemCert, pemKey)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.StorePushCert(ctx, pemCert, pemKey)
})
return err
}
47 changes: 19 additions & 28 deletions storage/allmulti/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,39 @@ import (
"context"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)

func (ms *MultiAllStorage) StoreCommandReport(r *mdm.Request, report *mdm.CommandResults) error {
finalErr := ms.stores[0].StoreCommandReport(r, report)
for n, storage := range ms.stores[1:] {
if err := storage.StoreCommandReport(r, report); err != nil {
ms.logger.Info("method", "StoreCommandReport", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].StoreCommandReport(r, report)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.StoreCommandReport(r, report)
})
return err
}

func (ms *MultiAllStorage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) {
skipFinal, finalErr := ms.stores[0].RetrieveNextCommand(r, skipNotNow)
for n, storage := range ms.stores[1:] {
if _, err := storage.RetrieveNextCommand(r, skipNotNow); err != nil {
ms.logger.Info("method", "RetrieveNextCommand", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.RetrieveNextCommand(r, skipNotNow)
return err
})
return skipFinal, finalErr
}

func (ms *MultiAllStorage) ClearQueue(r *mdm.Request) error {
finalErr := ms.stores[0].ClearQueue(r)
for n, storage := range ms.stores[1:] {
if err := storage.ClearQueue(r); err != nil {
ms.logger.Info("method", "ClearQueue", "storage", n+1, "err", err)
continue
}
}
return finalErr
err := ms.stores[0].ClearQueue(r)
ms.runAndLogOthers(func(s storage.AllStorage) error {
return s.ClearQueue(r)
})
return err
}

func (ms *MultiAllStorage) EnqueueCommand(ctx context.Context, id []string, cmd *mdm.Command) (map[string]error, error) {
finalMap, finalErr := ms.stores[0].EnqueueCommand(ctx, id, cmd)
for n, storage := range ms.stores[1:] {
if _, err := storage.EnqueueCommand(ctx, id, cmd); err != nil {
ms.logger.Info("method", "EnqueueCommand", "storage", n+1, "err", err)
continue
}
}
ms.runAndLogOthers(func(s storage.AllStorage) error {
_, err := s.EnqueueCommand(ctx, id, cmd)
return err
})
return finalMap, finalErr
}

0 comments on commit 3907f14

Please sign in to comment.