From b0a10aa5a58ead71113b1e0e2db186de7cc81450 Mon Sep 17 00:00:00 2001 From: Nathaniel Caza Date: Tue, 26 Mar 2024 10:07:10 -0500 Subject: [PATCH] use pool for pgdump-lite, csv output for COPY (#3770) --- devtools/pgdump-lite/cmd/pgdump-lite/main.go | 38 ++- devtools/pgdump-lite/dumpdata.go | 231 ++++++++----------- devtools/pgdump-lite/dumpschema.go | 18 +- devtools/pgdump-lite/pgd/db.go | 2 +- devtools/pgdump-lite/pgd/models.go | 2 +- devtools/pgdump-lite/pgd/queries.sql.go | 2 +- 6 files changed, 144 insertions(+), 149 deletions(-) diff --git a/devtools/pgdump-lite/cmd/pgdump-lite/main.go b/devtools/pgdump-lite/cmd/pgdump-lite/main.go index 1ea6e3f334..62303567c1 100644 --- a/devtools/pgdump-lite/cmd/pgdump-lite/main.go +++ b/devtools/pgdump-lite/cmd/pgdump-lite/main.go @@ -8,7 +8,9 @@ import ( "strings" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" "github.com/target/goalert/devtools/pgdump-lite" + "github.com/target/goalert/devtools/pgdump-lite/pgd" ) func main() { @@ -17,6 +19,7 @@ func main() { db := flag.String("d", os.Getenv("DBURL"), "DB URL") // use same env var as pg_dump dataOnly := flag.Bool("a", false, "dump only the data, not the schema") schemaOnly := flag.Bool("s", false, "dump only the schema, no data") + parallel := flag.Bool("p", false, "dump data in parallel (note: separate tables will still be dumped in order, but not in the same transaction, so may be inconsistent between tables)") skip := flag.String("T", "", "skip tables") flag.Parse() @@ -31,20 +34,39 @@ func main() { } ctx := context.Background() - cfg, err := pgx.ParseConfig(*db) + cfg, err := pgxpool.ParseConfig(*db) if err != nil { log.Fatalln("ERROR: invalid db url:", err) } - cfg.RuntimeParams["client_encoding"] = "UTF8" + cfg.ConnConfig.RuntimeParams["client_encoding"] = "UTF8" - conn, err := pgx.ConnectConfig(ctx, cfg) + conn, err := pgxpool.NewWithConfig(ctx, cfg) if err != nil { log.Fatalln("ERROR: connect:", err) } - defer conn.Close(ctx) + defer conn.Close() + dbtx := pgd.DBTX(conn) + if !*parallel { + tx, err := conn.BeginTx(ctx, pgx.TxOptions{ + IsoLevel: pgx.Serializable, + AccessMode: pgx.ReadOnly, + DeferrableMode: pgx.Deferrable, + }) + if err != nil { + log.Fatalln("ERROR: begin tx:", err) + } + defer func() { + err := tx.Commit(ctx) + if err != nil { + log.Fatalln("ERROR: commit tx:", err) + } + }() + } + + var s *pgdump.Schema if !*dataOnly { - s, err := pgdump.DumpSchema(ctx, conn) + s, err = pgdump.DumpSchema(ctx, dbtx) if err != nil { log.Fatalln("ERROR: dump data:", err) } @@ -59,7 +81,11 @@ func main() { } if !*schemaOnly { - err = pgdump.DumpData(ctx, conn, out, strings.Split(*skip, ",")) + if *parallel { + err = pgdump.DumpDataWithSchemaParallel(ctx, conn, out, strings.Split(*skip, ","), s) + } else { + err = pgdump.DumpDataWithSchema(ctx, dbtx, out, strings.Split(*skip, ","), s) + } if err != nil { log.Fatalln("ERROR: dump data:", err) } diff --git a/devtools/pgdump-lite/dumpdata.go b/devtools/pgdump-lite/dumpdata.go index 0094b4d891..c81fd01492 100644 --- a/devtools/pgdump-lite/dumpdata.go +++ b/devtools/pgdump-lite/dumpdata.go @@ -1,174 +1,135 @@ package pgdump import ( + "bufio" "context" - "errors" "fmt" "io" - "sort" - "strings" + "slices" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgtype" - "github.com/target/goalert/util/sqlutil" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/target/goalert/devtools/pgdump-lite/pgd" ) -func sortColumns(columns []string) { - // alphabetical, but with id first - sort.Slice(columns, func(i, j int) bool { - ci, cj := columns[i], columns[j] - if ci == cj { - return false - } - if ci == "id" { - return true - } - if cj == "id" { - return false - } - - return ci < cj - }) -} - -func quoteNames(names []string) { - for i, n := range names { - names[i] = pgx.Identifier{n}.Sanitize() - } +type TableData struct { + TableName string + Columns []string + Rows [][]string } -func queryStrings(ctx context.Context, tx pgx.Tx, sql string, args ...interface{}) ([]string, error) { - rows, err := tx.Query(ctx, sql, args...) - if err != nil { - return nil, err - } - defer rows.Close() - var result []string - for rows.Next() { - var value string - err = rows.Scan(&value) +// DumpDataWithSchema will return all data from all tables (except those in skipTables) in a structured format. +func DumpDataWithSchema(ctx context.Context, conn pgd.DBTX, out io.Writer, skipTables []string, schema *Schema) error { + var err error + if schema == nil { + schema, err = DumpSchema(ctx, conn) if err != nil { - return nil, err + return fmt.Errorf("dump schema: %w", err) } - result = append(result, value) } - return result, nil -} - -type scannable string + for _, t := range schema.Tables { + if slices.Contains(skipTables, t.Name) { + continue + } -func (s *scannable) ScanText(v pgtype.Text) error { - if !v.Valid { - *s = "\\N" - } else { - *s = scannable(strings.ReplaceAll(v.String, "\\", "\\\\")) + err = dumpTableDataWith(ctx, conn, out, t.Name) + if err != nil { + return fmt.Errorf("dump table data: %w", err) + } } return nil } -func contains(s []string, e string) bool { - for _, a := range s { - if a == e { - return true +// DumpDataWithSchema will return all data from all tables (except those in skipTables) in a structured format. +func DumpDataWithSchemaParallel(ctx context.Context, conn *pgxpool.Pool, out io.Writer, skipTables []string, schema *Schema) error { + var err error + if schema == nil { + schema, err = DumpSchema(ctx, conn) + if err != nil { + return fmt.Errorf("dump schema: %w", err) } } - return false -} -func DumpData(ctx context.Context, conn *pgx.Conn, out io.Writer, skip []string) error { - tx, err := conn.BeginTx(ctx, pgx.TxOptions{ - IsoLevel: pgx.Serializable, - DeferrableMode: pgx.Deferrable, - AccessMode: pgx.ReadOnly, - }) - if err != nil { - return fmt.Errorf("begin tx: %w", err) + type streamW struct { + name string + pw *io.PipeWriter } - defer sqlutil.RollbackContext(ctx, "pgdump-lite: dump data", tx) - tables, err := queryStrings(ctx, tx, "select table_name from information_schema.tables where table_schema = 'public'") - if err != nil { - return fmt.Errorf("read tables: %w", err) - } - sort.Strings(tables) - - for _, table := range tables { - if contains(skip, table) { + streams := make(chan io.Reader, len(schema.Tables)) + inputs := make(chan streamW, len(schema.Tables)) + for _, t := range schema.Tables { + if slices.Contains(skipTables, t.Name) { continue } + pr, pw := io.Pipe() + streams <- pr + inputs <- streamW{name: t.Name, pw: pw} + + go func() { + err := conn.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { + s := <-inputs + w := bufio.NewWriterSize(s.pw, 65535) + err := dumpTableData(ctx, conn.Conn(), w, s.name) + if err != nil { + return s.pw.CloseWithError(fmt.Errorf("dump table data: %w", err)) + } + defer s.pw.Close() + return w.Flush() + }) + if err != nil { + panic(err) + } + }() + } + close(streams) - columns, err := queryStrings(ctx, tx, "select column_name from information_schema.columns where table_schema = 'public' and table_name = $1 order by ordinal_position", table) + for r := range streams { + _, err = io.Copy(out, r) if err != nil { - return fmt.Errorf("read columns for '%s': %w", table, err) - } - quoteNames(columns) - - primaryCols, err := queryStrings(ctx, tx, ` - select col.column_name - from information_schema.table_constraints tbl - join information_schema.constraint_column_usage col on - col.table_schema = 'public' and - col.constraint_name = tbl.constraint_name - where - tbl.table_schema = 'public' and - tbl.table_name = $1 and - constraint_type = 'PRIMARY KEY' - `, table) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("read primary key for '%s': %w", table, err) + return fmt.Errorf("write table data: %w", err) } - sortColumns(primaryCols) - quoteNames(primaryCols) + } - colNames := strings.Join(columns, ", ") - orderBy := strings.Join(primaryCols, ",") - if orderBy == "" { - orderBy = colNames - } + return nil +} - fmt.Fprintf(out, "COPY %s (%s) FROM stdin;\n", pgx.Identifier{table}.Sanitize(), colNames) - rows, err := tx.Query(ctx, - fmt.Sprintf("select %s from %s order by %s", - colNames, - table, - orderBy, - ), - pgx.QueryExecModeSimpleProtocol, - ) - if err != nil { - return fmt.Errorf("read data on '%s': %w", table, err) - } - defer rows.Close() - vals := make([]interface{}, len(columns)) +func dumpTableDataWith(ctx context.Context, db pgd.DBTX, out io.Writer, tableName string) error { + switch db := db.(type) { + case *pgx.Conn: + return dumpTableData(ctx, db, out, tableName) + case *pgxpool.Pool: + return db.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { + return dumpTableData(ctx, conn.Conn(), out, tableName) + }) + default: + return fmt.Errorf("unsupported DBTX type: %T", db) + } +} - for i := range vals { - vals[i] = new(scannable) - } - for rows.Next() { - err = rows.Scan(vals...) - if err != nil { - return fmt.Errorf("read data on '%s': %w", table, err) - } - for i, v := range vals { - if i > 0 { - if _, err := io.WriteString(out, "\t"); err != nil { - return err - } - } - if _, err := io.WriteString(out, string(*v.(*scannable))); err != nil { - return err - } - } - if _, err := io.WriteString(out, "\n"); err != nil { - return err - } - } - rows.Close() +func dumpTableData(ctx context.Context, conn *pgx.Conn, out io.Writer, tableName string) error { + _, err := fmt.Fprintf(out, "COPY %s FROM stdin WITH (FORMAT csv, HEADER MATCH, ENCODING utf8);\n", pgx.Identifier{tableName}.Sanitize()) + if err != nil { + return fmt.Errorf("write COPY statement: %w", err) + } + _, err = conn.PgConn().CopyTo(ctx, out, fmt.Sprintf("COPY %s TO STDOUT WITH (FORMAT csv, HEADER true, ENCODING utf8)", pgx.Identifier{tableName}.Sanitize())) + if err != nil { + return fmt.Errorf("copy data: %w", err) + } + _, err = fmt.Fprintf(out, "\\.\n\n") + if err != nil { + return fmt.Errorf("write end of COPY: %w", err) + } + + return nil +} - fmt.Fprintf(out, "\\.\n\n") +func DumpData(ctx context.Context, conn pgd.DBTX, out io.Writer, skipTables []string) error { + schema, err := DumpSchema(ctx, conn) + if err != nil { + return err } - return tx.Commit(ctx) + return DumpDataWithSchema(ctx, conn, out, skipTables, schema) } diff --git a/devtools/pgdump-lite/dumpschema.go b/devtools/pgdump-lite/dumpschema.go index df75c03c3f..14b078e00b 100644 --- a/devtools/pgdump-lite/dumpschema.go +++ b/devtools/pgdump-lite/dumpschema.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" - "github.com/jackc/pgx/v5" "github.com/target/goalert/devtools/pgdump-lite/pgd" ) @@ -57,14 +56,16 @@ type Index struct { Def string } -func (idx Index) String() string { return idx.Def + ";" } +func (idx Index) EntityName() string { return idx.Name } +func (idx Index) String() string { return idx.Def + ";" } type Trigger struct { Name string Def string } -func (t Trigger) String() string { return t.Def + ";" } +func (t Trigger) EntityName() string { return t.Name } +func (t Trigger) String() string { return t.Def + ";" } type Sequence struct { Name string @@ -77,6 +78,7 @@ type Sequence struct { OwnedBy string } +func (s Sequence) EntityName() string { return s.Name } func (s Sequence) String() string { def := fmt.Sprintf("CREATE SEQUENCE %s\n\tSTART WITH %d\n\tINCREMENT BY %d\n\tMINVALUE %d\n\tMAXVALUE %d\n\tCACHE %d", s.Name, s.StartValue, s.Increment, s.MinValue, s.MaxValue, s.Cache) @@ -95,6 +97,7 @@ type Extension struct { Name string } +func (e Extension) EntityName() string { return e.Name } func (e Extension) String() string { return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", e.Name) } @@ -104,13 +107,15 @@ type Function struct { Def string } -func (f Function) String() string { return f.Def + ";" } +func (f Function) EntityName() string { return f.Name } +func (f Function) String() string { return f.Def + ";" } type Enum struct { Name string Values []string } +func (e Enum) EntityName() string { return e.Name } func (e Enum) String() string { return fmt.Sprintf("CREATE TYPE %s AS ENUM (\n\t'%s'\n);", e.Name, strings.Join(e.Values, "',\n\t'")) } @@ -125,6 +130,7 @@ type Table struct { Sequences []Sequence } +func (t Table) EntityName() string { return t.Name } func (t Table) String() string { var lines []string for _, c := range t.Columns { @@ -160,6 +166,7 @@ type Constraint struct { Def string } +func (c Constraint) EntityName() string { return c.Name } func (c Constraint) String() string { return fmt.Sprintf("CONSTRAINT %s %s", c.Name, c.Def) } @@ -171,6 +178,7 @@ type Column struct { DefaultValue string } +func (c Column) EntityName() string { return c.Name } func (c Column) String() string { var def string if c.DefaultValue != "" { @@ -182,7 +190,7 @@ func (c Column) String() string { return fmt.Sprintf("%s %s%s", c.Name, c.Type, def) } -func DumpSchema(ctx context.Context, conn *pgx.Conn) (*Schema, error) { +func DumpSchema(ctx context.Context, conn pgd.DBTX) (*Schema, error) { db := pgd.New(conn) var s Schema diff --git a/devtools/pgdump-lite/pgd/db.go b/devtools/pgdump-lite/pgd/db.go index f966bc9545..037305d722 100644 --- a/devtools/pgdump-lite/pgd/db.go +++ b/devtools/pgdump-lite/pgd/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.25.0 package pgd diff --git a/devtools/pgdump-lite/pgd/models.go b/devtools/pgdump-lite/pgd/models.go index 0c17183996..7ee37c351a 100644 --- a/devtools/pgdump-lite/pgd/models.go +++ b/devtools/pgdump-lite/pgd/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.25.0 package pgd diff --git a/devtools/pgdump-lite/pgd/queries.sql.go b/devtools/pgdump-lite/pgd/queries.sql.go index 207811ca27..ba5c2a9851 100644 --- a/devtools/pgdump-lite/pgd/queries.sql.go +++ b/devtools/pgdump-lite/pgd/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.21.0 +// sqlc v1.25.0 // source: queries.sql package pgd