diff --git a/pkg/snowflake/database_test.go b/pkg/snowflake/database_test.go index 5cb84add17..8c8ed6beb9 100644 --- a/pkg/snowflake/database_test.go +++ b/pkg/snowflake/database_test.go @@ -41,6 +41,12 @@ func TestDatabase(t *testing.T) { c.SetBool("bam", false) q = c.Statement() r.Equal(`CREATE DATABASE "db1" FOO='bar' BAM=false`, q) + + // test escaping + c2 := db.Create() + c2.SetString("foo", "ba'r") + q = c2.Statement() + r.Equal(`CREATE DATABASE "db1" FOO='ba\'r'`, q) } func TestDatabaseCreateFromShare(t *testing.T) { diff --git a/pkg/snowflake/schema.go b/pkg/snowflake/schema.go index 203c4782e4..d6c676d1bd 100644 --- a/pkg/snowflake/schema.go +++ b/pkg/snowflake/schema.go @@ -101,7 +101,7 @@ func (sb *SchemaBuilder) Create() string { } if sb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, sb.comment)) + q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(sb.comment))) } return q.String() @@ -120,7 +120,7 @@ func (sb *SchemaBuilder) Swap(targetSchema string) string { // ChangeComment returns the SQL query that will update the comment on the schema. func (sb *SchemaBuilder) ChangeComment(c string) string { - return fmt.Sprintf(`ALTER SCHEMA %v SET COMMENT = '%v'`, sb.QualifiedName(), c) + return fmt.Sprintf(`ALTER SCHEMA %v SET COMMENT = '%v'`, sb.QualifiedName(), EscapeString(c)) } // RemoveComment returns the SQL query that will remove the comment on the schema. diff --git a/pkg/snowflake/schema_test.go b/pkg/snowflake/schema_test.go index 56ddeb5c28..bc9a20b895 100644 --- a/pkg/snowflake/schema_test.go +++ b/pkg/snowflake/schema_test.go @@ -25,8 +25,8 @@ func TestSchemaCreate(t *testing.T) { s.WithDataRetentionDays(7) r.Equal(s.Create(), `CREATE TRANSIENT SCHEMA "db"."test" WITH MANAGED ACCESS DATA_RETENTION_TIME_IN_DAYS = 7`) - s.WithComment("Yeehaw") - r.Equal(s.Create(), `CREATE TRANSIENT SCHEMA "db"."test" WITH MANAGED ACCESS DATA_RETENTION_TIME_IN_DAYS = 7 COMMENT = 'Yeehaw'`) + s.WithComment("Yee'haw") + r.Equal(`CREATE TRANSIENT SCHEMA "db"."test" WITH MANAGED ACCESS DATA_RETENTION_TIME_IN_DAYS = 7 COMMENT = 'Yee\'haw'`, s.Create()) } func TestSchemaRename(t *testing.T) { @@ -44,7 +44,7 @@ func TestSchemaSwap(t *testing.T) { func TestSchemaChangeComment(t *testing.T) { r := require.New(t) s := Schema("test") - r.Equal(s.ChangeComment("worst schema ever"), `ALTER SCHEMA "test" SET COMMENT = 'worst schema ever'`) + r.Equal(`ALTER SCHEMA "test" SET COMMENT = 'worst\' schema ever'`, s.ChangeComment("worst' schema ever")) } func TestSchemaRemoveComment(t *testing.T) { diff --git a/pkg/snowflake/stage.go b/pkg/snowflake/stage.go index 474c78c2f0..8223d3b272 100644 --- a/pkg/snowflake/stage.go +++ b/pkg/snowflake/stage.go @@ -123,7 +123,7 @@ func (sb *StageBuilder) Create() string { } if sb.comment != "" { - q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, sb.comment)) + q.WriteString(fmt.Sprintf(` COMMENT = '%v'`, EscapeString(sb.comment))) } return q.String() diff --git a/pkg/snowflake/stage_test.go b/pkg/snowflake/stage_test.go index e40933d1d2..17e9a0521a 100644 --- a/pkg/snowflake/stage_test.go +++ b/pkg/snowflake/stage_test.go @@ -28,11 +28,11 @@ func TestStageCreate(t *testing.T) { s.WithCopyOptions("on_error='skip_file'") r.Equal(s.Create(), `CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file')`) - s.WithComment("Yeehaw") - r.Equal(s.Create(), `CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yeehaw'`) + s.WithComment("Yee'haw") + r.Equal(`CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yee\'haw'`, s.Create()) s.WithStorageIntegration("MY_INTEGRATION") - r.Equal(s.Create(), `CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') STORAGE_INTEGRATION = MY_INTEGRATION ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yeehaw'`) + r.Equal(`CREATE STAGE "test_db"."test_schema"."test_stage" URL = 's3://load/encrypted_files/' CREDENTIALS = (aws_role='arn:aws:iam::001234567890:role/mysnowflakerole') STORAGE_INTEGRATION = MY_INTEGRATION ENCRYPTION = (type='AWS_SSE_KMS' kms_key_id = 'aws/key') FILE_FORMAT = (format_name=my_csv_format) COPY_OPTIONS = (on_error='skip_file') COMMENT = 'Yee\'haw'`, s.Create()) } func TestStageRename(t *testing.T) { diff --git a/pkg/snowflake/validation_test.go b/pkg/snowflake/validation_test.go index 0a319ad616..f20b36c83f 100644 --- a/pkg/snowflake/validation_test.go +++ b/pkg/snowflake/validation_test.go @@ -1 +1,3 @@ package snowflake_test + +// TODO write tests here diff --git a/pkg/snowflake/view.go b/pkg/snowflake/view.go index a4635dd616..ecc1fff230 100644 --- a/pkg/snowflake/view.go +++ b/pkg/snowflake/view.go @@ -110,7 +110,7 @@ func (vb *ViewBuilder) Create() string { q.WriteString(fmt.Sprintf(` VIEW %v`, vb.QualifiedName())) if vb.comment != "" { - q.WriteString(fmt.Sprintf(" COMMENT = '%v'", vb.comment)) + q.WriteString(fmt.Sprintf(" COMMENT = '%v'", EscapeString(vb.comment))) } q.WriteString(fmt.Sprintf(" AS %v", vb.statement)) @@ -139,7 +139,7 @@ func (vb *ViewBuilder) Unsecure() string { // Note that comment is the only parameter, if more are released this should be // abstracted as per the generic builder. func (vb *ViewBuilder) ChangeComment(c string) string { - return fmt.Sprintf(`ALTER VIEW %v SET COMMENT = '%v'`, vb.QualifiedName(), c) + return fmt.Sprintf(`ALTER VIEW %v SET COMMENT = '%v'`, vb.QualifiedName(), EscapeString(c)) } // RemoveComment returns the SQL query that will remove the comment on the view. diff --git a/pkg/snowflake/view_test.go b/pkg/snowflake/view_test.go index 23aaa10a92..3529374caa 100644 --- a/pkg/snowflake/view_test.go +++ b/pkg/snowflake/view_test.go @@ -16,16 +16,14 @@ func TestView(t *testing.T) { v.WithSecure() r.True(v.secure) - v.WithComment("great comment") - r.Equal("great comment", v.comment) - + v.WithComment("great' comment") v.WithStatement("SELECT * FROM DUMMY LIMIT 1") r.Equal("SELECT * FROM DUMMY LIMIT 1", v.statement) v.WithStatement("SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1") q := v.Create() - r.Equal(`CREATE SECURE VIEW "test" COMMENT = 'great comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q) + r.Equal(`CREATE SECURE VIEW "test" COMMENT = 'great\' comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q) q = v.Secure() r.Equal(`ALTER VIEW "test" SET SECURE`, q) @@ -33,8 +31,8 @@ func TestView(t *testing.T) { q = v.Unsecure() r.Equal(`ALTER VIEW "test" UNSET SECURE`, q) - q = v.ChangeComment("bad comment") - r.Equal(`ALTER VIEW "test" SET COMMENT = 'bad comment'`, q) + q = v.ChangeComment("bad' comment") + r.Equal(`ALTER VIEW "test" SET COMMENT = 'bad\' comment'`, q) q = v.RemoveComment() r.Equal(`ALTER VIEW "test" UNSET COMMENT`, q) @@ -49,7 +47,7 @@ func TestView(t *testing.T) { r.Equal(v.QualifiedName(), `"mydb".."test"`) q = v.Create() - r.Equal(`CREATE SECURE VIEW "mydb".."test" COMMENT = 'great comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q) + r.Equal(`CREATE SECURE VIEW "mydb".."test" COMMENT = 'great\' comment' AS SELECT * FROM DUMMY WHERE blah = 'blahblah' LIMIT 1`, q) q = v.Secure() r.Equal(`ALTER VIEW "mydb".."test" SET SECURE`, q) @@ -58,7 +56,7 @@ func TestView(t *testing.T) { r.Equal(`SHOW VIEWS LIKE 'test' IN DATABASE "mydb"`, q) q = v.Drop() - r.Equal(`DROP VIEW "mydb".."test"`, q) + r.Equal(`DROP VIEW "mydb".."test"`, q) // FIXME invalid query } func TestQualifiedName(t *testing.T) {