Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Column Metadata #231

Merged
merged 1 commit into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ It only asserts that argument is of `time.Time` type.

## Change Log

- **2019-04-06** - added functionality to mock a sql MetaData request
- **2019-02-13** - added `go.mod` removed the references and suggestions using `gopkg.in`.
- **2018-12-11** - added expectation of Rows to be closed, while mocking expected query.
- **2018-12-11** - introduced an option to provide **QueryMatcher** in order to customize SQL query matching.
Expand Down
77 changes: 77 additions & 0 deletions column.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package sqlmock

import "reflect"

// Column is a mocked column Metadata for rows.ColumnTypes()
type Column struct {
name string
dbType string
nullable bool
nullableOk bool
length int64
lengthOk bool
precision int64
scale int64
psOk bool
scanType reflect.Type
}

func (c *Column) Name() string {
return c.name
}

func (c *Column) DbType() string {
return c.dbType
}

func (c *Column) IsNullable() (bool, bool) {
return c.nullable, c.nullableOk
}

func (c *Column) Length() (int64, bool) {
return c.length, c.lengthOk
}

func (c *Column) PrecisionScale() (int64, int64, bool) {
return c.precision, c.scale, c.psOk
}

func (c *Column) ScanType() reflect.Type {
return c.scanType
}

// NewColumn returns a Column with specified name
func NewColumn(name string) *Column {
return &Column{
name: name,
}
}

// Nullable returns the column with nullable metadata set
func (c *Column) Nullable(nullable bool) *Column {
c.nullable = nullable
c.nullableOk = true
return c
}

// OfType returns the column with type metadata set
func (c *Column) OfType(dbType string, sampleValue interface{}) *Column {
c.dbType = dbType
c.scanType = reflect.TypeOf(sampleValue)
return c
}

// WithLength returns the column with length metadata set.
func (c *Column) WithLength(length int64) *Column {
c.length = length
c.lengthOk = true
return c
}

// WithPrecisionAndScale returns the column with precision and scale metadata set.
func (c *Column) WithPrecisionAndScale(precision, scale int64) *Column {
c.precision = precision
c.scale = scale
c.psOk = true
return c
}
63 changes: 63 additions & 0 deletions column_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package sqlmock

import (
"reflect"
"testing"
"time"
)

func TestColumn(t *testing.T) {
now, _ := time.Parse(time.RFC3339, "2020-06-20T22:08:41Z")
column1 := NewColumn("test").OfType("VARCHAR", "").Nullable(true).WithLength(100)
column2 := NewColumn("number").OfType("DECIMAL", float64(0.0)).Nullable(false).WithPrecisionAndScale(10, 4)
column3 := NewColumn("when").OfType("TIMESTAMP", now)

if column1.ScanType().Kind() != reflect.String {
t.Errorf("string scanType mismatch: %v", column1.ScanType())
}
if column2.ScanType().Kind() != reflect.Float64 {
t.Errorf("float scanType mismatch: %v", column2.ScanType())
}
if column3.ScanType() != reflect.TypeOf(time.Time{}) {
t.Errorf("time scanType mismatch: %v", column3.ScanType())
}

nullable, ok := column1.IsNullable()
if !nullable || !ok {
t.Errorf("'test' column should be nullable")
}
nullable, ok = column2.IsNullable()
if nullable || !ok {
t.Errorf("'number' column should not be nullable")
}
nullable, ok = column3.IsNullable()
if ok {
t.Errorf("'when' column nullability should be unknown")
}

length, ok := column1.Length()
if length != 100 || !ok {
t.Errorf("'test' column wrong length")
}
length, ok = column2.Length()
if ok {
t.Errorf("'number' column is not of variable length type")
}
length, ok = column3.Length()
if ok {
t.Errorf("'when' column is not of variable length type")
}

_, _, ok = column1.PrecisionScale()
if ok {
t.Errorf("'test' column not applicable")
}
precision, scale, ok := column2.PrecisionScale()
if precision != 10 || scale != 4 || !ok {
t.Errorf("'number' column not applicable")
}
_, _, ok = column3.PrecisionScale()
if ok {
t.Errorf("'when' column not applicable")
}
}
10 changes: 9 additions & 1 deletion expectations_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@ import (
// WillReturnRows specifies the set of resulting rows that will be returned
// by the triggered query
func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery {
defs := 0
sets := make([]*Rows, len(rows))
for i, r := range rows {
sets[i] = r
if r.def != nil {
defs++
}
}
if defs > 0 && defs == len(sets) {
e.rows = &rowSetsWithDefinition{&rowSets{sets: sets, ex: e}}
} else {
e.rows = &rowSets{sets: sets, ex: e}
}
e.rows = &rowSets{sets: sets, ex: e}
return e
}

Expand Down
1 change: 1 addition & 0 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func (rs *rowSets) invalidateRaw() {
type Rows struct {
converter driver.ValueConverter
cols []string
def []*Column
rows [][]driver.Value
pos int
nextErr map[int]error
Expand Down
56 changes: 55 additions & 1 deletion rows_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

package sqlmock

import "io"
import (
"database/sql/driver"
"io"
"reflect"
)

// Implement the "RowsNextResultSet" interface
func (rs *rowSets) HasNextResultSet() bool {
Expand All @@ -18,3 +22,53 @@ func (rs *rowSets) NextResultSet() error {
rs.pos++
return nil
}

// type for rows with columns definition created with sqlmock.NewRowsWithColumnDefinition
type rowSetsWithDefinition struct {
*rowSets
}

// Implement the "RowsColumnTypeDatabaseTypeName" interface
func (rs *rowSetsWithDefinition) ColumnTypeDatabaseTypeName(index int) string {
return rs.getDefinition(index).DbType()
}

// Implement the "RowsColumnTypeLength" interface
func (rs *rowSetsWithDefinition) ColumnTypeLength(index int) (length int64, ok bool) {
return rs.getDefinition(index).Length()
}

// Implement the "RowsColumnTypeNullable" interface
func (rs *rowSetsWithDefinition) ColumnTypeNullable(index int) (nullable, ok bool) {
return rs.getDefinition(index).IsNullable()
}

// Implement the "RowsColumnTypePrecisionScale" interface
func (rs *rowSetsWithDefinition) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return rs.getDefinition(index).PrecisionScale()
}

// ColumnTypeScanType is defined from driver.RowsColumnTypeScanType
func (rs *rowSetsWithDefinition) ColumnTypeScanType(index int) reflect.Type {
return rs.getDefinition(index).ScanType()
}

// return column definition from current set metadata
func (rs *rowSetsWithDefinition) getDefinition(index int) *Column {
return rs.sets[rs.pos].def[index]
}

// NewRowsWithColumnDefinition return rows with columns metadata
func NewRowsWithColumnDefinition(columns ...*Column) *Rows {
cols := make([]string, len(columns))
for i, column := range columns {
cols[i] = column.Name()
}

return &Rows{
cols: cols,
def: columns,
nextErr: make(map[int]error),
converter: driver.DefaultParameterConverter,
}
}
Loading