Skip to content

Commit

Permalink
Addressing review
Browse files Browse the repository at this point in the history
  • Loading branch information
marco6 committed Jan 14, 2025
1 parent 13de1f1 commit fa549e3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
6 changes: 0 additions & 6 deletions pkg/database/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ var errDBClosed = errors.New("sql: database is closed")
type Interface interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)

PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)

BeginTx(ctx context.Context, opts *sql.TxOptions) (Transaction, error)
Conn(ctx context.Context) (*sql.Conn, error)
Close() error
Expand All @@ -22,9 +20,7 @@ type Interface interface {
type Transaction interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)

PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)

StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
Commit() error
Rollback() error
Expand All @@ -33,9 +29,7 @@ type Transaction interface {
type Wrapped[T Transaction] interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error)

PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)

BeginTx(ctx context.Context, opts *sql.TxOptions) (T, error)
Conn(ctx context.Context) (*sql.Conn, error)
Close() error
Expand Down
18 changes: 10 additions & 8 deletions pkg/database/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func init() {
type preparedDb[T Transaction] struct {
underlying Wrapped[T]
mu sync.RWMutex
store map[string]*sql.Stmt
cache map[string]*sql.Stmt
}

// NewPrepared creates a new Interface that wraps the given database and
Expand All @@ -32,7 +32,7 @@ type preparedDb[T Transaction] struct {
func NewPrepared[T Transaction](db Wrapped[T]) Interface {
return &preparedDb[T]{
underlying: db,
store: make(map[string]*sql.Stmt),
cache: make(map[string]*sql.Stmt),
}
}

Expand Down Expand Up @@ -76,7 +76,7 @@ func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.S
}()

db.mu.RLock()
stmt = db.store[query]
stmt = db.cache[query]
db.mu.RUnlock()
if stmt != nil {
return stmt, nil
Expand All @@ -89,8 +89,10 @@ func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.S
return nil, errDBClosed
}

// Check again if the query was prepared during locking
stmt = db.store[query]
// Given that some time has passed since the unlock of the read lock, and the lock of the
// write lock, another goroutine might have already prepared this query, so we should check
// again to avoid preparing the same query twice.
stmt = db.cache[query]
if stmt != nil {
return stmt, nil
}
Expand All @@ -99,7 +101,7 @@ func (db *preparedDb[T]) prepare(ctx context.Context, query string) (stmt *sql.S
if err != nil {
return nil, err
}
db.store[query] = prepared
db.cache[query] = prepared
return prepared, nil
}

Expand All @@ -124,12 +126,12 @@ func (db *preparedDb[T]) Close() error {
defer db.mu.Unlock()

errs := []error{}
for _, stmt := range db.store {
for _, stmt := range db.cache {
if err := stmt.Close(); err != nil {
errs = append(errs, err)
}
}
db.store = nil
db.cache = nil

if err := db.underlying.Close(); err != nil {
errs = append(errs, err)
Expand Down

0 comments on commit fa549e3

Please sign in to comment.