Skip to content

Commit

Permalink
use pool for pgdump-lite, csv output for COPY (#3770)
Browse files Browse the repository at this point in the history
  • Loading branch information
mastercactapus authored Mar 26, 2024
1 parent b0e642b commit b0a10aa
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 149 deletions.
38 changes: 32 additions & 6 deletions devtools/pgdump-lite/cmd/pgdump-lite/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"strings"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/target/goalert/devtools/pgdump-lite"
"github.com/target/goalert/devtools/pgdump-lite/pgd"
)

func main() {
Expand All @@ -17,6 +19,7 @@ func main() {
db := flag.String("d", os.Getenv("DBURL"), "DB URL") // use same env var as pg_dump
dataOnly := flag.Bool("a", false, "dump only the data, not the schema")
schemaOnly := flag.Bool("s", false, "dump only the schema, no data")
parallel := flag.Bool("p", false, "dump data in parallel (note: separate tables will still be dumped in order, but not in the same transaction, so may be inconsistent between tables)")
skip := flag.String("T", "", "skip tables")
flag.Parse()

Expand All @@ -31,20 +34,39 @@ func main() {
}

ctx := context.Background()
cfg, err := pgx.ParseConfig(*db)
cfg, err := pgxpool.ParseConfig(*db)
if err != nil {
log.Fatalln("ERROR: invalid db url:", err)
}
cfg.RuntimeParams["client_encoding"] = "UTF8"
cfg.ConnConfig.RuntimeParams["client_encoding"] = "UTF8"

conn, err := pgx.ConnectConfig(ctx, cfg)
conn, err := pgxpool.NewWithConfig(ctx, cfg)
if err != nil {
log.Fatalln("ERROR: connect:", err)
}
defer conn.Close(ctx)
defer conn.Close()

dbtx := pgd.DBTX(conn)
if !*parallel {
tx, err := conn.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.Serializable,
AccessMode: pgx.ReadOnly,
DeferrableMode: pgx.Deferrable,
})
if err != nil {
log.Fatalln("ERROR: begin tx:", err)
}
defer func() {
err := tx.Commit(ctx)
if err != nil {
log.Fatalln("ERROR: commit tx:", err)
}
}()
}

