From a2b928eaeb65f19ce78c150cc8fa511dccfe81fd Mon Sep 17 00:00:00 2001 From: Grant Zvolsky Date: Thu, 12 May 2022 12:29:17 +0200 Subject: [PATCH] feat: replace hydra's transaction impl with ory/popx/transaction --- hsm/manager_hsm_test.go | 3 +- persistence/sql/persister.go | 30 ++-- persistence/sql/transaction/transaction.go | 70 ---------- .../sql/transaction/transaction_test.go | 128 ------------------ 4 files changed, 16 insertions(+), 215 deletions(-) delete mode 100644 persistence/sql/transaction/transaction.go delete mode 100644 persistence/sql/transaction/transaction_test.go diff --git a/hsm/manager_hsm_test.go b/hsm/manager_hsm_test.go index ab7d7ec859c..0a80d21a714 100644 --- a/hsm/manager_hsm_test.go +++ b/hsm/manager_hsm_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/ory/hydra/jwk" + "github.com/ory/hydra/x/contextx" "github.com/ory/hydra/driver" "github.com/ory/hydra/driver/config" @@ -47,7 +48,7 @@ func TestDefaultKeyManager_HsmEnabled(t *testing.T) { reg.WithLogger(l) reg.WithConfig(c) reg.WithHsmContext(mockHsmContext) - err := reg.Init(context.Background()) + err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}) assert.NoError(t, err) assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager()) assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager()) diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 1fb543ad771..84e30a6c0ec 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -6,9 +6,9 @@ import ( "reflect" "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/x/randx" "github.com/gofrs/uuid" - "github.com/gobuffalo/x/randx" "github.com/pkg/errors" "github.com/ory/fosite" @@ -16,7 +16,6 @@ import ( "github.com/ory/hydra/driver/config" "github.com/ory/hydra/jwk" "github.com/ory/hydra/persistence" - "github.com/ory/hydra/persistence/sql/transaction" "github.com/ory/hydra/x" "github.com/ory/hydra/x/contextx" "github.com/ory/x/errorsx" @@ -54,8 +53,8 @@ type ( ) func (p *Persister) BeginTX(ctx context.Context) (context.Context, error) { - _, ok := ctx.Value(transaction.TransactionContextKey).(*pop.Connection) - if ok { + fallback := &pop.Connection{TX: &pop.Tx{}} + if popx.GetConnection(ctx, fallback).TX != fallback.TX { return ctx, errorsx.WithStack(ErrTransactionOpen) } @@ -69,25 +68,27 @@ func (p *Persister) BeginTX(ctx context.Context) (context.Context, error) { ID: randx.String(30), Dialect: p.conn.Dialect, } - return context.WithValue(ctx, transaction.TransactionContextKey, c), err + return popx.WithTransaction(ctx, c), err } func (p *Persister) Commit(ctx context.Context) error { - c, ok := ctx.Value(transaction.TransactionContextKey).(*pop.Connection) - if !ok || c.TX == nil { + fallback := &pop.Connection{TX: &pop.Tx{}} + tx := popx.GetConnection(ctx, fallback) + if tx.TX == fallback.TX || tx.TX == nil { return errorsx.WithStack(ErrNoTransactionOpen) } - return errorsx.WithStack(c.TX.Commit()) + return errorsx.WithStack(tx.TX.Commit()) } func (p *Persister) Rollback(ctx context.Context) error { - c, ok := ctx.Value(transaction.TransactionContextKey).(*pop.Connection) - if !ok || c.TX == nil { + fallback := &pop.Connection{TX: &pop.Tx{}} + tx := popx.GetConnection(ctx, fallback) + if tx.TX == fallback.TX || tx.TX == nil { return errorsx.WithStack(ErrNoTransactionOpen) } - return errorsx.WithStack(c.TX.Rollback()) + return errorsx.WithStack(tx.TX.Rollback()) } func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.Provider, l *logrusx.Logger) (*Persister, error) { @@ -142,10 +143,7 @@ func (p *Persister) QueryWithNetwork(ctx context.Context) *pop.Query { } func (p *Persister) Connection(ctx context.Context) *pop.Connection { - if c, ok := ctx.Value(transaction.TransactionContextKey).(*pop.Connection); ok { - return c.WithContext(ctx) - } - return p.conn.WithContext(ctx) + return popx.GetConnection(ctx, p.conn) } func (p *Persister) mustSetNetwork(nid uuid.UUID, v interface{}) interface{} { @@ -163,5 +161,5 @@ func (p *Persister) mustSetNetwork(nid uuid.UUID, v interface{}) interface{} { } func (p *Persister) transaction(ctx context.Context, f func(ctx context.Context, c *pop.Connection) error) error { - return transaction.Transaction(ctx, p.conn, f) + return popx.Transaction(ctx, p.conn, f) } diff --git a/persistence/sql/transaction/transaction.go b/persistence/sql/transaction/transaction.go deleted file mode 100644 index a0688946d3a..00000000000 --- a/persistence/sql/transaction/transaction.go +++ /dev/null @@ -1,70 +0,0 @@ -package transaction - -import ( - "context" - - "github.com/cockroachdb/cockroach-go/v2/crdb" - "github.com/gobuffalo/pop/v6" - "github.com/jmoiron/sqlx" - "github.com/ory/x/errorsx" -) - -type transactionContextType string - -const TransactionContextKey transactionContextType = "transactionConnection" - -func Transaction(ctx context.Context, conn *pop.Connection, f func(ctx context.Context, c *pop.Connection) error) error { - isNested := true - c, ok := ctx.Value(TransactionContextKey).(*pop.Connection) - if !ok { - isNested = false - - var err error - c, err = conn.WithContext(ctx).NewTransaction() - - if err != nil { - return errorsx.WithStack(err) - } - } - - if !isNested && c.Dialect.Name() == "cockroach" { // Only retry the outer closure of cockroach transactions - return crdb.ExecuteInTx(ctx, sqlxTxAdapter{c.TX.Tx}, func() error { - return f(context.WithValue(ctx, TransactionContextKey, c), c) - }) - } else { - if err := f(context.WithValue(ctx, TransactionContextKey, c), c); err != nil { - if !isNested { - if err := c.TX.Rollback(); err != nil { - return errorsx.WithStack(err) - } - } - return err - } - - // commit if there is no wrapping transaction - if !isNested { - return errorsx.WithStack(c.TX.Commit()) - } - } - - return nil -} - -type sqlxTxAdapter struct { - *sqlx.Tx -} - -var _ crdb.Tx = sqlxTxAdapter{} - -func (s sqlxTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) error { - _, err := s.Tx.ExecContext(ctx, query, args...) - return err -} - -func (s sqlxTxAdapter) Commit(ctx context.Context) error { - return s.Tx.Commit() -} - -func (s sqlxTxAdapter) Rollback(ctx context.Context) error { - return s.Tx.Rollback() -} diff --git a/persistence/sql/transaction/transaction_test.go b/persistence/sql/transaction/transaction_test.go deleted file mode 100644 index 00582a294fa..00000000000 --- a/persistence/sql/transaction/transaction_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package transaction - -import ( - "context" - "fmt" - "runtime" - "testing" - - "github.com/cockroachdb/cockroach-go/v2/crdb" - "github.com/cockroachdb/cockroach-go/v2/testserver" - "github.com/gobuffalo/pop/v6" - "github.com/stretchr/testify/require" - - "github.com/ory/x/sqlcon" -) - -func newDB(t *testing.T) *pop.Connection { - if runtime.GOOS == "windows" { - t.Skip("CockroachDB test suite does not support windows") - } - - ts, err := testserver.NewTestServer() - require.NoError(t, err) - - dsn := ts.PGURL() - dsn.Scheme = "cockroach:" - q := dsn.Query() - q.Set("search_path", "d,public") - dsn.RawQuery = q.Encode() - - c, err := pop.NewConnection(&pop.ConnectionDetails{URL: dsn.String()}) - require.NoError(t, err) - require.NoError(t, c.Open()) - return c -} - -func TestCockroachTransactionRetryExpectedFailure(t *testing.T) { - c := newDB(t) - require.Error(t, crdb.ExecuteTxGenericTest(context.Background(), popWriteSkewTest{c: c, t: t})) -} - -func TestCockroachTransactionRetrySuccess(t *testing.T) { - c := newDB(t) - require.NoError(t, crdb.ExecuteTxGenericTest(context.Background(), hydraWriteSkewTest{c: c, popWriteSkewTest: popWriteSkewTest{c: c, t: t}})) -} - -type table struct { - ID int `db:"id"` - Balance int `db:"balance"` -} - -func (t table) TableName() string { - return "t" -} - -type popWriteSkewTest struct { - t *testing.T - c *pop.Connection -} - -type hydraWriteSkewTest struct { - popWriteSkewTest - c *pop.Connection -} - -var _ crdb.WriteSkewTest = popWriteSkewTest{} -var _ crdb.WriteSkewTest = hydraWriteSkewTest{} - -// ExecuteTx is part of the crdb.WriteSkewTest interface. -func (t hydraWriteSkewTest) ExecuteTx(ctx context.Context, fn func(tx interface{}) error) error { - return Transaction(ctx, t.c, func(ctx context.Context, tx *pop.Connection) error { - return fn(tx.WithContext(ctx)) - }) -} - -func (t popWriteSkewTest) Init(ctx context.Context) error { - for _, s := range []string{ - "CREATE DATABASE d", - "CREATE TABLE d.t (id INT PRIMARY KEY, balance INT)", - "USE d", - "INSERT INTO d.t (id, balance) VALUES (1, 100), (2, 100)", - } { - if err := t.c.RawQuery(s).Exec(); err != nil { - return err - } - } - - return nil -} - -// ExecuteTx is part of the crdb.WriteSkewTest interface. -func (t popWriteSkewTest) ExecuteTx(ctx context.Context, fn func(tx interface{}) error) error { - fmt.Printf("entering...\n") - return t.c.Transaction(func(tx *pop.Connection) error { - return fn(tx) - }) -} - -// GetBalances is part of the crdb.WriteSkewTest interface. -func (t popWriteSkewTest) GetBalances(ctx context.Context, txi interface{}) (int, int, error) { - tx := txi.(*pop.Connection).WithContext(ctx) - var tables []table - - err := tx.RawQuery(`SELECT * FROM d.t WHERE id IN (1, 2);`).All(&tables) - if err != nil { - return 0, 0, sqlcon.HandleError(err) - } - - if len(tables) != 2 { - err := fmt.Errorf("expected two balances; got %d", len(tables)) - t.t.Logf("Got error: %+v", err) - return 0, 0, err - } - return tables[0].Balance, tables[1].Balance, nil -} - -// UpdateBalance is part of the crdb.WriteSkewInterface. -func (t popWriteSkewTest) UpdateBalance( - ctx context.Context, txi interface{}, acct, delta int, -) error { - tx := txi.(*pop.Connection).WithContext(ctx) - err := tx.RawQuery(`UPDATE d.t SET balance=balance+$1 WHERE id=$2;`, delta, acct).Exec() - t.t.Logf("Got error: %+v", err) - if err != nil { - return err - } - return nil -}