-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathscanner.go
220 lines (198 loc) · 5.44 KB
/
scanner.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package pgxscan
import (
"errors"
"fmt"
"github.com/jackc/pgx/v4"
"github.com/randallmlough/pgxscan/internal/sqlmapper"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
type (
// Scanner knows how to scan sql.Rows into structs.
Scanner interface {
Scan(v ...interface{}) error
}
scannerFunc func(i ...interface{}) error
)
func unableToFindFieldError(col string) error {
return fmt.Errorf(`unable to find corresponding field to column "%s" returned by query`, col)
}
// NewScanner takes in a scanner returns a scanner
// Since the pgx row and rows interface both have a `Scan(v ...interface{}) error` method,
// either one can be passed as the argument and scanner will take care of the rest.
func NewScanner(src Scanner, opts ...Option) Scanner {
cfg := &Config{
ReturnErrNoRowsForRows: true,
MatchAllColumnsToStruct: true,
}
for _, opt := range opts {
opt.apply(cfg)
}
switch s := src.(type) {
case pgx.Rows:
return &rows{rows: s, cfg: cfg}
case pgx.Row:
return &row{row: s, cfg: cfg}
}
return nil
}
type Config struct {
ReturnErrNoRowsForRows bool
MatchAllColumnsToStruct bool
}
type Option interface {
apply(*Config)
}
// optionFunc wraps a func so it satisfies the Option interface.
type optionFunc func(*Config)
func (f optionFunc) apply(s *Config) {
f(s)
}
// ErrNoRowsQuery sets whether or not a pgx.ErrNoRows error should be returned on a query that has no rows
func ErrNoRowsQuery(b bool) Option {
return optionFunc(func(cfg *Config) {
cfg.ReturnErrNoRowsForRows = b
})
}
// MatchAllColumns sets whether or not a unableToFindFieldError error
// should be returned on a query that has more columns than fields in the struct
func MatchAllColumns(b bool) Option {
return optionFunc(func(cfg *Config) {
cfg.MatchAllColumnsToStruct = b
})
}
var ErrNoCols = errors.New("columns can not be nil")
// ScanStruct will scan the current row into i.
// When matchAllColumnsToStruct is false, it will not complain about extra columns
// in the result set that are not mapped to the columns in the struct, or, said
// another way, it will allow unmapped items, which can, sometimes, be convenient
func ScanStruct(scan scannerFunc, i interface{}, cols []string, matchAllColumnsToStruct bool) error {
if cols == nil {
return ErrNoCols
}
cm, err := sqlmaper.GetColumnMap(i)
if err != nil {
return err
}
scans := make([]interface{}, len(cols))
for idx, col := range cols {
data, ok := cm[col]
switch {
case strings.HasPrefix(col, QueryColumnNotatePrefix):
// notated columns are always skipped
scans[idx] = new(int8)
case !ok:
if matchAllColumnsToStruct {
return unableToFindFieldError(col)
}
default:
scans[idx] = reflect.New(data.GoType).Interface()
}
}
if err := scan(scans...); err != nil {
// identify the offending field in case types do not match, very useful
// when using this library
if strings.HasPrefix(err.Error(), "can't scan into dest[") {
// TODO: use ScanArgError if made public (see https://github.com/jackc/pgx/issues/931)
// to avoid all this parsing
splitted := strings.Split(err.Error(), ":")
re := regexp.MustCompile(`\[(\d+)\]`)
errCol, errConv := strconv.Atoi(
strings.Trim(string(re.Find([]byte(splitted[0]))), "[]"),
)
if errConv != nil {
return err
}
return fmt.Errorf("%s (field '%s'):%s", splitted[0], cols[errCol], splitted[1])
}
return err
}
record := make(map[string]interface{}, len(cols))
for index, col := range cols {
record[col] = scans[index]
}
sqlmaper.AssignStructVals(i, record, cm)
return nil
}
func validate(i interface{}) (reflect.Value, error) {
val := reflect.ValueOf(i)
if val.Kind() != reflect.Ptr {
return reflect.Value{}, errors.New("destination must be a pointer")
}
if !val.Elem().CanSet() {
return reflect.Value{}, errors.New("destination must be initialized. Don't use var foo *Foo. Use foo := new(Foo) or foo := &Foo{}")
}
for val.Kind() == reflect.Ptr {
if val.IsNil() {
v := reflect.New(val.Type().Elem())
// TODO: refactoring required to handle non initialized nil values like, `var foo *Foo`
// the previous recursion call below worked, however, it limits the possibility of doing post processing after close,
// such as rows returned, which is more useful than accepting nil values.
//if err := scan.Scan(v.Interface()); err != nil {
// return reflect.Value{}, err
//}
val.Set(v)
} else {
val = val.Elem()
}
}
return val, nil
}
func isVariadic(i ...interface{}) bool {
switch len(i) {
case 0:
return false
case 1:
if isBuiltin(i[0]) {
return true
}
return false
default:
return true
}
}
func isBuiltin(i interface{}) bool {
switch i.(type) {
case
string,
uint, uint8, uint16, uint32, uint64,
int, int8, int16, int32, int64,
complex64, complex128,
float32, float64,
bool:
return true
case
*string,
*uint, *uint8, *uint16, *uint32, *uint64,
*int, *int8, *int16, *int32, *int64,
*complex64, *complex128,
*float32, *float64,
*bool:
return true
case
[]string,
[]uint, []uint8, []uint16, []uint32, []uint64,
[]int, []int8, []int16, []int32, []int64,
[]complex64, []complex128,
[]float32, []float64,
[]bool:
return true
case
*[]string,
*[]uint, *[]uint8, *[]uint16, *[]uint32, *[]uint64,
*[]int, *[]int8, *[]int16, *[]int32, *[]int64,
*[]complex64, *[]complex128,
*[]float32, *[]float64,
*[]bool:
return true
case time.Time, *time.Time:
return true
case []time.Time, *[]time.Time:
return true
default:
return false
}
}