diff --git a/tempdb/create.go b/tempdb/create.go index 2d273c3..ef80467 100644 --- a/tempdb/create.go +++ b/tempdb/create.go @@ -15,24 +15,23 @@ type Querier interface { DeleteDatabase(ctx context.Context, name string) error } -func New(dbUrl string) (Querier, error) { - parsedUrl, err := url.Parse(dbUrl) +func New(dbURL string) (Querier, error) { + parsedURL, err := url.Parse(dbURL) if err != nil { return nil, err } - querier, err := querierFromUrl(parsedUrl) + querier, err := querierFromURL(parsedURL) if err != nil { return nil, fmt.Errorf("cannot create querier: %w", err) } return querier, nil - } -func querierFromUrl(url *url.URL) (Querier, error) { +func querierFromURL(url *url.URL) (Querier, error) { dialect := url.Scheme switch dialect { - case "postgresql": + case "postgresql", "postgres": q, err := postgres.New(url) if err != nil { return nil, fmt.Errorf("cannot create postgres: %w", err) diff --git a/tempdb/create_test.go b/tempdb/create_test.go index c50a97c..8568bc6 100644 --- a/tempdb/create_test.go +++ b/tempdb/create_test.go @@ -3,10 +3,10 @@ package tempdb import ( "testing" - "github.com/Alviner/drillerfy/tempdb/postgres" - "github.com/Alviner/drillerfy/utils_test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/Alviner/drillerfy/tempdb/postgres" ) func TestCreate(t *testing.T) { @@ -15,17 +15,25 @@ func TestCreate(t *testing.T) { require := require.New(t) assert := assert.New(t) + t.Run("postgresql", func(t *testing.T) { + t.Parallel() + // act + db, err := New("postgresql://pguser:pgpass@localhost:5432/pgdb") + require.NoError(err) + // assert + assert.IsType(new(postgres.Postgres), db) + }) t.Run("postgres", func(t *testing.T) { t.Parallel() - //act - db, err := New(utils_test.PostgresDNS(t)) + // act + db, err := New("postgres://pguser:pgpass@localhost:5432/pgdb") require.NoError(err) // assert assert.IsType(new(postgres.Postgres), db) }) t.Run("unknown", func(t *testing.T) { t.Parallel() - //act + // act _, err := New("unknown://database.url") // assert require.EqualErrorf( @@ -35,5 +43,4 @@ func TestCreate(t *testing.T) { "unknown", ) }) - } diff --git a/tempdb/postgres/postgres_test.go b/tempdb/postgres/postgres_test.go index 3e2e5d4..9e40aed 100644 --- a/tempdb/postgres/postgres_test.go +++ b/tempdb/postgres/postgres_test.go @@ -7,10 +7,11 @@ import ( "testing" "time" - "github.com/Alviner/drillerfy/utils_test" - "github.com/Alviner/drillerfy/utils_test/postgres" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/Alviner/drillerfy/utils_test" + "github.com/Alviner/drillerfy/utils_test/postgres" ) func TestPostgres(t *testing.T) { @@ -19,7 +20,7 @@ func TestPostgres(t *testing.T) { require := require.New(t) assert := assert.New(t) - dbUrl, err := url.Parse(utils_test.PostgresDNS(t)) + dbURL, err := url.Parse(utils_test.PostgresDNS(t)) require.NoError(err) t.Run("create", func(t *testing.T) { @@ -29,13 +30,13 @@ func TestPostgres(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 2*time.Second) defer done() - pg, err := New(dbUrl) + pg, err := New(dbURL) require.NoError(err) - //act + // act require.NoError(pg.CreateDatabase(ctx, dbName, "")) defer func() { require.NoError(pg.DeleteDatabase(ctx, dbName)) }() // assert - names, err := postgres.DBNames(t, dbUrl) + names, err := postgres.DBNames(t, dbURL) require.NoError(err) assert.Contains(names, dbName) }) @@ -47,17 +48,17 @@ func TestPostgres(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 5*time.Second) defer done() - pg, err := New(dbUrl) + pg, err := New(dbURL) require.NoError(err) require.NoError(pg.CreateDatabase(ctx, templateName, "")) defer func() { require.NoError(pg.DeleteDatabase(ctx, templateName)) }() - //act + // act require.NoError(pg.CreateDatabase(ctx, dbName, templateName)) defer func() { require.NoError(pg.DeleteDatabase(ctx, dbName)) }() // assert - names, err := postgres.DBNames(t, dbUrl) + names, err := postgres.DBNames(t, dbURL) require.NoError(err) assert.Contains(names, dbName) }) @@ -69,15 +70,15 @@ func TestPostgres(t *testing.T) { ctx, done := context.WithTimeout(context.Background(), 5*time.Second) defer done() - pg, err := New(dbUrl) + pg, err := New(dbURL) require.NoError(err) require.NoError(pg.CreateDatabase(ctx, dbName, "")) - //act + // act require.NoError(pg.DeleteDatabase(ctx, dbName)) // assert - names, err := postgres.DBNames(t, dbUrl) + names, err := postgres.DBNames(t, dbURL) require.NoError(err) assert.NotContains(names, dbName) })