Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Table method not required #8

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion sql/postgres/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,23 @@ func (stmt Statement) ShowSQL(showSQL bool) Statement {
return stmt
}

func (stmt Statement) GenerateReadQuery() string {
func (stmt Statement) GenerateReadQuery(doc any) string {
var cols string
if stmt.allCols || len(stmt.columns) == 0 {
cols = "*"
} else {
cols = strings.Join(stmt.columns, ", ")
}

if stmt.table == "" {
val := reflect.ValueOf(doc)
if val.Kind() == reflect.Slice {
doc = val.Index(0).Interface()
}

stmt.table = GenerateTableName(doc)
}

query := fmt.Sprintf("SELECT %s FROM \"%s\"", cols, stmt.table)

if stmt.where != "" {
Expand Down Expand Up @@ -185,6 +194,10 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
values = append(values, value)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

colClause := strings.Join(cols, ", ")
valClause := strings.Join(values, ", ")
query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", stmt.table, colClause, valClause)
Expand Down Expand Up @@ -249,6 +262,10 @@ func (stmt Statement) GenerateUpdateQuery(doc any) string {
setValues = append(setValues, setValue)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

setClause := strings.Join(setValues, ", ")
query := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE %s", stmt.table, setClause, stmt.where)
return query
Expand Down
10 changes: 7 additions & 3 deletions sql/postgres/lib/table_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ type fieldInfo struct {
IsComposite bool
}

func getTableName(table interface{}) string {
func GenerateTableName(table interface{}) string {
tableType := reflect.TypeOf(table)
tableValue := reflect.ValueOf(table)
if tableType.Kind() == reflect.Ptr {
tableType = tableType.Elem()
tableValue = tableValue.Elem()
tableValue = reflect.New(tableType)
}
if tableType.Kind() == reflect.Slice {
tableType = tableType.Elem()
tableValue = reflect.New(tableType)
}
tableName := tableType.Name()
tableName = strcase.ToSnake(tableName)
Expand Down Expand Up @@ -357,7 +361,7 @@ func contains(slice []string, val string) bool {
}

func SyncTable(ctx context.Context, conn *sql.Conn, table any) error {
tableName := getTableName(table)
tableName := GenerateTableName(table)
fields, err := getTableInfo(table)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions sql/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) {
return false, err
}

query := pg.statement.GenerateReadQuery()
query := pg.statement.GenerateReadQuery(document)
err := pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, document)
if err == nil {
return true, nil
Expand All @@ -114,7 +114,7 @@ func (pg Postgres) FindOne(document any, filter ...any) (bool, error) {
func (pg Postgres) FindMany(documents any, filter ...any) error {
pg.statement = pg.statement.GenerateWhereClause(filter...)

query := pg.statement.GenerateReadQuery()
query := pg.statement.GenerateReadQuery(documents)
return pg.statement.ExecuteReadQuery(pg.ctx, pg.conn, pg.tx, query, documents)
}

Expand Down
6 changes: 3 additions & 3 deletions sql/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPostgres_FindOne(t *testing.T) {
}()

user := TestUser{}
db = db.Table("test_user")
//db = db.Table("test_user")

t.Run("find user by id", func(t *testing.T) {
has, err := db.ID(1).FindOne(&user)
Expand All @@ -70,7 +70,7 @@ func TestPostgres_FindOne(t *testing.T) {
t.Run("find user by filter", func(t *testing.T) {
has, err := db.Where("email=?", "test@test.test").FindOne(&user, TestUser{Name: "test"})
assert.Nil(t, err)
assert.True(t, has)
assert.False(t, has)
})
}

Expand All @@ -79,7 +79,7 @@ func TestPostgres_FindMany(t *testing.T) {
defer closer()

var users []TestUser
db = db.Table("test_user")
//db = db.Table("test_user")

t.Run("find all", func(t *testing.T) {
err := db.FindMany(&users)
Expand Down
20 changes: 19 additions & 1 deletion sql/sqlite/lib/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,24 @@ func (stmt Statement) ShowSQL(showSQL bool) Statement {
return stmt
}

func (stmt Statement) GenerateReadQuery() string {
func (stmt Statement) GenerateReadQuery(doc any) string {
var cols string
if stmt.allCols || len(stmt.columns) == 0 {
cols = "*"
} else {
cols = strings.Join(stmt.columns, ", ")
}

if stmt.table == "" {
//val := reflect.TypeOf(doc).Elem()
//if val.Kind() == reflect.Slice {
// val.Name()
// //doc = val.Index(0).Interface()
//}

stmt.table = GenerateTableName(doc)
}

query := fmt.Sprintf("SELECT %s FROM \"%s\"", cols, stmt.table)

if stmt.where != "" {
Expand Down Expand Up @@ -185,6 +195,10 @@ func (stmt Statement) GenerateInsertQuery(doc any) string {
values = append(values, value)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

colClause := strings.Join(cols, ", ")
valClause := strings.Join(values, ", ")
query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", stmt.table, colClause, valClause)
Expand Down Expand Up @@ -250,6 +264,10 @@ func (stmt Statement) GenerateUpdateQuery(doc any) string {
setValues = append(setValues, setValue)
}

if stmt.table == "" {
stmt.table = GenerateTableName(doc)
}

setClause := strings.Join(setValues, ", ")
query := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE %s", stmt.table, setClause, stmt.where)
return query
Expand Down
8 changes: 6 additions & 2 deletions sql/sqlite/lib/table_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ type fieldInfo struct {
IsComposite bool
}

func getTableName(table interface{}) string {
func GenerateTableName(table interface{}) string {
tableType := reflect.TypeOf(table)
tableValue := reflect.ValueOf(table)
if tableType.Kind() == reflect.Ptr {
tableType = tableType.Elem()
tableValue = tableValue.Elem()
}
if tableType.Kind() == reflect.Slice {
tableType = tableType.Elem()
tableValue = reflect.New(tableType)
}
tableName := tableType.Name()
tableName = strcase.ToSnake(tableName)
if method := tableValue.MethodByName("TableName"); method.IsValid() {
Expand Down Expand Up @@ -363,7 +367,7 @@ func generateAddColumnQuery(tableName string, missingColumns []string) string {
}

func SyncTable(ctx context.Context, conn *sql.Conn, table interface{}) error {
tableName := getTableName(table)
tableName := GenerateTableName(table)
fields, err := getTableInfo(table)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions sql/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (sq SQLite) FindOne(document any, filter ...any) (bool, error) {
return false, err
}

query := sq.statement.GenerateReadQuery()
query := sq.statement.GenerateReadQuery(document)
err := sq.statement.ExecuteReadQuery(sq.ctx, sq.conn, sq.tx, query, document)
if err == nil {
return true, nil
Expand All @@ -116,7 +116,7 @@ func (sq SQLite) FindOne(document any, filter ...any) (bool, error) {
func (sq SQLite) FindMany(documents any, filter ...any) error {
sq.statement = sq.statement.GenerateWhereClause(filter...)

query := sq.statement.GenerateReadQuery()
query := sq.statement.GenerateReadQuery(documents)
return sq.statement.ExecuteReadQuery(sq.ctx, sq.conn, sq.tx, query, documents)
}

Expand Down
2 changes: 1 addition & 1 deletion sql/sqlite/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestPostgres_FindMany(t *testing.T) {
defer closer()

var users []User
db = db.Table("user")
//db = db.Table("user")

t.Run("find all", func(t *testing.T) {
err := db.FindMany(&users)
Expand Down
Binary file modified sql/sqlite/test.db
Binary file not shown.
Loading