Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support retry #68

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions coredb/engine_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,63 @@ package coredb
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"strings"
"time"
)

type IsNonRetryableErrorFunc func(err error) bool

// RetryConfig encapsulates retry parameters.
type RetryConfig struct {
MaxRetries int
InitialBackoff time.Duration
IsNonRetryableErrorFunc IsNonRetryableErrorFunc
}

// DefaultRetryConfig provides a reasonable default configuration
var DefaultRetryConfig = RetryConfig{
MaxRetries: 5,
InitialBackoff: 200 * time.Millisecond,
IsNonRetryableErrorFunc: IsNonRetryableError,
}

// IsNonRetryableError checks if an error is non-retryable.
func IsNonRetryableError(err error) bool {
if err == nil {
return false
}
// Example (Replace with your database's non-retryable errors)

// SQL specific errors that are not retryable
if errors.Is(err, sql.ErrNoRows) {
return true
}

// Example: Invalid SQL syntax
if strings.Contains(err.Error(), "syntax error") {
return true
}

if strings.Contains(err.Error(), "1146") { // Table doesn't exists
return true
}
if strings.Contains(err.Error(), "1064") { // No database selected
return true
}
if strings.Contains(err.Error(), "1149") { // Invalid SQL statement
return true
}
// Example: Authentication issues
if strings.Contains(err.Error(), "Access denied") {
return true
}

return false // Default is retryable
}

// FetchByPKCtx returns a row of T type with given primary key value
func FetchByPKCtx[T any](ctx context.Context, dbname string, tableName string, pkName []string, val ...any) (*T, error) {
sql := "WHERE `" + pkName[0] + "` = ?"
Expand Down Expand Up @@ -57,6 +110,56 @@ func ExecCtx(ctx context.Context, dbname string, query string, params ...any) (s
return mydb.ExecContext(ctx, query, params...)
}

// ExecWithRetry executes a query with retry logic on failure.
func ExecWithRetry(ctx context.Context, dbname string, query string, retryConfig RetryConfig, params ...any) (sql.Result, error) {
// Set defaults for invalid config
if retryConfig.MaxRetries <= 0 {
retryConfig.MaxRetries = DefaultRetryConfig.MaxRetries
}

if retryConfig.InitialBackoff <= 0 {
retryConfig.InitialBackoff = DefaultRetryConfig.InitialBackoff
}

// Use the default if NonRetryableErrorFunc is nil
nonRetryableErrorFunc := retryConfig.IsNonRetryableErrorFunc
if nonRetryableErrorFunc == nil {
nonRetryableErrorFunc = IsNonRetryableError
}

var result sql.Result
var err error
retryCount := 0
currentBackoff := retryConfig.InitialBackoff

for {
select {
case <-ctx.Done():
return result, fmt.Errorf("context cancelled during retry: %w", ctx.Err())
default:
result, err = ExecCtx(ctx, dbname, query, params...)
if err == nil {
return result, nil // Success!
}

if nonRetryableErrorFunc(err) {
return result, err // Fail immediately for non-retryable errors
}

retryCount++
if retryCount > retryConfig.MaxRetries {
log.Printf("Max retries (%d) exceeded for: %s, last error: %v", retryConfig.MaxRetries, query, err)
return result, fmt.Errorf("max retries exceeded, last error: %w", err)
}

delay := currentBackoff
log.Printf("Retrying attempt %d with delay %v. Last error: %v", retryCount, delay, err)
time.Sleep(delay)
currentBackoff *= 2
}
}
}

// FindOneCtx returns a row from given table type with where query.
// If no rows found, *T will be nil. No error will be returned.
func FindOneCtx[T any](ctx context.Context, dbname string, tableName string, where WhereQuery) (*T, error) {
Expand Down
50 changes: 50 additions & 0 deletions coredb/txengine/tx_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,56 @@ func RunTransaction(ctx context.Context, dbName string, fn func(ctx context.Cont
return
}

// RunTxWithRetry runs a transaction with retry logic on failure.
func RunTxWithRetry(ctx context.Context, dbName string, retryConfig coredb.RetryConfig, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
// Set defaults for invalid config
if retryConfig.MaxRetries <= 0 {
retryConfig.MaxRetries = coredb.DefaultRetryConfig.MaxRetries
}

if retryConfig.InitialBackoff <= 0 {
retryConfig.InitialBackoff = coredb.DefaultRetryConfig.InitialBackoff
}

nonRetryableErrorFunc := retryConfig.IsNonRetryableErrorFunc
if nonRetryableErrorFunc == nil {
nonRetryableErrorFunc = coredb.IsNonRetryableError
}

var resultErr error
retryCount := 0
currentBackoff := retryConfig.InitialBackoff

for {
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled during retry: %w", ctx.Err())
default:
resultErr = RunTransaction(ctx, dbName, fn)
if resultErr == nil {
return nil // Success!
}

if nonRetryableErrorFunc(resultErr) {
log.Printf("Non-retryable error: %v", resultErr)
return resultErr // Fail immediately for non-retryable errors
}

retryCount++
if retryCount > retryConfig.MaxRetries {
log.Printf("Max retries (%d) exceeded, last error: %v", retryConfig.MaxRetries, resultErr)
return fmt.Errorf("max retries exceeded, last error: %w", resultErr)

}

delay := currentBackoff
log.Printf("Retrying attempt %d with delay %v. Last error: %v", retryCount, delay, resultErr)
time.Sleep(delay)
currentBackoff *= 2
}
}
}

func runTransaction(ctx context.Context, tx *sql.Tx, conn *sql.Conn, fn func(ctx context.Context, sqlTx *sql.Tx) error) (err error) {
if tx == nil && conn == nil {
return errors.New("wrong usage. tx and conn cannot both be nil")
Expand Down
33 changes: 33 additions & 0 deletions tests/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package tests

import (
"context"
"testing"
"time"

_ "github.com/go-sql-driver/mysql"
"github.com/olachat/gola/v2/coredb"
)

func TestExecWithRetry_Success(t *testing.T) {
ctx := context.Background()

_, err := coredb.ExecWithRetry(ctx, testDBName, "INSERT INTO test_table (name, email) VALUES (?, ?)", coredb.DefaultRetryConfig, "test", "test@example.com")
if err != nil {
t.Fatalf("Expected success, but got error: %v", err)
}
}

func TestExecWithRetry_Fail(t *testing.T) {
ctx := context.Background()
now := time.Now()
_, err := coredb.ExecWithRetry(ctx, testDBName, "INSERT INTO no_such_table (name, email) VALUES (?, ?)", coredb.DefaultRetryConfig, "test", "test@example.com")
if err == nil {
t.Fatalf("Expected error, but got success")
}
elapsed := time.Since(now)
if elapsed < (200+400+800+1600+3200)*time.Millisecond {
t.Fatalf("Expected retry to take at least 100ms, but took: %v", elapsed)
}
t.Logf("retry took: %v", elapsed)
}
16 changes: 9 additions & 7 deletions tests/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ func init() {
panic(err)
}

// realdb, err := open()

// if err != nil {
// panic(err)
// }

coredb.Setup(func(dbname string, mode coredb.DBMode) *sql.DB {
return db
if dbname == testDBName {
return db
}
return nil
})

_, err = db.Exec("CREATE TABLE IF NOT EXISTS test_table (name VARCHAR(255), email VARCHAR(255))")
if err != nil {
panic(err)
}

// create tables
for _, tableName := range tableNames {
query, _ := testdata.Fixtures.ReadFile(tableName + ".sql")
Expand Down
Loading