-
Notifications
You must be signed in to change notification settings - Fork 253
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use pool for pgdump-lite, csv output for COPY (#3770)
- Loading branch information
1 parent
b0e642b
commit b0a10aa
Showing
6 changed files
with
144 additions
and
149 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,174 +1,135 @@ | ||
package pgdump | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"sort" | ||
"strings" | ||
"slices" | ||
|
||
"github.com/jackc/pgx/v5" | ||
"github.com/jackc/pgx/v5/pgtype" | ||
"github.com/target/goalert/util/sqlutil" | ||
"github.com/jackc/pgx/v5/pgxpool" | ||
"github.com/target/goalert/devtools/pgdump-lite/pgd" | ||
) | ||
|
||
func sortColumns(columns []string) { | ||
// alphabetical, but with id first | ||
sort.Slice(columns, func(i, j int) bool { | ||
ci, cj := columns[i], columns[j] | ||
if ci == cj { | ||
return false | ||
} | ||
if ci == "id" { | ||
return true | ||
} | ||
if cj == "id" { | ||
return false | ||
} | ||
|
||
return ci < cj | ||
}) | ||
} | ||
|
||
func quoteNames(names []string) { | ||
for i, n := range names { | ||
names[i] = pgx.Identifier{n}.Sanitize() | ||
} | ||
type TableData struct { | ||
TableName string | ||
Columns []string | ||
Rows [][]string | ||
} | ||
|
||
func queryStrings(ctx context.Context, tx pgx.Tx, sql string, args ...interface{}) ([]string, error) { | ||
rows, err := tx.Query(ctx, sql, args...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer rows.Close() | ||
var result []string | ||
for rows.Next() { | ||
var value string | ||
err = rows.Scan(&value) | ||
// DumpDataWithSchema will return all data from all tables (except those in skipTables) in a structured format. | ||
func DumpDataWithSchema(ctx context.Context, conn pgd.DBTX, out io.Writer, skipTables []string, schema *Schema) error { | ||
var err error | ||
if schema == nil { | ||
schema, err = DumpSchema(ctx, conn) | ||
if err != nil { | ||
return nil, err | ||
return fmt.Errorf("dump schema: %w", err) | ||
} | ||
result = append(result, value) | ||
} | ||
|
||
return result, nil | ||
} | ||
|
||
type scannable string | ||
for _, t := range schema.Tables { | ||
if slices.Contains(skipTables, t.Name) { | ||
continue | ||
} | ||
|
||
func (s *scannable) ScanText(v pgtype.Text) error { | ||
if !v.Valid { | ||
*s = "\\N" | ||
} else { | ||
*s = scannable(strings.ReplaceAll(v.String, "\\", "\\\\")) | ||
err = dumpTableDataWith(ctx, conn, out, t.Name) | ||
if err != nil { | ||
return fmt.Errorf("dump table data: %w", err) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func contains(s []string, e string) bool { | ||
for _, a := range s { | ||
if a == e { | ||
return true | ||
// DumpDataWithSchema will return all data from all tables (except those in skipTables) in a structured format. | ||
func DumpDataWithSchemaParallel(ctx context.Context, conn *pgxpool.Pool, out io.Writer, skipTables []string, schema *Schema) error { | ||
var err error | ||
if schema == nil { | ||
schema, err = DumpSchema(ctx, conn) | ||
if err != nil { | ||
return fmt.Errorf("dump schema: %w", err) | ||
} | ||
} | ||
return false | ||
} | ||
|
||
func DumpData(ctx context.Context, conn *pgx.Conn, out io.Writer, skip []string) error { | ||
tx, err := conn.BeginTx(ctx, pgx.TxOptions{ | ||
IsoLevel: pgx.Serializable, | ||
DeferrableMode: pgx.Deferrable, | ||
AccessMode: pgx.ReadOnly, | ||
}) | ||
if err != nil { | ||
return fmt.Errorf("begin tx: %w", err) | ||
type streamW struct { | ||
name string | ||
pw *io.PipeWriter | ||
} | ||
defer sqlutil.RollbackContext(ctx, "pgdump-lite: dump data", tx) | ||
|
||
tables, err := queryStrings(ctx, tx, "select table_name from information_schema.tables where table_schema = 'public'") | ||
if err != nil { | ||
return fmt.Errorf("read tables: %w", err) | ||
} | ||
sort.Strings(tables) | ||
|
||
for _, table := range tables { | ||
if contains(skip, table) { | ||
streams := make(chan io.Reader, len(schema.Tables)) | ||
inputs := make(chan streamW, len(schema.Tables)) | ||
for _, t := range schema.Tables { | ||
if slices.Contains(skipTables, t.Name) { | ||
continue | ||
} | ||
pr, pw := io.Pipe() | ||
streams <- pr | ||
inputs <- streamW{name: t.Name, pw: pw} | ||
|
||
go func() { | ||
err := conn.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { | ||
s := <-inputs | ||
w := bufio.NewWriterSize(s.pw, 65535) | ||
err := dumpTableData(ctx, conn.Conn(), w, s.name) | ||
if err != nil { | ||
return s.pw.CloseWithError(fmt.Errorf("dump table data: %w", err)) | ||
} | ||
defer s.pw.Close() | ||
return w.Flush() | ||
}) | ||
if err != nil { | ||
panic(err) | ||
} | ||
}() | ||
} | ||
close(streams) | ||
|
||
columns, err := queryStrings(ctx, tx, "select column_name from information_schema.columns where table_schema = 'public' and table_name = $1 order by ordinal_position", table) | ||
for r := range streams { | ||
_, err = io.Copy(out, r) | ||
if err != nil { | ||
return fmt.Errorf("read columns for '%s': %w", table, err) | ||
} | ||
quoteNames(columns) | ||
|
||
primaryCols, err := queryStrings(ctx, tx, ` | ||
select col.column_name | ||
from information_schema.table_constraints tbl | ||
join information_schema.constraint_column_usage col on | ||
col.table_schema = 'public' and | ||
col.constraint_name = tbl.constraint_name | ||
where | ||
tbl.table_schema = 'public' and | ||
tbl.table_name = $1 and | ||
constraint_type = 'PRIMARY KEY' | ||
`, table) | ||
if err != nil && !errors.Is(err, pgx.ErrNoRows) { | ||
return fmt.Errorf("read primary key for '%s': %w", table, err) | ||
return fmt.Errorf("write table data: %w", err) | ||
} | ||
sortColumns(primaryCols) | ||
quoteNames(primaryCols) | ||
} | ||
|
||
colNames := strings.Join(columns, ", ") | ||
orderBy := strings.Join(primaryCols, ",") | ||
if orderBy == "" { | ||
orderBy = colNames | ||
} | ||
return nil | ||
} | ||
|
||
fmt.Fprintf(out, "COPY %s (%s) FROM stdin;\n", pgx.Identifier{table}.Sanitize(), colNames) | ||
rows, err := tx.Query(ctx, | ||
fmt.Sprintf("select %s from %s order by %s", | ||
colNames, | ||
table, | ||
orderBy, | ||
), | ||
pgx.QueryExecModeSimpleProtocol, | ||
) | ||
if err != nil { | ||
return fmt.Errorf("read data on '%s': %w", table, err) | ||
} | ||
defer rows.Close() | ||
vals := make([]interface{}, len(columns)) | ||
func dumpTableDataWith(ctx context.Context, db pgd.DBTX, out io.Writer, tableName string) error { | ||
switch db := db.(type) { | ||
case *pgx.Conn: | ||
return dumpTableData(ctx, db, out, tableName) | ||
case *pgxpool.Pool: | ||
return db.AcquireFunc(ctx, func(conn *pgxpool.Conn) error { | ||
return dumpTableData(ctx, conn.Conn(), out, tableName) | ||
}) | ||
default: | ||
return fmt.Errorf("unsupported DBTX type: %T", db) | ||
} | ||
} | ||
|
||
for i := range vals { | ||
vals[i] = new(scannable) | ||
} | ||
for rows.Next() { | ||
err = rows.Scan(vals...) | ||
if err != nil { | ||
return fmt.Errorf("read data on '%s': %w", table, err) | ||
} | ||
for i, v := range vals { | ||
if i > 0 { | ||
if _, err := io.WriteString(out, "\t"); err != nil { | ||
return err | ||
} | ||
} | ||
if _, err := io.WriteString(out, string(*v.(*scannable))); err != nil { | ||
return err | ||
} | ||
} | ||
if _, err := io.WriteString(out, "\n"); err != nil { | ||
return err | ||
} | ||
} | ||
rows.Close() | ||
func dumpTableData(ctx context.Context, conn *pgx.Conn, out io.Writer, tableName string) error { | ||
_, err := fmt.Fprintf(out, "COPY %s FROM stdin WITH (FORMAT csv, HEADER MATCH, ENCODING utf8);\n", pgx.Identifier{tableName}.Sanitize()) | ||
if err != nil { | ||
return fmt.Errorf("write COPY statement: %w", err) | ||
} | ||
_, err = conn.PgConn().CopyTo(ctx, out, fmt.Sprintf("COPY %s TO STDOUT WITH (FORMAT csv, HEADER true, ENCODING utf8)", pgx.Identifier{tableName}.Sanitize())) | ||
if err != nil { | ||
return fmt.Errorf("copy data: %w", err) | ||
} | ||
_, err = fmt.Fprintf(out, "\\.\n\n") | ||
if err != nil { | ||
return fmt.Errorf("write end of COPY: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
fmt.Fprintf(out, "\\.\n\n") | ||
func DumpData(ctx context.Context, conn pgd.DBTX, out io.Writer, skipTables []string) error { | ||
schema, err := DumpSchema(ctx, conn) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return tx.Commit(ctx) | ||
return DumpDataWithSchema(ctx, conn, out, skipTables, schema) | ||
} |
Oops, something went wrong.