var s *pgdump.Schema
if !*dataOnly {
s, err := pgdump.DumpSchema(ctx, conn)
s, err = pgdump.DumpSchema(ctx, dbtx)
if err != nil {
log.Fatalln("ERROR: dump data:", err)
}
Expand All @@ -59,7 +81,11 @@ func main() {
}

if !*schemaOnly {
err = pgdump.DumpData(ctx, conn, out, strings.Split(*skip, ","))
if *parallel {
err = pgdump.DumpDataWithSchemaParallel(ctx, conn, out, strings.Split(*skip, ","), s)
} else {
err = pgdump.DumpDataWithSchema(ctx, dbtx, out, strings.Split(*skip, ","), s)
}
if err != nil {
log.Fatalln("ERROR: dump data:", err)
}
Expand Down
231 changes: 96 additions & 135 deletions devtools/pgdump-lite/dumpdata.go
Original file line number Diff line number Diff line change
@@ -1,174 +1,135 @@
package pgdump

import (
"bufio"
"context"
"errors"
"fmt"
"io"
"sort"
"strings"
"slices"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/target/goalert/util/sqlutil"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/target/goalert/devtools/pgdump-lite/pgd"
)

func sortColumns(columns []string) {
// alphabetical, but with id first
sort.Slice(columns, func(i, j int) bool {
ci, cj := columns[i], columns[j]
if ci == cj {
return false
}
if ci == "id" {
return true
}
if cj == "id" {
return false
}

return ci < cj
})
}

func quoteNames(names []string) {
for i, n := range names {
names[i] = pgx.Identifier{n}.Sanitize()
}
type TableData struct {
TableName string
Columns []string
Rows [][]string
}

func queryStrings(ctx context.Context, tx pgx.Tx, sql string, args ...interface{}) ([]string, error) {
rows, err := tx.Query(ctx, sql, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []string
for rows.Next() {
var value string
err = rows.Scan(&value)
// DumpDataWithSchema will return all data from all tables (except those in skipTables) in a structured format.
func DumpDataWithSchema(ctx context.Context, conn pgd.DBTX, out io.Writer, skipTables []string, schema *Schema) error {
var err error
if schema == nil {
schema, err = DumpSchema(ctx, conn)
if err != nil {
return nil, err
return fmt.Errorf("dump schema: %w", err)
}
result = append(result, value)
}

return result, nil
}

type scannable string
for _, t := range schema.Tables {
if slices.Contains(skipTables, t.Name) {
continue
}

func (s *scannable) ScanText(v pgtype.Text) error {
if !v.Valid {
*s = "\\N"
} else {
*s = scannable(strings.ReplaceAll(v.String, "\\", "\\\\"))
err = dumpTableDataWith(ctx, conn, out, t.Name)
if err != nil {
return fmt.Errorf("dump table data: %w", err)
}
}

return nil
}

func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
// DumpDataWithSchema will return all data from all tables (except those in skipTables) in a structured format.
func DumpDataWithSchemaParallel(ctx context.Context, conn *pgxpool.Pool, out io.Writer, skipTables []string, schema *Schema) error {
var err error
if schema == nil {
schema, err = DumpSchema(ctx, conn)
if err != nil {
return fmt.Errorf("dump schema: %w", err)
}
}
return false
}

func DumpData(ctx context.Context, conn *pgx.Conn, out io.Writer, skip []string) error {
tx, err := conn.BeginTx(ctx, pgx.TxOptions{
IsoLevel: pgx.Serializable,
DeferrableMode: pgx.Deferrable,
AccessMode: pgx.ReadOnly,
})
if err != nil {
return fmt.Errorf("begin tx: %w", err)
type streamW struct {
name string
pw *io.PipeWriter
}
defer sqlutil.RollbackContext(ctx, "pgdump-lite: dump data", tx)

tables, err := queryStrings(ctx, tx, "select table_name from information_schema.tables where table_schema = 'public'")
if err != nil {
return fmt.Errorf("read tables: %w", err)
}
sort.Strings(tables)

for _, table := range tables {
if contains(skip, table) {
streams := make(chan io.Reader, len(schema.Tables))
inputs := make(chan streamW, len(schema.Tables))
for _, t := range schema.Tables {
if slices.Contains(skipTables, t.Name) {
continue
}
pr, pw := io.Pipe()
streams <- pr
inputs <- streamW{name: t.Name, pw: pw}

go func() {
err := conn.AcquireFunc(ctx, func(conn *pgxpool.Conn) error {
s := <-inputs
w := bufio.NewWriterSize(s.pw, 65535)
err := dumpTableData(ctx, conn.Conn(), w, s.name)
if err != nil {
return s.pw.CloseWithError(fmt.Errorf("dump table data: %w", err))
}
defer s.pw.Close()
return w.Flush()
})
if err != nil {
panic(err)
}
}()
}
close(streams)

columns, err := queryStrings(ctx, tx, "select column_name from information_schema.columns where table_schema = 'public' and table_name = $1 order by ordinal_position", table)
for r := range streams {
_, err = io.Copy(out, r)
if err != nil {
return fmt.Errorf("read columns for '%s': %w", table, err)
}
quoteNames(columns)

primaryCols, err := queryStrings(ctx, tx, `
select col.column_name
from information_schema.table_constraints tbl
join information_schema.constraint_column_usage col on
col.table_schema = 'public' and
col.constraint_name = tbl.constraint_name
where
tbl.table_schema = 'public' and
tbl.table_name = $1 and
constraint_type = 'PRIMARY KEY'
`, table)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("read primary key for '%s': %w", table, err)
return fmt.Errorf("write table data: %w", err)
}
sortColumns(primaryCols)
quoteNames(primaryCols)
}

colNames := strings.Join(columns, ", ")
orderBy := strings.Join(primaryCols, ",")
if orderBy == "" {
orderBy = colNames
}
return nil
}

fmt.Fprintf(out, "COPY %s (%s) FROM stdin;\n", pgx.Identifier{table}.Sanitize(), colNames)
rows, err := tx.Query(ctx,
fmt.Sprintf("select %s from %s order by %s",
colNames,
table,
orderBy,
),
pgx.QueryExecModeSimpleProtocol,
)
if err != nil {
return fmt.Errorf("read data on '%s': %w", table, err)
}
defer rows.Close()
vals := make([]interface{}, len(columns))
func dumpTableDataWith(ctx context.Context, db pgd.DBTX, out io.Writer, tableName string) error {
switch db := db.(type) {
case *pgx.Conn:
return dumpTableData(ctx, db, out, tableName)
case *pgxpool.Pool:
return db.AcquireFunc(ctx, func(conn *pgxpool.Conn) error {
return dumpTableData(ctx, conn.Conn(), out, tableName)
})
default:
return fmt.Errorf("unsupported DBTX type: %T", db)
}
}

for i := range vals {
vals[i] = new(scannable)
}
for rows.Next() {
err = rows.Scan(vals...)
if err != nil {
return fmt.Errorf("read data on '%s': %w", table, err)
}
for i, v := range vals {
if i > 0 {
if _, err := io.WriteString(out, "\t"); err != nil {
return err
}
}
if _, err := io.WriteString(out, string(*v.(*scannable))); err != nil {
return err
}
}
if _, err := io.WriteString(out, "\n"); err != nil {
return err
}
}
rows.Close()
func dumpTableData(ctx context.Context, conn *pgx.Conn, out io.Writer, tableName string) error {
_, err := fmt.Fprintf(out, "COPY %s FROM stdin WITH (FORMAT csv, HEADER MATCH, ENCODING utf8);\n", pgx.Identifier{tableName}.Sanitize())
if err != nil {
return fmt.Errorf("write COPY statement: %w", err)
}
_, err = conn.PgConn().CopyTo(ctx, out, fmt.Sprintf("COPY %s TO STDOUT WITH (FORMAT csv, HEADER true, ENCODING utf8)", pgx.Identifier{tableName}.Sanitize()))
if err != nil {
return fmt.Errorf("copy data: %w", err)
}
_, err = fmt.Fprintf(out, "\\.\n\n")
if err != nil {
return fmt.Errorf("write end of COPY: %w", err)
}

return nil
}

fmt.Fprintf(out, "\\.\n\n")
func DumpData(ctx context.Context, conn pgd.DBTX, out io.Writer, skipTables []string) error {
schema, err := DumpSchema(ctx, conn)
if err != nil {
return err
}

return tx.Commit(ctx)
return DumpDataWithSchema(ctx, conn, out, skipTables, schema)
}
Loading

0 comments on commit b0a10aa

Please sign in to comment.