diff --git a/cmd/clairctl/admin.go b/cmd/clairctl/admin.go index 20caf84671..53c20b8389 100644 --- a/cmd/clairctl/admin.go +++ b/cmd/clairctl/admin.go @@ -297,9 +297,12 @@ func adminPre473(c *cli.Context) error { func updateGoPackages(c *cli.Context) error { const ( - getPackageNames = "SELECT DISTINCT package_name FROM vuln WHERE updater = 'osv/go'" - getPackages = "SELECT id, version FROM package WHERE name = $1 and norm_version IS NULL" - updatePackages = "UPDATE package SET norm_version=$1::int[], norm_kind=$2 WHERE id = $3 and norm_version IS NULL" + // TODO (crozzy): This could describe something more interesting like >6 or 6-10 but + // at the moment that seems like overkill. + compatibleMigrationVersion = 7 + getPackageNames = "SELECT DISTINCT package_name FROM vuln WHERE updater = 'osv/go'" + getPackages = "SELECT id, version FROM package WHERE name = $1 and norm_version IS NULL" + updatePackages = "UPDATE package SET norm_version=$1::int[], norm_kind=$2 WHERE id = $3 and norm_version IS NULL" ) ctx := c.Context @@ -344,6 +347,10 @@ func updateGoPackages(c *cli.Context) error { return fmt.Errorf("error creating indexer pool: %w", err) } defer indexerPool.Close() + err = checkMigrationVersion(ctx, indexerPool, "libindex_migrations", []int{compatibleMigrationVersion}) + if err != nil { + return fmt.Errorf("error checking migration version: %w", err) + } for _, p := range packageNames { err := indexerPool.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { @@ -403,6 +410,36 @@ func updateGoPackages(c *cli.Context) error { return nil } +type ErrNonCompatibleMigrationVersion struct { + version int + acceptableVersions []int +} + +func NewErrNonCompatibleMigrationVersion(version int, acceptableVersions []int) ErrNonCompatibleMigrationVersion { + return ErrNonCompatibleMigrationVersion{version: version, acceptableVersions: acceptableVersions} +} + +func (e ErrNonCompatibleMigrationVersion) Error() string { + return fmt.Sprintf("non-compatible migration version %d (acceptable versions: %v)", e.version, e.acceptableVersions) +} + +func checkMigrationVersion(ctx context.Context, pool *pgxpool.Pool, migrationTable string, acceptableVersions []int) error { + checkMigrationVersionQuery := fmt.Sprintf("SELECT MAX(version) FROM %s", migrationTable) + return pool.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { + var version int + err := conn.QueryRow(ctx, checkMigrationVersionQuery).Scan(&version) + if err != nil { + return err + } + for _, v := range acceptableVersions { + if v == version { + return nil + } + } + return NewErrNonCompatibleMigrationVersion(version, acceptableVersions) + }) +} + func createConnPool(ctx context.Context, dsn string, maxConns int32) (*pgxpool.Pool, error) { pgcfg, err := pgxpool.ParseConfig(dsn) if err != nil {