Skip to content

Commit

Permalink
fix: warehouse size validation (#1873)
Browse files Browse the repository at this point in the history
* warehouse size

* Update pkg/sdk/warehouses.go

Co-authored-by: Nathan Gaberel <nathan.gaberel@snowflake.com>

* update

---------

Co-authored-by: Nathan Gaberel <nathan.gaberel@snowflake.com>
  • Loading branch information
sfc-gh-swinkler and sfc-gh-ngaberel authored Jun 13, 2023
1 parent 30d0017 commit 5bbe460
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 63 deletions.
10 changes: 8 additions & 2 deletions pkg/resources/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ func dataTypeValidateFunc(val interface{}, _ string) (warns []string, errs []err
}

func dataTypeDiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool {
oldDT := sdk.DataTypeFromString(old)
newDT := sdk.DataTypeFromString(new)
oldDT, err := sdk.ToDataType(old)
if err != nil {
return false
}
newDT, err := sdk.ToDataType(new)
if err != nil {
return false
}
return oldDT == newDT
}

Expand Down
13 changes: 9 additions & 4 deletions pkg/resources/masking_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,21 @@ func CreateMaskingPolicy(d *schema.ResourceData, meta interface{}) error {
columns := m["column"].([]interface{})
for _, c := range columns {
cm := c.(map[string]interface{})
dt := sdk.DataTypeFromString(cm["type"].(string))
dt, err := sdk.ToDataType(cm["type"].(string))
if err != nil {
return err
}
signature = append(signature, sdk.TableColumnSignature{
Name: cm["name"].(string),
Type: dt,
})
}
}

returns := sdk.DataTypeFromString(returnDataType)

returns, err := sdk.ToDataType(returnDataType)
if err != nil {
return err
}
opts := &sdk.CreateMaskingPolicyOptions{}
if comment, ok := d.Get("comment").(string); ok {
opts.Comment = sdk.String(comment)
Expand All @@ -170,7 +175,7 @@ func CreateMaskingPolicy(d *schema.ResourceData, meta interface{}) error {
opts.ExemptOtherPolicies = sdk.Bool(exemptOtherPolicies)
}

err := client.MaskingPolicies.Create(ctx, objectIdentifier, signature, returns, expression, opts)
err = client.MaskingPolicies.Create(ctx, objectIdentifier, signature, returns, expression, opts)
if err != nil {
return err
}
Expand Down
44 changes: 22 additions & 22 deletions pkg/resources/warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package resources
import (
"context"
"database/sql"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
snowflakevalidation "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/validation"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
)
Expand All @@ -23,26 +23,20 @@ var warehouseSchema = map[string]*schema.Schema{
Default: "",
},
"warehouse_size": {
Type: schema.TypeString,
Optional: true,
Computed: true,
ValidateFunc: validation.StringInSlice([]string{
string(sdk.WarehouseSizeXSmall),
string(sdk.WarehouseSizeSmall),
string(sdk.WarehouseSizeMedium),
string(sdk.WarehouseSizeLarge),
string(sdk.WarehouseSizeXLarge),
string(sdk.WarehouseSizeXXLarge),
string(sdk.WarehouseSizeXXXLarge),
string(sdk.WarehouseSizeX4Large),
string(sdk.WarehouseSizeX5Large),
string(sdk.WarehouseSizeX6Large),
}, false),
Type: schema.TypeString,
Optional: true,
Computed: true,
ValidateFunc: snowflakevalidation.ValidateWarehouseSize,
DiffSuppressFunc: func(k, old, new string, d *schema.ResourceData) bool {
normalize := func(s string) string {
return strings.ToUpper(strings.ReplaceAll(s, "-", ""))
oldSize, err := sdk.ToWarehouseSize(old)
if err != nil {
return false
}
newSize, err := sdk.ToWarehouseSize(new)
if err != nil {
return false
}
return normalize(old) == normalize(new)
return oldSize == newSize
},
Description: "Specifies the size of the virtual warehouse. Larger warehouse sizes 5X-Large and 6X-Large are currently in preview and only available on Amazon Web Services (AWS).",
},
Expand Down Expand Up @@ -168,7 +162,6 @@ func CreateWarehouse(d *schema.ResourceData, meta interface{}) error {

name := d.Get("name").(string)
objectIdentifier := sdk.NewAccountObjectIdentifier(name)

whType := sdk.WarehouseType(d.Get("warehouse_type").(string))
createOptions := &sdk.CreateWarehouseOptions{
Comment: sdk.String(d.Get("comment").(string)),
Expand All @@ -181,7 +174,10 @@ func CreateWarehouse(d *schema.ResourceData, meta interface{}) error {
}

if v, ok := d.GetOk("warehouse_size"); ok {
size := sdk.WarehouseSize(strings.ReplaceAll(v.(string), "-", ""))
size, err := sdk.ToWarehouseSize(v.(string))
if err != nil {
return err
}
createOptions.WarehouseSize = &size
}
if v, ok := d.GetOk("max_cluster_count"); ok {
Expand Down Expand Up @@ -306,7 +302,11 @@ func UpdateWarehouse(d *schema.ResourceData, meta interface{}) error {
}
if d.HasChange("warehouse_size") {
runSet = true
size := sdk.WarehouseSize(strings.ReplaceAll(d.Get("warehouse_size").(string), "-", ""))
v := d.Get("warehouse_size")
size, err := sdk.ToWarehouseSize(v.(string))
if err != nil {
return err
}
set.WarehouseSize = &size
}
if d.HasChange("max_cluster_count") {
Expand Down
11 changes: 5 additions & 6 deletions pkg/resources/warehouse_acceptance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"strings"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
)
Expand Down Expand Up @@ -44,12 +43,12 @@ func TestAcc_Warehouse(t *testing.T) {
},
// CHANGE PROPERTIES
{
Config: wConfig2(prefix2),
Config: wConfig2(prefix2, "X-LARGE"),
Check: resource.ComposeTestCheckFunc(
resource.TestCheckResourceAttr("snowflake_warehouse.w", "name", prefix2),
resource.TestCheckResourceAttr("snowflake_warehouse.w", "comment", "test comment 2"),
resource.TestCheckResourceAttr("snowflake_warehouse.w", "auto_suspend", "60"),
resource.TestCheckResourceAttr("snowflake_warehouse.w", "warehouse_size", string(sdk.WarehouseSizeSmall)),
resource.TestCheckResourceAttr("snowflake_warehouse.w", "warehouse_size", "X-LARGE"),
),
},
// IMPORT
Expand Down Expand Up @@ -110,12 +109,12 @@ resource "snowflake_warehouse" "w" {
return fmt.Sprintf(s, prefix)
}

func wConfig2(prefix string) string {
func wConfig2(prefix string, size string) string {
s := `
resource "snowflake_warehouse" "w" {
name = "%s"
comment = "test comment 2"
warehouse_size = "SMALL"
warehouse_size = "%s"
auto_suspend = 60
max_cluster_count = 1
Expand All @@ -126,7 +125,7 @@ resource "snowflake_warehouse" "w" {
wait_for_provisioning = false
}
`
return fmt.Sprintf(s, prefix)
return fmt.Sprintf(s, prefix, size)
}

func wConfigPattern(prefix string) string {
Expand Down
38 changes: 18 additions & 20 deletions pkg/sdk/data_types.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sdk

import (
"fmt"
"strings"

"golang.org/x/exp/slices"
Expand All @@ -24,70 +25,67 @@ const (
DataTypeArray DataType = "ARRAY"
DataTypeGeography DataType = "GEOGRAPHY"
DataTypeGeometry DataType = "GEOMETRY"

// DataTypeUnknown is used for testing purposes only.
DataTypeUnknown DataType = "UNKNOWN"
)

func DataTypeFromString(s string) DataType {
func ToDataType(s string) (DataType, error) {
dType := strings.ToUpper(s)

switch dType {
case "DATE":
return DataTypeDate
return DataTypeDate, nil
case "VARIANT":
return DataTypeVariant
return DataTypeVariant, nil
case "OBJECT":
return DataTypeObject
return DataTypeObject, nil
case "ARRAY":
return DataTypeArray
return DataTypeArray, nil
case "GEOGRAPHY":
return DataTypeGeography
return DataTypeGeography, nil
case "GEOMETRY":
return DataTypeGeometry
return DataTypeGeometry, nil
}

numberSynonyms := []string{"NUMBER", "DECIMAL", "NUMERIC", "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"}
if slices.ContainsFunc(numberSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeNumber
return DataTypeNumber, nil
}

floatSynonyms := []string{"FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION", "REAL"}
if slices.ContainsFunc(floatSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeFloat
return DataTypeFloat, nil
}
varcharSynonyms := []string{"VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"}
if slices.ContainsFunc(varcharSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeVARCHAR
return DataTypeVARCHAR, nil
}
binarySynonyms := []string{"BINARY", "VARBINARY"}
if slices.ContainsFunc(binarySynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeBinary
return DataTypeBinary, nil
}
booleanSynonyms := []string{"BOOLEAN", "BOOL"}
if slices.Contains(booleanSynonyms, dType) {
return DataTypeBoolean
return DataTypeBoolean, nil
}

timestampLTZSynonyms := []string{"TIMESTAMP_LTZ"}
if slices.ContainsFunc(timestampLTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeTimestampLTZ
return DataTypeTimestampLTZ, nil
}

timestampTZSynonyms := []string{"TIMESTAMP_TZ"}
if slices.ContainsFunc(timestampTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeTimestampTZ
return DataTypeTimestampTZ, nil
}

timestampNTZSynonyms := []string{"DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ"}
if slices.ContainsFunc(timestampNTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeTimestampNTZ
return DataTypeTimestampNTZ, nil
}

timeSynonyms := []string{"TIME"}
if slices.ContainsFunc(timeSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) {
return DataTypeTime
return DataTypeTime, nil
}

return DataTypeUnknown
return "", fmt.Errorf("invalid data type: %s", s)
}
6 changes: 3 additions & 3 deletions pkg/sdk/data_types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestDataTypeFromString(t *testing.T) {
func TestToDataType(t *testing.T) {
type test struct {
input string
want DataType
Expand Down Expand Up @@ -77,12 +77,12 @@ func TestDataTypeFromString(t *testing.T) {
{input: "array", want: DataTypeArray},
{input: "geography", want: DataTypeGeography},
{input: "geometry", want: DataTypeGeometry},
{input: "invalid", want: DataTypeUnknown},
}

for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := DataTypeFromString(tc.input)
got, err := ToDataType(tc.input)
require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/sdk/masking_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ type maskingPolicyDetailsRow struct {
}

func (row maskingPolicyDetailsRow) toMaskingPolicyDetails() *MaskingPolicyDetails {
dataType := DataTypeFromString(row.ReturnType)
dataType, err := ToDataType(row.ReturnType)
if err != nil {
return nil
}
v := &MaskingPolicyDetails{
Name: row.Name,
Signature: []TableColumnSignature{},
Expand All @@ -344,7 +347,10 @@ func (row maskingPolicyDetailsRow) toMaskingPolicyDetails() *MaskingPolicyDetail
if len(p) != 2 {
continue
}
dType := DataTypeFromString(p[1])
dType, err := ToDataType(p[1])
if err != nil {
continue
}
v.Signature = append(v.Signature, TableColumnSignature{
Name: p[0],
Type: dType,
Expand Down
9 changes: 7 additions & 2 deletions pkg/sdk/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ import (
)

func IsValidDataType(v string) bool {
dt := DataTypeFromString(v)
return dt != DataTypeUnknown
_, err := ToDataType(v)
return err == nil
}

func IsValidWarehouseSize(v string) bool {
_, err := ToWarehouseSize(v)
return err == nil
}

func validObjectidentifier(objectIdentifier ObjectIdentifier) bool {
Expand Down
16 changes: 14 additions & 2 deletions pkg/sdk/validations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@ import (
func TestIsValidDataType(t *testing.T) {
t.Run("with valid data type", func(t *testing.T) {
ok := IsValidDataType("VARCHAR")
assert.Equal(t, ok, true)
assert.True(t, ok)
})

t.Run("with invalid data type", func(t *testing.T) {
ok := IsValidDataType("foo")
assert.Equal(t, ok, false)
assert.False(t, ok)
})
}

func TestIsValidWarehouseSize(t *testing.T) {
t.Run("with valid warehouse size", func(t *testing.T) {
ok := IsValidWarehouseSize("XSMALL")
assert.True(t, ok)
})

t.Run("with invalid warehouse size", func(t *testing.T) {
ok := IsValidWarehouseSize("foo")
assert.False(t, ok)
})
}

Expand Down
Loading

0 comments on commit 5bbe460

Please sign in to comment.