Skip to content

Commit

Permalink
More clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Sep 8, 2024
1 parent 14ab2fe commit e60f8c2
Show file tree
Hide file tree
Showing 19 changed files with 73 additions and 99 deletions.
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (s *Store) putTable(ctx context.Context, bqTableID TableIdentifier, tableDa
defer managedStream.Close()

encoder := func(row map[string]any) ([]byte, error) {
message, err := rowToMessage(row, columns, *messageDescriptor, s.AdditionalDateFormats())
message, err := rowToMessage(row, columns, *messageDescriptor)
if err != nil {
return nil, fmt.Errorf("failed to convert row to message: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions clients/bigquery/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ import (
func (s *Store) Merge(tableData *optimization.TableData) error {
var additionalEqualityStrings []string
if tableData.TopicConfig().BigQueryPartitionSettings != nil {
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
distinctDates, err := tableData.DistinctDates(tableData.TopicConfig().BigQueryPartitionSettings.PartitionField, additionalDateFmts)
distinctDates, err := tableData.DistinctDates(tableData.TopicConfig().BigQueryPartitionSettings.PartitionField)
if err != nil {
return fmt.Errorf("failed to generate distinct dates: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion clients/bigquery/storagewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestRowToMessage(t *testing.T) {
desc, err := columnsToMessageDescriptor(columns)
assert.NoError(t, err)

message, err := rowToMessage(row, columns, *desc, []string{})
message, err := rowToMessage(row, columns, *desc)
assert.NoError(t, err)

bytes, err := protojson.Marshal(message)
Expand Down
3 changes: 1 addition & 2 deletions clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo

defer stmt.Close()

additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
var row []any
for _, col := range cols {
castedValue, castErr := parseValue(value[col.Name()], col, additionalDateFmts)
castedValue, castErr := parseValue(value[col.Name()], col)
if castErr != nil {
return castErr
}
Expand Down
30 changes: 15 additions & 15 deletions clients/mssql/values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,81 +15,81 @@ import (

func TestParseValue(t *testing.T) {
{
val, err := parseValue(nil, columns.Column{}, nil)
val, err := parseValue(nil, columns.Column{})
assert.NoError(t, err)
assert.Nil(t, val)
}
{
val, err := parseValue("string value", columns.NewColumn("foo", typing.String), nil)
val, err := parseValue("string value", columns.NewColumn("foo", typing.String))
assert.NoError(t, err)
assert.Equal(t, "string value", val)

// We don't need to escape backslashes.
val, err = parseValue(`dusty o\donald`, columns.NewColumn("foo", typing.String), nil)
val, err = parseValue(`dusty o\donald`, columns.NewColumn("foo", typing.String))
assert.NoError(t, err)
assert.Equal(t, `dusty o\donald`, val)

// If the string precision exceeds the value, we'll need to insert an exceeded value.
stringCol := columns.NewColumn("foo", typing.String)
stringCol.KindDetails.OptionalStringPrecision = ptr.ToInt32(25)

val, err = parseValue(`abcdefabcdefabcdefabcdef113321`, stringCol, nil)
val, err = parseValue(`abcdefabcdefabcdefabcdef113321`, stringCol)
assert.NoError(t, err)
assert.Equal(t, constants.ExceededValueMarker, val)
}
{
val, err := parseValue(map[string]any{"foo": "bar"}, columns.NewColumn("json", typing.Struct), nil)
val, err := parseValue(map[string]any{"foo": "bar"}, columns.NewColumn("json", typing.Struct))
assert.NoError(t, err)
assert.Equal(t, `{"foo":"bar"}`, val)
}
{
val, err := parseValue([]any{"foo", "bar"}, columns.NewColumn("array", typing.Array), nil)
val, err := parseValue([]any{"foo", "bar"}, columns.NewColumn("array", typing.Array))
assert.NoError(t, err)
assert.Equal(t, `["foo","bar"]`, val)
}
{
// Integers
val, err := parseValue(1234, columns.NewColumn("int", typing.Integer), nil)
val, err := parseValue(1234, columns.NewColumn("int", typing.Integer))
assert.NoError(t, err)
assert.Equal(t, 1234, val)

// Should be able to handle string ints
val, err = parseValue("1234", columns.NewColumn("float", typing.Integer), nil)
val, err = parseValue("1234", columns.NewColumn("float", typing.Integer))
assert.NoError(t, err)
assert.Equal(t, 1234, val)
}
{
// Floats
val, err := parseValue(1234.5678, columns.NewColumn("float", typing.Float), nil)
val, err := parseValue(1234.5678, columns.NewColumn("float", typing.Float))
assert.NoError(t, err)
assert.Equal(t, 1234.5678, val)

// Should be able to handle string floats
val, err = parseValue("1234.5678", columns.NewColumn("float", typing.Float), nil)
val, err = parseValue("1234.5678", columns.NewColumn("float", typing.Float))
assert.NoError(t, err)
assert.Equal(t, 1234.5678, val)
}
{
// Boolean, but the column is an integer column.
val, err := parseValue(true, columns.NewColumn("bigint", typing.Integer), nil)
val, err := parseValue(true, columns.NewColumn("bigint", typing.Integer))
assert.NoError(t, err)
assert.Equal(t, 1, val)

// Booleans
val, err = parseValue(true, columns.NewColumn("bool", typing.Boolean), nil)
val, err = parseValue(true, columns.NewColumn("bool", typing.Boolean))
assert.NoError(t, err)
assert.True(t, val.(bool))

val, err = parseValue(false, columns.NewColumn("bool", typing.Boolean), nil)
val, err = parseValue(false, columns.NewColumn("bool", typing.Boolean))
assert.NoError(t, err)
assert.False(t, val.(bool))

// Should be able to handle string booleans
val, err = parseValue("true", columns.NewColumn("bool", typing.Boolean), nil)
val, err = parseValue("true", columns.NewColumn("bool", typing.Boolean))
assert.NoError(t, err)
assert.True(t, val.(bool))

val, err = parseValue("false", columns.NewColumn("bool", typing.Boolean), nil)
val, err = parseValue("false", columns.NewColumn("bool", typing.Boolean))
assert.NoError(t, err)
assert.False(t, val.(bool))
}
Expand Down
4 changes: 2 additions & 2 deletions clients/redshift/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func replaceExceededValues(colVal string, colKind typing.KindDetails) string {
return colVal
}

func castColValStaging(colVal any, colKind typing.KindDetails, additionalDateFmts []string) (string, error) {
func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) {
if colVal == nil {
if colKind == typing.Struct {
// Returning empty here because if it's a struct, it will go through JSON PARSE and JSON_PARSE("") = null
Expand All @@ -42,7 +42,7 @@ func castColValStaging(colVal any, colKind typing.KindDetails, additionalDateFmt
return `\N`, nil
}

colValString, err := values.ToString(colVal, colKind, additionalDateFmts)
colValString, err := values.ToString(colVal, colKind)
if err != nil {
return "", err
}
Expand Down
8 changes: 4 additions & 4 deletions clients/redshift/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ func (r *RedshiftTestSuite) TestReplaceExceededValues() {
func (r *RedshiftTestSuite) TestCastColValStaging() {
{
// Masked
value, err := castColValStaging(stringutil.Random(int(maxRedshiftLength)+1), typing.String, nil)
value, err := castColValStaging(stringutil.Random(int(maxRedshiftLength)+1), typing.String)
assert.NoError(r.T(), err)
assert.Equal(r.T(), constants.ExceededValueMarker, value)
}
{
// Valid
value, err := castColValStaging("thisissuperlongbutnotlongenoughtogetmasked", typing.String, nil)
value, err := castColValStaging("thisissuperlongbutnotlongenoughtogetmasked", typing.String)
assert.NoError(r.T(), err)
assert.Equal(r.T(), "thisissuperlongbutnotlongenoughtogetmasked", value)
}
{
// Masked struct
value, err := castColValStaging(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct, nil)
value, err := castColValStaging(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct)
assert.NoError(r.T(), err)
assert.Equal(r.T(), fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), value)
}
{
// Valid struct
value, err := castColValStaging(`{"foo": "bar"}`, typing.Struct, nil)
value, err := castColValStaging(`{"foo": "bar"}`, typing.Struct)
assert.NoError(r.T(), err)
assert.Equal(r.T(), `{"foo": "bar"}`, value)
}
Expand Down
3 changes: 1 addition & 2 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID
writer := csv.NewWriter(gzipWriter) // Create a CSV writer on top of the gzip writer
writer.Comma = '\t'

additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
for _, value := range tableData.Rows() {
var row []string
for _, col := range columns {
castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails, additionalDateFmts)
castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails)
if castErr != nil {
return "", castErr
}
Expand Down
3 changes: 1 addition & 2 deletions clients/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,12 @@ func (s *Store) Merge(tableData *optimization.TableData) error {
return fmt.Errorf("failed to instantiate parquet writer: %w", err)
}

additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
pw.CompressionType = parquet.CompressionCodec_GZIP
columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
for _, val := range tableData.Rows() {
row := make(map[string]any)
for _, col := range columns {
value, err := parquetutil.ParseValue(val[col.Name()], col, additionalDateFmts)
value, err := parquetutil.ParseValue(val[col.Name()], col)
if err != nil {
return fmt.Errorf("failed to parse value, err: %w, value: %v, column: %q", err, val[col.Name()], col.Name())
}
Expand Down
8 changes: 4 additions & 4 deletions clients/shared/default_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var dialects = []sql.Dialect{

func TestColumn_DefaultValue(t *testing.T) {
birthday := time.Date(2022, time.September, 6, 3, 19, 24, 942000000, time.UTC)
birthdayExtDateTime, err := ext.ParseExtendedDateTime(birthday.Format(ext.ISO8601), nil)
birthdayExtDateTime, err := ext.ParseExtendedDateTime(birthday.Format(ext.ISO8601))
assert.NoError(t, err)

// date
Expand Down Expand Up @@ -94,7 +94,7 @@ func TestColumn_DefaultValue(t *testing.T) {

for _, testCase := range testCases {
for _, dialect := range dialects {
actualValue, actualErr := DefaultValue(testCase.col, dialect, nil)
actualValue, actualErr := DefaultValue(testCase.col, dialect)
assert.NoError(t, actualErr, fmt.Sprintf("%s %s", testCase.name, dialect))

expectedValue := testCase.expectedValue
Expand All @@ -113,14 +113,14 @@ func TestColumn_DefaultValue(t *testing.T) {
// Type *decimal.Decimal
decimalValue := decimal.NewDecimal(numbers.MustParseDecimal("3.14159"))
col := columns.NewColumnWithDefaultValue("", typing.EDecimal, decimalValue)
value, err := DefaultValue(col, redshiftDialect.RedshiftDialect{}, nil)
value, err := DefaultValue(col, redshiftDialect.RedshiftDialect{})
assert.NoError(t, err)
assert.Equal(t, "3.14159", value)
}
{
// Wrong type (string)
col := columns.NewColumnWithDefaultValue("", typing.EDecimal, "hello")
_, err := DefaultValue(col, redshiftDialect.RedshiftDialect{}, nil)
_, err := DefaultValue(col, redshiftDialect.RedshiftDialect{})
assert.ErrorContains(t, err, "expected type *decimal.Decimal, got string")
}
}
Expand Down
7 changes: 3 additions & 4 deletions clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ func replaceExceededValues(colVal string, kindDetails typing.KindDetails) string
return colVal
}

func castColValStaging(colVal any, colKind typing.KindDetails, additionalDateFmts []string) (string, error) {
func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) {
if colVal == nil {
// \\N needs to match NULL_IF(...) from ddl.go
return `\\N`, nil
}

value, err := values.ToString(colVal, colKind, additionalDateFmts)
value, err := values.ToString(colVal, colKind)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -117,12 +117,11 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa
writer := csv.NewWriter(file)
writer.Comma = '\t'

additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
for _, value := range tableData.Rows() {
var row []string
for _, col := range columns {
castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails, additionalDateFmts)
castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails)
if castErr != nil {
return "", castErr
}
Expand Down
8 changes: 4 additions & 4 deletions clients/snowflake/staging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,26 @@ func (s *SnowflakeTestSuite) TestReplaceExceededValues() {
func (s *SnowflakeTestSuite) TestCastColValStaging() {
{
// Null
value, err := castColValStaging(nil, typing.String, nil)
value, err := castColValStaging(nil, typing.String)
assert.NoError(s.T(), err)
assert.Equal(s.T(), `\\N`, value)
}
{
// Struct field

// Did not exceed lob size
value, err := castColValStaging(map[string]any{"key": "value"}, typing.Struct, nil)
value, err := castColValStaging(map[string]any{"key": "value"}, typing.Struct)
assert.NoError(s.T(), err)
assert.Equal(s.T(), `{"key":"value"}`, value)

// Did exceed lob size
value, err = castColValStaging(map[string]any{"key": strings.Repeat("a", 16777216)}, typing.Struct, nil)
value, err = castColValStaging(map[string]any{"key": strings.Repeat("a", 16777216)}, typing.Struct)
assert.NoError(s.T(), err)
assert.Equal(s.T(), `{"key":"__artie_exceeded_value"}`, value)
}
{
// String field
value, err := castColValStaging("foo", typing.String, nil)
value, err := castColValStaging("foo", typing.String)
assert.NoError(s.T(), err)
assert.Equal(s.T(), "foo", value)
}
Expand Down
15 changes: 0 additions & 15 deletions lib/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,10 @@ import (

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/kafkalib"
"gopkg.in/yaml.v3"

"github.com/stretchr/testify/assert"
)

func TestSharedTransferConfig(t *testing.T) {
{
var sharedTransferCfg SharedTransferConfig
validBody := `
typingSettings:
additionalDateFormats: ["yyyy-MM-dd1"]
`
err := yaml.Unmarshal([]byte(validBody), &sharedTransferCfg)
assert.NoError(t, err)

assert.Equal(t, "yyyy-MM-dd1", sharedTransferCfg.TypingSettings.AdditionalDateFormats[0])
}
}

const (
validKafkaTopic = `
kafka:
Expand Down
4 changes: 2 additions & 2 deletions lib/optimization/table_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,15 @@ func (t *TableData) NumberOfRows() uint {
return uint(len(t.rowsData))
}

func (t *TableData) DistinctDates(colName string, additionalDateFmts []string) ([]string, error) {
func (t *TableData) DistinctDates(colName string) ([]string, error) {
retMap := make(map[string]bool)
for _, row := range t.rowsData {
val, isOk := row[colName]
if !isOk {
return nil, fmt.Errorf("col: %v does not exist on row: %v", colName, row)
}

extTime, err := ext.ParseFromInterface(val, additionalDateFmts)
extTime, err := ext.ParseFromInterface(val)
if err != nil {
return nil, fmt.Errorf("col: %v is not a time column, value: %v, err: %w", colName, val, err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/optimization/table_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestDistinctDates(t *testing.T) {
rowsData: testCase.rowData,
}

actualValues, actualErr := td.DistinctDates("ts", nil)
actualValues, actualErr := td.DistinctDates("ts")
if testCase.expectedErr != "" {
assert.ErrorContains(t, actualErr, testCase.expectedErr, testCase.name)
} else {
Expand Down
4 changes: 2 additions & 2 deletions lib/parquetutil/parse_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ import (
"github.com/artie-labs/transfer/lib/typing/ext"
)

func ParseValue(colVal any, colKind columns.Column, additionalDateFmts []string) (any, error) {
func ParseValue(colVal any, colKind columns.Column) (any, error) {
if colVal == nil {
return nil, nil
}

switch colKind.KindDetails.Kind {
case typing.ETime.Kind:
extTime, err := ext.ParseFromInterface(colVal, additionalDateFmts)
extTime, err := ext.ParseFromInterface(colVal)
if err != nil {
return "", fmt.Errorf("failed to cast colVal as time.Time, colVal: %v, err: %w", colVal, err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/parquetutil/parse_values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func TestParseValue(t *testing.T) {
}

for _, tc := range testCases {
actualValue, actualErr := ParseValue(tc.colVal, tc.colKind, nil)
actualValue, actualErr := ParseValue(tc.colVal, tc.colKind)
assert.NoError(t, actualErr, tc.name)
assert.Equal(t, tc.expectedValue, actualValue, tc.name)
}
Expand Down
Loading

0 comments on commit e60f8c2

Please sign in to comment.