diff --git a/README.md b/README.md index 08ab87f..aa96378 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ A brief overview of the various command-line switches and HTTP endpoints and API ## Features -- Horizontal scaling: zero/minimal local state. Persistence in storage layers. MySQL backend provided in the box. +- Horizontal scaling: zero/minimal local state. Persistence in storage layers. MySQL and PostgreSQL backends provided in the box. - Multiple APNs topics: potentially multi-tenant. - Multi-command targeting: send the same command (or pushes) to multiple enrollments without individually queuing commands. - Migration endpoint: allow migrating MDM enrollments between storage backends or (supported) MDM servers diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 5eae8ff..d058ba2 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -11,8 +11,10 @@ import ( "github.com/micromdm/nanomdm/storage/allmulti" "github.com/micromdm/nanomdm/storage/file" "github.com/micromdm/nanomdm/storage/mysql" + "github.com/micromdm/nanomdm/storage/pgsql" _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" ) type StringAccumulator []string @@ -72,6 +74,12 @@ func (s *Storage) Parse(logger log.Logger) (storage.AllStorage, error) { return nil, err } mdmStorage = append(mdmStorage, mysqlStorage) + case "pgsql": + pgsqlStorage, err := pgsqlStorageConfig(dsn, options, logger) + if err != nil { + return nil, err + } + mdmStorage = append(mdmStorage, pgsqlStorage) default: return nil, fmt.Errorf("unknown storage: %s", storage) } @@ -134,3 +142,27 @@ func splitOptions(s string) map[string]string { } return out } + +func pgsqlStorageConfig(dsn, options string, logger log.Logger) (*pgsql.PgSQLStorage, error) { + logger = logger.With("storage", "pgsql") + opts := []pgsql.Option{ + pgsql.WithDSN(dsn), + pgsql.WithLogger(logger), + } + if options != "" { + for k, v := range splitOptions(options) { + switch k { + case "delete": + if v == "1" { + opts = append(opts, pgsql.WithDeleteCommands()) + logger.Debug("msg", "deleting commands") + } else if v != "0" { + return nil, fmt.Errorf("invalid value for delete option: %q", v) + } + default: + return nil, fmt.Errorf("invalid option: %q", k) + } + } + } + return pgsql.New(opts...) +} diff --git a/docs/operations-guide.md b/docs/operations-guide.md index 1c60c46..4617561 100644 --- a/docs/operations-guide.md +++ b/docs/operations-guide.md @@ -63,6 +63,20 @@ Options are specified as a comma-separated list of "key=value" pairs. The mysql *Example:* `-storage mysql -dsn nanomdm:nanomdm/mymdmdb -storage-options delete=1` +#### pgsql storage backend + +* `-storage pgsql` + +Configures the PostgreSQL storage backend. The `-dsn` flag should be in the [format the SQL driver expects](https://pkg.go.dev/github.com/lib/pq#pkg-overview). Be sure to create your tables with the [schema.sql](../storage/pgsql/schema.sql) file that corresponds to your NanoMDM version. Also make sure you apply any schema changes for each updated version (i.e. execute the numbered schema change files). PostgreSQL 9.5 or later is required. + +*Example:* `-storage pgsql -dsn postgres://postgres:toor@localhost:5432/nanomdm?sslmode=disable` + +Options are specified as a comma-separated list of "key=value" pairs. The pgsql backend supports these options: +* `delete=1`, `delete=0` + * This option turns on or off the command and response deleter. It is disabled by default. When enabled (with `delete=1`) command responses, queued commands, and commands themselves will be deleted from the database after enrollments have responded to a command. + +*Example:* `-storage pgsql -dsn postgres://postgres:toor@localhost/nanomdm -storage-options delete=1` + #### multi-storage backend You can configure multiple storage backends to be used simultaneously. Specifying multiple sets of `-storage`, `-dsn`, & `-storage-options` flags will configure the "multi-storage" adapter. The flags must be specified in sets and are related to each other in the order they're specified: for example the first `-storage` flag corresponds to the first `-dsn` flag and so forth. diff --git a/go.mod b/go.mod index 22b14a8..26c082f 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require ( github.com/RobotsAndPencils/buford v0.14.0 github.com/go-sql-driver/mysql v1.6.0 github.com/groob/plist v0.0.0-20220217120414-63fa881b19a5 + github.com/lib/pq v1.10.6 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 ) diff --git a/go.sum b/go.sum index 4a232ff..9c62244 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/groob/plist v0.0.0-20220217120414-63fa881b19a5 h1:saaSiB25B1wgaxrshQhurfPKUGJ4It3OxNJUy0rdOjU= github.com/groob/plist v0.0.0-20220217120414-63fa881b19a5/go.mod h1:itkABA+w2cw7x5nYUS/pLRef6ludkZKOigbROmCTaFw= +github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= +github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 h1:CCriYyAfq1Br1aIYettdHZTy8mBTIPo7We18TuO/bak= go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/storage/mysql/mysql.go b/storage/mysql/mysql.go index cec2ced..398f867 100644 --- a/storage/mysql/mysql.go +++ b/storage/mysql/mysql.go @@ -1,4 +1,4 @@ -// Pacakge mysql stores and retrieves MDM data from SQL +// Package mysql stores and retrieves MDM data from MySQL package mysql import ( diff --git a/storage/pgsql/bstoken.go b/storage/pgsql/bstoken.go new file mode 100644 index 0000000..4189d96 --- /dev/null +++ b/storage/pgsql/bstoken.go @@ -0,0 +1,36 @@ +package pgsql + +import ( + "github.com/micromdm/nanomdm/mdm" +) + +func (s *PgSQLStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error { + _, err := s.db.ExecContext( + r.Context, + `UPDATE devices SET bootstrap_token_b64 = $1, bootstrap_token_at = CURRENT_TIMESTAMP WHERE id = $2;`, + nullEmptyString(msg.BootstrapToken.BootstrapToken.String()), + r.ID, + ) + if err != nil { + return err + } + return s.updateLastSeen(r) +} + +func (s *PgSQLStorage) RetrieveBootstrapToken(r *mdm.Request, _ *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { + var tokenB64 string + err := s.db.QueryRowContext( + r.Context, + `SELECT bootstrap_token_b64 FROM devices WHERE id = $1;`, + r.ID, + ).Scan(&tokenB64) + if err != nil { + return nil, err + } + bsToken := new(mdm.BootstrapToken) + err = bsToken.SetTokenString(tokenB64) + if err == nil { + err = s.updateLastSeen(r) + } + return bsToken, err +} diff --git a/storage/pgsql/certauth.go b/storage/pgsql/certauth.go new file mode 100644 index 0000000..1ceec98 --- /dev/null +++ b/storage/pgsql/certauth.go @@ -0,0 +1,52 @@ +package pgsql + +import ( + "context" + "strings" + + "github.com/micromdm/nanomdm/mdm" +) + +// Executes SQL statements that return a single COUNT(*) of rows. +func (s *PgSQLStorage) queryRowContextRowExists(ctx context.Context, query string, args ...interface{}) (bool, error) { + var ct int + err := s.db.QueryRowContext(ctx, query, args...).Scan(&ct) + return ct > 0, err +} + +func (s *PgSQLStorage) EnrollmentHasCertHash(r *mdm.Request, _ string) (bool, error) { + return s.queryRowContextRowExists( + r.Context, + `SELECT COUNT(*) FROM cert_auth_associations WHERE id = $1;`, + r.ID, + ) +} + +func (s *PgSQLStorage) HasCertHash(r *mdm.Request, hash string) (bool, error) { + return s.queryRowContextRowExists( + r.Context, + `SELECT COUNT(*) FROM cert_auth_associations WHERE sha256 = $1;`, + strings.ToLower(hash), + ) +} + +func (s *PgSQLStorage) IsCertHashAssociated(r *mdm.Request, hash string) (bool, error) { + return s.queryRowContextRowExists( + r.Context, + `SELECT COUNT(*) FROM cert_auth_associations WHERE id = $1 AND sha256 = $2;`, + r.ID, strings.ToLower(hash), + ) +} + +// AssociateCertHash "DO NOTHING" on duplicated keys +func (s *PgSQLStorage) AssociateCertHash(r *mdm.Request, hash string) error { + _, err := s.db.ExecContext( + r.Context, ` +INSERT INTO cert_auth_associations (id, sha256) +VALUES ($1, $2) +ON CONFLICT ON CONSTRAINT cert_auth_associations_pkey DO UPDATE SET updated_at=now();`, + r.ID, + strings.ToLower(hash), + ) + return err +} diff --git a/storage/pgsql/migrate.go b/storage/pgsql/migrate.go new file mode 100644 index 0000000..e5fed0f --- /dev/null +++ b/storage/pgsql/migrate.go @@ -0,0 +1,61 @@ +package pgsql + +import ( + "context" + + "github.com/micromdm/nanomdm/mdm" +) + +func (s *PgSQLStorage) RetrieveMigrationCheckins(ctx context.Context, c chan<- interface{}) error { + // TODO: if a TokenUpdate does not include the latest UnlockToken + // then we should synthesize a TokenUpdate to transfer it over. + deviceRows, err := s.db.QueryContext( + ctx, + `SELECT authenticate, token_update FROM devices;`, + ) + if err != nil { + return err + } + defer deviceRows.Close() + for deviceRows.Next() { + var authBytes, tokenBytes []byte + if err := deviceRows.Scan(&authBytes, &tokenBytes); err != nil { + return err + } + for _, msgBytes := range [][]byte{authBytes, tokenBytes} { + msg, err := mdm.DecodeCheckin(msgBytes) + if err != nil { + c <- err + } else { + c <- msg + } + } + } + if err = deviceRows.Err(); err != nil { + return err + } + userRows, err := s.db.QueryContext( + ctx, + `SELECT token_update FROM users;`, + ) + if err != nil { + return err + } + defer userRows.Close() + for userRows.Next() { + var msgBytes []byte + if err := userRows.Scan(&msgBytes); err != nil { + return err + } + msg, err := mdm.DecodeCheckin(msgBytes) + if err != nil { + c <- err + } else { + c <- msg + } + } + if err = userRows.Err(); err != nil { + return err + } + return nil +} diff --git a/storage/pgsql/postgresql.go b/storage/pgsql/postgresql.go new file mode 100644 index 0000000..12e0ae8 --- /dev/null +++ b/storage/pgsql/postgresql.go @@ -0,0 +1,270 @@ +// Package pgsql stores and retrieves MDM data from PostgresSQL +package pgsql + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/micromdm/nanomdm/cryptoutil" + "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" + "github.com/micromdm/nanomdm/mdm" +) + +var ErrNoCert = errors.New("no certificate in MDM Request") + +type PgSQLStorage struct { + logger log.Logger + db *sql.DB + rm bool +} + +type config struct { + driver string + dsn string + db *sql.DB + logger log.Logger + rm bool +} + +type Option func(*config) + +func WithLogger(logger log.Logger) Option { + return func(c *config) { + c.logger = logger + } +} + +func WithDSN(dsn string) Option { + return func(c *config) { + c.dsn = dsn + } +} + +func WithDriver(driver string) Option { + return func(c *config) { + c.driver = driver + } +} + +func WithDB(db *sql.DB) Option { + return func(c *config) { + c.db = db + } +} + +func WithDeleteCommands() Option { + return func(c *config) { + c.rm = true + } +} + +func New(opts ...Option) (*PgSQLStorage, error) { + cfg := &config{logger: log.NopLogger, driver: "postgres"} + for _, opt := range opts { + opt(cfg) + } + var err error + if cfg.db == nil { + cfg.db, err = sql.Open(cfg.driver, cfg.dsn) + if err != nil { + return nil, err + } + } + if err = cfg.db.Ping(); err != nil { + return nil, err + } + return &PgSQLStorage{db: cfg.db, logger: cfg.logger, rm: cfg.rm}, nil +} + +// nullEmptyString returns a NULL string if s is empty. +func nullEmptyString(s string) sql.NullString { + return sql.NullString{ + String: s, + Valid: s != "", + } +} + +func (s *PgSQLStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) error { + var pemCert []byte + if r.Certificate != nil { + pemCert = cryptoutil.PEMCertificate(r.Certificate.Raw) + } + _, err := s.db.ExecContext( + r.Context, ` +INSERT INTO devices + (id, identity_cert, serial_number, authenticate, authenticate_at) +VALUES + ($1, $2, $3, $4, CURRENT_TIMESTAMP) +ON CONFLICT ON CONSTRAINT devices_pkey DO +UPDATE SET + identity_cert = EXCLUDED.identity_cert, + serial_number = EXCLUDED.serial_number, + authenticate = EXCLUDED.authenticate, + authenticate_at = CURRENT_TIMESTAMP;`, + r.ID, nullEmptyString(string(pemCert)), nullEmptyString(msg.SerialNumber), msg.Raw, + ) + return err +} + +func (s *PgSQLStorage) storeDeviceTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { + query := `UPDATE devices SET token_update = $1, token_update_at = CURRENT_TIMESTAMP` + where := ` WHERE id = $2;` + args := []interface{}{msg.Raw} + // separately store the Unlock Token per MDM spec + if len(msg.UnlockToken) > 0 { + query += `, unlock_token = $2, unlock_token_at = CURRENT_TIMESTAMP ` + args = append(args, msg.UnlockToken) + where = ` WHERE id = $3;` + } + args = append(args, r.ID) + _, err := s.db.ExecContext(r.Context, query+where, args...) + return err +} + +func (s *PgSQLStorage) storeUserTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { + // there shouldn't be an Unlock Token on the user channel, but + // complain if there is to warn an admin + if len(msg.UnlockToken) > 0 { + ctxlog.Logger(r.Context, s.logger).Info( + "msg", "Unlock Token on user channel not stored", + ) + } + _, err := s.db.ExecContext( + r.Context, ` +INSERT INTO users + (id, device_id, user_short_name, user_long_name, token_update, token_update_at) +VALUES + ($1, $2, $3, $4, $5, CURRENT_TIMESTAMP) +ON CONFLICT ON CONSTRAINT users_pkey DO UPDATE +SET + device_id = EXCLUDED.device_id, + user_short_name = EXCLUDED.user_short_name, + user_long_name = EXCLUDED.user_long_name, + token_update = EXCLUDED.token_update, + token_update_at = CURRENT_TIMESTAMP;`, + r.ID, + r.ParentID, + nullEmptyString(msg.UserShortName), + nullEmptyString(msg.UserLongName), + msg.Raw, + ) + return err +} + +func (s *PgSQLStorage) StoreTokenUpdate(r *mdm.Request, msg *mdm.TokenUpdate) error { + var err error + var deviceId, userId string + resolved := (&msg.Enrollment).Resolved() + if err = resolved.Validate(); err != nil { + return err + } + if resolved.IsUserChannel { + deviceId = r.ParentID + userId = r.ID + err = s.storeUserTokenUpdate(r, msg) + } else { + deviceId = r.ID + err = s.storeDeviceTokenUpdate(r, msg) + } + if err != nil { + return err + } + _, err = s.db.ExecContext( + r.Context, ` +INSERT INTO enrollments + (id, device_id, user_id, type, topic, push_magic, token_hex, last_seen_at, token_update_tally) +VALUES + ($1, $2, $3, $4, $5, $6, $7, CURRENT_TIMESTAMP, 1) +ON CONFLICT ON CONSTRAINT enrollments_pkey DO UPDATE +SET + device_id = EXCLUDED.device_id, + user_id = EXCLUDED.user_id, + type = EXCLUDED.type, + topic = EXCLUDED.topic, + push_magic = EXCLUDED.push_magic, + token_hex = EXCLUDED.token_hex, + enabled = TRUE, + last_seen_at = CURRENT_TIMESTAMP, + token_update_tally = enrollments.token_update_tally + 1;`, + r.ID, + deviceId, + nullEmptyString(userId), + r.Type.String(), + msg.Topic, + msg.PushMagic, + msg.Token.String(), + ) + return err +} + +func (s *PgSQLStorage) RetrieveTokenUpdateTally(ctx context.Context, id string) (int, error) { + var tally int + err := s.db.QueryRowContext( + ctx, + `SELECT token_update_tally FROM enrollments WHERE id = $1;`, + id, + ).Scan(&tally) + return tally, err +} + +func (s *PgSQLStorage) StoreUserAuthenticate(r *mdm.Request, msg *mdm.UserAuthenticate) error { + colName := "user_authenticate" + colAtName := "user_authenticate_at" + // if the DigestResponse is empty then this is the first (of two) + // UserAuthenticate messages depending on our response + if msg.DigestResponse != "" { + colName = "user_authenticate_digest" + colAtName = "user_authenticate_digest_at" + } + _, err := s.db.ExecContext( + r.Context, ` +INSERT INTO users + (id, device_id, user_short_name, user_long_name, `+colName+`, `+colAtName+`) +VALUES + ($1, $2, $3, $4, $5, CURRENT_TIMESTAMP) +ON CONFLICT ON CONSTRAINT users_pkey DO UPDATE +SET + device_id = EXCLUDED.device_id, + user_short_name = EXCLUDED.user_short_name, + user_long_name = EXCLUDED.user_long_name, + `+colName+` = EXCLUDED.`+colName+`, + `+colAtName+` = EXCLUDED.`+colAtName+`;`, + r.ID, + r.ParentID, + nullEmptyString(msg.UserShortName), + nullEmptyString(msg.UserLongName), + msg.Raw, + ) + if err != nil { + return err + } + return s.updateLastSeen(r) +} + +// Disable can be called for an Authenticate or CheckOut message +func (s *PgSQLStorage) Disable(r *mdm.Request) error { + if r.ParentID != "" { + return errors.New("can only disable a device channel") + } + _, err := s.db.ExecContext( + r.Context, + `UPDATE enrollments SET enabled = FALSE, token_update_tally = 0, last_seen_at = CURRENT_TIMESTAMP WHERE device_id = $1 AND enabled = TRUE;`, + r.ID, + ) + return err +} + +func (s *PgSQLStorage) updateLastSeen(r *mdm.Request) (err error) { + _, err = s.db.ExecContext( + r.Context, + `UPDATE enrollments SET last_seen_at = CURRENT_TIMESTAMP WHERE id = $1`, + r.ID, + ) + if err != nil { + err = fmt.Errorf("updating last seen: %w", err) + } + return +} diff --git a/storage/pgsql/push.go b/storage/pgsql/push.go new file mode 100644 index 0000000..0287f77 --- /dev/null +++ b/storage/pgsql/push.go @@ -0,0 +1,57 @@ +package pgsql + +import ( + "context" + "errors" + "strconv" + "strings" + + "github.com/micromdm/nanomdm/mdm" +) + +// RetrievePushInfo retreives push info for identifiers ids. +// +// Note that we may return fewer results than input. The user of this +// method needs to reconcile that with their requested ids. +func (s *PgSQLStorage) RetrievePushInfo(ctx context.Context, ids []string) (map[string]*mdm.Push, error) { + if len(ids) < 1 { + return nil, errors.New("no ids provided") + } + + // previous: `SELECT id, topic, push_magic, token_hex FROM enrollments WHERE id IN (`+qs+`);`, + // refactor all strings concatenations with strings.Builder which is more efficient + var qs strings.Builder + + qs.WriteString(`SELECT id, topic, push_magic, token_hex FROM enrollments WHERE id IN (`) + args := make([]interface{}, len(ids)) + for i, v := range ids { + args[i] = v + if i > 0 { + qs.WriteString(",") + } + // can be a bit faster than fmt.Fprintf(&qs, "$%d", i+1) + qs.WriteString("$") + qs.WriteString(strconv.Itoa(i + 1)) + } + qs.WriteString(`);`) + + rows, err := s.db.QueryContext(ctx, qs.String(), args...) + if err != nil { + return nil, err + } + defer rows.Close() + pushInfos := make(map[string]*mdm.Push) + for rows.Next() { + push := new(mdm.Push) + var id, token string + if err := rows.Scan(&id, &push.Topic, &push.PushMagic, &token); err != nil { + return nil, err + } + // convert from hex + if err := push.SetTokenString(token); err != nil { + return nil, err + } + pushInfos[id] = push + } + return pushInfos, rows.Err() +} diff --git a/storage/pgsql/pushcert.go b/storage/pgsql/pushcert.go new file mode 100644 index 0000000..7077938 --- /dev/null +++ b/storage/pgsql/pushcert.go @@ -0,0 +1,62 @@ +package pgsql + +import ( + "context" + "crypto/tls" + "strconv" + + "github.com/micromdm/nanomdm/cryptoutil" +) + +func (s *PgSQLStorage) RetrievePushCert(ctx context.Context, topic string) (*tls.Certificate, string, error) { + var certPEM, keyPEM []byte + var staleToken int + err := s.db.QueryRowContext( + ctx, + `SELECT cert_pem, key_pem, stale_token FROM push_certs WHERE topic = $1;`, + topic, + ).Scan(&certPEM, &keyPEM, &staleToken) + if err != nil { + return nil, "", err + } + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, "", err + } + return &cert, strconv.Itoa(staleToken), err +} + +func (s *PgSQLStorage) IsPushCertStale(ctx context.Context, topic, staleToken string) (bool, error) { + var staleTokenInt, dbStaleToken int + staleTokenInt, err := strconv.Atoi(staleToken) + if err != nil { + return true, err + } + err = s.db.QueryRowContext( + ctx, + `SELECT stale_token FROM push_certs WHERE topic = $1;`, + topic, + ).Scan(&dbStaleToken) + return dbStaleToken != staleTokenInt, err +} + +func (s *PgSQLStorage) StorePushCert(ctx context.Context, pemCert, pemKey []byte) error { + topic, err := cryptoutil.TopicFromPEMCert(pemCert) + if err != nil { + return err + } + _, err = s.db.ExecContext( + ctx, ` +INSERT INTO push_certs + (topic, cert_pem, key_pem, stale_token) +VALUES + ($1, $2, $3, 0) +ON CONFLICT (topic) DO +UPDATE SET + cert_pem = EXCLUDED.cert_pem, + key_pem = EXCLUDED.key_pem, + stale_token = push_certs.stale_token + 1;`, + topic, pemCert, pemKey, + ) + return err +} diff --git a/storage/pgsql/queue.go b/storage/pgsql/queue.go new file mode 100644 index 0000000..2a020e8 --- /dev/null +++ b/storage/pgsql/queue.go @@ -0,0 +1,199 @@ +package pgsql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/micromdm/nanomdm/mdm" +) + +func enqueue(ctx context.Context, tx *sql.Tx, ids []string, cmd *mdm.Command) error { + if len(ids) < 1 { + return errors.New("no id(s) supplied to queue command to") + } + _, err := tx.ExecContext( + ctx, + `INSERT INTO commands (command_uuid, request_type, command) VALUES ($1, $2, $3);`, + cmd.CommandUUID, cmd.Command.RequestType, cmd.Raw, + ) + if err != nil { + return err + } + + var query strings.Builder + + query.WriteString(`INSERT INTO enrollment_queue (id, command_uuid) VALUES `) + args := make([]interface{}, len(ids)*2) + for i, id := range ids { + if i > 0 { + query.WriteString(",") + } + ind := i * 2 + + //previous: query += fmt.Sprintf("($%d, $%d)", ind+1, ind+2) + query.WriteString("($") + query.WriteString(strconv.Itoa(ind + 1)) + query.WriteString(", $") + query.WriteString(strconv.Itoa(ind + 2)) + query.WriteString(")") + + args[ind] = id + args[ind+1] = cmd.CommandUUID + } + query.WriteString(";") + + _, err = tx.ExecContext(ctx, query.String(), args...) + return err +} + +func (s *PgSQLStorage) EnqueueCommand(ctx context.Context, ids []string, cmd *mdm.Command) (map[string]error, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + if err = enqueue(ctx, tx, ids, cmd); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return nil, fmt.Errorf("rollback error: %w; while trying to handle error: %v", rbErr, err) + } + return nil, err + } + return nil, tx.Commit() +} + +func (s *PgSQLStorage) deleteCommand(ctx context.Context, tx *sql.Tx, id, uuid string) error { + _, err := tx.ExecContext(ctx, ` +DELETE FROM enrollment_queue +WHERE id =$1 AND command_uuid =$2;`, id, uuid) + if err != nil { + return err + } + // delete command result (i.e. NotNows) and this queued command + _, err = tx.ExecContext(ctx, ` +DELETE FROM command_results +WHERE id =$1 AND command_uuid =$2;`, id, uuid) + if err != nil { + return err + } + + // now delete the actual command if no enrollments have it queued + // nor are there any results for it. + _, err = tx.ExecContext( + ctx, ` +DELETE FROM commands +USING + commands AS c + LEFT JOIN enrollment_queue AS q + ON q.command_uuid = c.command_uuid + LEFT JOIN command_results AS r + ON r.command_uuid = c.command_uuid +WHERE + c.command_uuid =$1 AND + q.command_uuid IS NULL AND + r.command_uuid IS NULL AND + commands.command_uuid = c.command_uuid; +`, + uuid, + ) + return err +} + +func (s *PgSQLStorage) deleteCommandTx(r *mdm.Request, result *mdm.CommandResults) error { + tx, err := s.db.BeginTx(r.Context, nil) + if err != nil { + return err + } + if err = s.deleteCommand(r.Context, tx, r.ID, result.CommandUUID); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("rollback error: %w; while trying to handle error: %v", rbErr, err) + } + return err + } + return tx.Commit() +} + +func (s *PgSQLStorage) StoreCommandReport(r *mdm.Request, result *mdm.CommandResults) error { + if err := s.updateLastSeen(r); err != nil { + return err + } + if result.Status == "Idle" { + return nil + } + if s.rm && result.Status != "NotNow" { + return s.deleteCommandTx(r, result) + } + notNowConstants := "NULL, 0" + notNowBumpTallySQL := "" + // note that due to the "ON CONFLICT ON CONSTRAINT command_results_pkey" we don't UPDATE the + // not_now_at field. thus it will only represent the first NotNow. + if result.Status == "NotNow" { + notNowConstants = "CURRENT_TIMESTAMP, 1" + notNowBumpTallySQL = `, not_now_tally = command_results.not_now_tally + 1` + } + _, err := s.db.ExecContext( + r.Context, ` +INSERT INTO command_results + (id, command_uuid, status, result, not_now_at, not_now_tally) +VALUES + ($1, $2, $3, $4, `+notNowConstants+`) +ON CONFLICT ON CONSTRAINT command_results_pkey DO UPDATE +SET + status = EXCLUDED.status, + result = EXCLUDED.result`+notNowBumpTallySQL+`;`, + r.ID, + result.CommandUUID, + result.Status, + result.Raw, + ) + return err +} + +func (s *PgSQLStorage) RetrieveNextCommand(r *mdm.Request, skipNotNow bool) (*mdm.Command, error) { + statusWhere := "status IS NULL" + if !skipNotNow { + statusWhere = `(` + statusWhere + ` OR status = 'NotNow')` + } + command := new(mdm.Command) + err := s.db.QueryRowContext( + r.Context, + `SELECT command_uuid, request_type, command FROM view_queue WHERE id = $1 AND active = TRUE AND `+statusWhere+` LIMIT 1;`, + r.ID, + ).Scan(&command.CommandUUID, &command.Command.RequestType, &command.Raw) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return command, nil +} + +func (s *PgSQLStorage) ClearQueue(r *mdm.Request) error { + if r.ParentID != "" { + return errors.New("can only clear a device channel queue") + } + // PostgreSQL UPDATE differs from MySQL, uses "FROM" specific + // to pgsql extension + _, err := s.db.ExecContext( + r.Context, + ` +UPDATE enrollment_queue +SET active = FALSE +FROM enrollment_queue AS q + INNER JOIN enrollments AS e + ON q.id = e.id + INNER JOIN commands AS c + ON q.command_uuid = c.command_uuid + LEFT JOIN command_results r + ON r.command_uuid = q.command_uuid AND r.id = q.id +WHERE + e.device_id = $1 AND + enrollment_queue.active = TRUE AND + (r.status IS NULL OR r.status = 'NotNow') AND + enrollment_queue.id = q.id;`, + r.ID) + return err +} diff --git a/storage/pgsql/queue_test.go b/storage/pgsql/queue_test.go new file mode 100644 index 0000000..f216d90 --- /dev/null +++ b/storage/pgsql/queue_test.go @@ -0,0 +1,111 @@ +//go:build integration +// +build integration + +package pgsql + +import ( + "context" + "errors" + "flag" + "io/ioutil" + "testing" + + _ "github.com/lib/pq" + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage/internal/test" +) + +var flDSN = flag.String("dsn", "", "DSN of test PostgreSQL instance") + +func loadAuthMsg() (*mdm.Authenticate, error) { + b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist") + if err != nil { + return nil, err + } + r, err := mdm.DecodeCheckin(b) + if err != nil { + return nil, err + } + a, ok := r.(*mdm.Authenticate) + if !ok { + return nil, errors.New("not an Authenticate message") + } + return a, nil +} + +func loadTokenMsg() (*mdm.TokenUpdate, error) { + b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist") + if err != nil { + return nil, err + } + r, err := mdm.DecodeCheckin(b) + if err != nil { + return nil, err + } + a, ok := r.(*mdm.TokenUpdate) + if !ok { + return nil, errors.New("not a TokenUpdate message") + } + return a, nil +} + +const deviceUDID = "66ADE930-5FDF-5EC4-8429-15640684C489" + +func newMdmReq() *mdm.Request { + return &mdm.Request{ + Context: context.Background(), + EnrollID: &mdm.EnrollID{ + Type: mdm.Device, + ID: deviceUDID, + }, + } +} + +func enrollTestDevice(storage *PgSQLStorage) error { + authMsg, err := loadAuthMsg() + if err != nil { + return err + } + err = storage.StoreAuthenticate(newMdmReq(), authMsg) + if err != nil { + return err + } + tokenMsg, err := loadTokenMsg() + if err != nil { + return err + } + err = storage.StoreTokenUpdate(newMdmReq(), tokenMsg) + if err != nil { + return err + } + return nil +} + +func TestQueue(t *testing.T) { + if *flDSN == "" { + t.Fatal("PostgreSQL DSN flag not provided to test") + } + + storage, err := New(WithDSN(*flDSN), WithDeleteCommands()) + if err != nil { + t.Fatal(err) + } + + err = enrollTestDevice(storage) + if err != nil { + t.Fatal(err) + } + + t.Run("WithDeleteCommands()", func(t *testing.T) { + test.TestQueue(t, deviceUDID, storage) + }) + + storage, err = New(WithDSN(*flDSN)) + if err != nil { + t.Fatal(err) + } + + t.Run("normal", func(t *testing.T) { + test.TestQueue(t, deviceUDID, storage) + }) +} diff --git a/storage/pgsql/schema.sql b/storage/pgsql/schema.sql new file mode 100644 index 0000000..e4bc723 --- /dev/null +++ b/storage/pgsql/schema.sql @@ -0,0 +1,317 @@ +/* Requires PostgreSQL 9.5 or later. + * From PostgreSQL documentation: ON CONFLICT clause is only available from PostgreSQL 9.5 + */ + +CREATE TABLE devices +( + id VARCHAR(255) NOT NULL, + + identity_cert TEXT NULL, + + serial_number VARCHAR(127) NULL, + + -- If the (iOS, iPadOS) device sent an UnlockToken in the TokenUpdate + -- TODO: Consider using a TEXT field and encoding the binary + unlock_token BYTEA NULL, + unlock_token_at TIMESTAMP NULL, + + -- The last raw Authenticate for this device + authenticate TEXT NOT NULL, + authenticate_at TIMESTAMP NOT NULL, + -- The last raw TokenUpdate for this device + token_update TEXT NULL, + token_update_at TIMESTAMP NULL, + + bootstrap_token_b64 TEXT NULL, + bootstrap_token_at TIMESTAMP NULL, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- trigger + + PRIMARY KEY (id), + + CHECK (identity_cert IS NULL OR SUBSTRING(identity_cert FROM 1 FOR 27) = '-----BEGIN CERTIFICATE-----'), + CHECK (serial_number IS NULL OR serial_number != ''), + CHECK (unlock_token IS NULL OR LENGTH(unlock_token) > 0), + CHECK (authenticate != ''), + CHECK (token_update IS NULL OR token_update != ''), + CHECK (bootstrap_token_b64 IS NULL OR bootstrap_token_b64 != '') +); +CREATE INDEX serial_number ON devices (serial_number); + +CREATE TABLE users +( + id VARCHAR(255) NOT NULL, + device_id VARCHAR(255) NOT NULL, + + user_short_name VARCHAR(255) NULL, + user_long_name VARCHAR(255) NULL, + + -- The last raw TokenUpdate for this user + token_update TEXT NULL, + token_update_at TIMESTAMP NULL, + + -- The last raw UserAuthenticate (and optional digest) for this user + user_authenticate TEXT NULL, + user_authenticate_at TIMESTAMP NULL, + user_authenticate_digest TEXT NULL, + user_authenticate_digest_at TIMESTAMP NULL, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, -- trigger + + PRIMARY KEY (id, device_id), + UNIQUE (id), + + FOREIGN KEY (device_id) + REFERENCES devices (id) + ON DELETE CASCADE ON UPDATE CASCADE, + + CHECK (user_short_name IS NULL OR user_short_name != ''), + CHECK (user_long_name IS NULL OR user_long_name != ''), + CHECK (token_update IS NULL OR token_update != ''), + CHECK (user_authenticate IS NULL OR user_authenticate != ''), + CHECK (user_authenticate_digest IS NULL OR user_authenticate_digest != '') +); + +/* This table represents enrollments which are an amalgamation of + * both device and user enrollments. + */ +CREATE TABLE enrollments +( + -- The enrollment ID of this enrollment + id VARCHAR(255) NOT NULL, + -- The "device" enrollment ID of this enrollment. This will be + -- the same as the `id` field in the case of a "device" enrollment, + -- or will be the "parent" enrollment for a "user" enrollment. + device_id VARCHAR(255) NOT NULL, + -- The "user" enrollment ID of this enrollment. This will be the + -- same as the `id` field in the case of a "user" enrollment or + -- NULL in the case of a device enrollment. + user_id VARCHAR(255) NULL, + + -- Textual representation of the type of device enrollment. + type VARCHAR(31) NOT NULL, + + -- The MDM APNs push trifecta. + topic VARCHAR(255) NOT NULL, + push_magic VARCHAR(127) NOT NULL, + token_hex VARCHAR(255) NOT NULL, -- TODO: Perhaps just CHAR(64)? + + enabled BOOLEAN NOT NULL DEFAULT TRUE, + token_update_tally INTEGER NOT NULL DEFAULT 1, + + last_seen_at TIMESTAMP NOT NULL, -- TODO: additional tests with real device and integration tests. + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + PRIMARY KEY (id), + CHECK (id != ''), + + FOREIGN KEY (device_id) + REFERENCES devices (id) + ON DELETE CASCADE ON UPDATE CASCADE, + + FOREIGN KEY (user_id) + REFERENCES users (id) + ON DELETE CASCADE ON UPDATE CASCADE, + UNIQUE (user_id), + + CHECK (type != ''), + CHECK (topic != ''), + CHECK (push_magic != ''), + CHECK (token_hex != '') +); +CREATE INDEX idx_type ON enrollments (type); + +/* Commands stand alone. By themselves they aren't associated with + * a device, a result (response), etc. Joining other tables is required + * for more context. + */ +CREATE TABLE commands +( + command_uuid VARCHAR(127) NOT NULL, + request_type VARCHAR(63) NOT NULL, + -- Raw command Plist + command TEXT NOT NULL, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + + PRIMARY KEY (command_uuid), + + CHECK (command_uuid != ''), + CHECK (request_type != ''), + CHECK (SUBSTRING(command FROM 1 FOR 5) = '