Skip to content

Commit

Permalink
feat: add struct field validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tauslim committed Feb 27, 2024
1 parent a65ca09 commit 9ad9f90
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 41 deletions.
10 changes: 1 addition & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Config struct {
//
// In order to create a parser for the given resource, you will do it like so:
//
// var QueryParser = rql.MustNewParser(
// var QueryParser = rql.NewParser(&rql.Config{
// Model: User{},
// })
//
Expand All @@ -45,14 +45,6 @@ type Config struct {
//
// We assume the schema for this struct contains a column named "address_city". Therefore, the default
// separator is underscore ("_"). But, you can change it to "." for convenience or readability reasons.
// Then you will be able to query your resource like this:
//
// {
// "filter": {
// "address.city": "DC"
// }
// }
//
// The parser will automatically convert it to underscore ("_"). If you want to control the name of
// the column, use the "column" option in the struct definition. For example:
//
Expand Down
22 changes: 21 additions & 1 deletion converter.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
package gorql

import (
"strconv"
"time"
)

// convert float to int.
func convertInt(v interface{}) interface{} {
return int(v.(float64))
if s, err := strconv.Atoi(v.(string)); err == nil {
return s
}
return v
}

// convert string to float.
func convertFloat(v interface{}) interface{} {
if s, err := strconv.ParseFloat(v.(string), 64); err == nil {
return s
}
return v
}

// convert string to time object.
Expand All @@ -17,6 +29,14 @@ func convertTime(layout string) func(interface{}) interface{} {
}
}

// convert string to bool.
func convertBool(v interface{}) interface{} {
if s, err := strconv.ParseBool(v.(string)); err == nil {
return s
}
return v
}

// nop converter.
func valueFn(v interface{}) interface{} {
return v
Expand Down
51 changes: 21 additions & 30 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io"
"reflect"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -88,22 +87,6 @@ func (r *RqlRootNode) parseSpecialOps() {
}
}

func (r *RqlRootNode) validateSpecialOps() error {
if r.limit != "" {
_, err := strconv.Atoi(r.limit)
if err != nil {
return fmt.Errorf("invalid format for limit: %s", err)
}
}
if r.offset != "" {
_, err := strconv.Atoi(r.offset)
if err != nil {
return fmt.Errorf("invalid format for offset: %s", err)
}
}
return nil
}

func parseLimit(n *RqlNode, root *RqlRootNode) (isLimitOp bool) {
if n == nil {
return false
Expand Down Expand Up @@ -183,23 +166,21 @@ type field struct {
CovertFn func(interface{}) interface{}
}

func NewParser() *Parser {
return &Parser{s: NewScanner()}
}

func NewParserWithConfig(c *Config) (*Parser, error) {
err := c.defaults()
if err != nil {
return nil, err
}
func NewParser(c *Config) (*Parser, error) {
p := &Parser{
s: NewScanner(),
c: c,
fields: make(map[string]*field),
}
err = p.init()
if err != nil {
return nil, err
if c != nil {
err := c.defaults()
if err != nil {
return nil, err
}
err = p.init()
if err != nil {
return nil, err
}
}
return p, nil
}
Expand Down Expand Up @@ -274,6 +255,7 @@ func (p *Parser) parseField(sf reflect.StructField) error {
switch typ := indirect(sf.Type); typ.Kind() {
case reflect.Bool:
f.ValidateFn = validateBool
f.CovertFn = convertBool
case reflect.String:
f.ValidateFn = validateString
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
Expand All @@ -284,17 +266,20 @@ func (p *Parser) parseField(sf reflect.StructField) error {
f.CovertFn = convertInt
case reflect.Float32, reflect.Float64:
f.ValidateFn = validateFloat
f.CovertFn = convertFloat
case reflect.Struct:
switch v := reflect.Zero(typ); v.Interface().(type) {
case sql.NullBool:
f.ValidateFn = validateBool
f.CovertFn = convertBool
case sql.NullString:
f.ValidateFn = validateString
case sql.NullInt64:
f.ValidateFn = validateInt
f.CovertFn = convertInt
case sql.NullFloat64:
f.ValidateFn = validateFloat
f.CovertFn = convertFloat
case time.Time:
f.ValidateFn = validateTime(layout)
f.CovertFn = convertTime(layout)
Expand Down Expand Up @@ -324,10 +309,16 @@ func (p *Parser) Parse(r io.Reader) (root *RqlRootNode, err error) {
return nil, err
}
root.parseSpecialOps()
err = root.validateSpecialOps()
err = p.validateSpecialOps(root)
if err != nil {
return nil, err
}
if p.c != nil {
err := p.validateFields(root.Node)
if err != nil {
return nil, err
}
}
return
}

Expand Down
5 changes: 4 additions & 1 deletion pkg/driver/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ type Test struct {
}

func (test *Test) Run(t *testing.T) {
p := gorql.NewParser()
p, err := gorql.NewParser(nil)
if err != nil {
t.Fatalf("(%s) New parser error :%v\n", test.Name, err)
}

rqlNode, err := p.Parse(strings.NewReader(test.RQL))
if test.WantParseError != (err != nil) {
Expand Down
101 changes: 101 additions & 0 deletions validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ import (
"fmt"
"math"
"reflect"
"strconv"
"time"
)

type ValidationFunc func(*RqlNode) error

func errorType(v interface{}, expected string) error {
actual := "nil"
if v != nil {
Expand Down Expand Up @@ -74,3 +77,101 @@ func validateTime(layout string) func(interface{}) error {
return err
}
}

func (p *Parser) validateFields(n *RqlNode) error {
if n == nil {
return nil
}
fn := p.GetFieldValidationFunc()
if fn == nil {
return fmt.Errorf("no field validation op '%s'", n.Op)
}
return fn(n)
}

func (p *Parser) GetFieldValidationFunc() ValidationFunc {
return func(n *RqlNode) (err error) {
var field *field
for i, a := range n.Args {
switch v := a.(type) {
case string:
if i == 0 {
f, ok := p.fields[v]
if !ok || !f.Filterable {
return fmt.Errorf("field name (arg: %s) is not filterable", v)
}
field = f
} else {
if field == nil {
return fmt.Errorf("no field is found for node value %s", v)
}
newVal := field.CovertFn(v)
n.Args[i] = newVal
}
case *RqlNode:
err = p.validateFields(v)
if err != nil {
return err
}
}
}
return nil
}
}

func (p *Parser) validateSpecialOps(r *RqlRootNode) error {
if r.Limit() != "" {
err := p.validateLimit(r.Limit())
if err != nil {
return err
}
}
if r.Offset() != "" {
err := p.validateOffset(r.Offset())
if err != nil {
return err
}
}
if len(r.Sort()) > 0 {
err := p.validateSort(r.Sort())
if err != nil {
return err
}
}
return nil
}

func (p *Parser) validateSort(sortItems []Sort) error {
for _, s := range sortItems {
f, ok := p.fields[s.By]
if !ok || !f.Sortable {
return fmt.Errorf("field %s is not sortable", s.By)
}
}
return nil
}

func (p *Parser) validateOffset(o string) error {
offset, err := strconv.Atoi(o)
if err != nil {
return fmt.Errorf("invalid format for offset: %s", err)
}
if offset < 0 {
return fmt.Errorf("offset is less than zero")
}
return nil
}

func (p *Parser) validateLimit(l string) error {
limit, err := strconv.Atoi(l)
if err != nil {
return fmt.Errorf("invalid format for limit: %s", err)
}
if limit < 0 {
return fmt.Errorf("specified limit is less than zero")
}
if limit > p.c.LimitMaxValue {
return fmt.Errorf("specified limit is more than the max limit %d allowed", p.c.DefaultLimit)
}
return nil
}

0 comments on commit 9ad9f90

Please sign in to comment.