diff --git a/pkg/sdk/network_policies_def.go b/pkg/sdk/network_policies_def.go index 68ae2f1434..cfca46ffa3 100644 --- a/pkg/sdk/network_policies_def.go +++ b/pkg/sdk/network_policies_def.go @@ -8,6 +8,16 @@ var ( ip = g.NewQueryStruct("IP"). Text("IP", g.KeywordOptions().SingleQuotes().Required()) + networkPoliciesAddNetworkRule = g.NewQueryStruct("AddNetworkRule"). + ListAssignment("ALLOWED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). + ListAssignment("BLOCKED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). + WithValidation(g.ExactlyOneValueSet, "AllowedNetworkRuleList", "BlockedNetworkRuleList") + + networkPoliciesRemoveNetworkRule = g.NewQueryStruct("RemoveNetworkRule"). + ListAssignment("ALLOWED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). + ListAssignment("BLOCKED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). + WithValidation(g.ExactlyOneValueSet, "AllowedNetworkRuleList", "BlockedNetworkRuleList") + NetworkPoliciesDef = g.NewInterface( "NetworkPolicies", "NetworkPolicy", @@ -20,6 +30,8 @@ var ( OrReplace(). SQL("NETWORK POLICY"). Name(). + ListAssignment("ALLOWED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). + ListAssignment("BLOCKED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). ListQueryStructField("AllowedIpList", ip, g.ParameterOptions().SQL("ALLOWED_IP_LIST").Parentheses()). ListQueryStructField("BlockedIpList", ip, g.ParameterOptions().SQL("BLOCKED_IP_LIST").Parentheses()). OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). @@ -35,16 +47,28 @@ var ( OptionalQueryStructField( "Set", g.NewQueryStruct("NetworkPolicySet"). + ListAssignment("ALLOWED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). + ListAssignment("BLOCKED_NETWORK_RULE_LIST", "SchemaObjectIdentifier", g.ParameterOptions().Parentheses()). ListQueryStructField("AllowedIpList", ip, g.ParameterOptions().SQL("ALLOWED_IP_LIST").Parentheses()). ListQueryStructField("BlockedIpList", ip, g.ParameterOptions().SQL("BLOCKED_IP_LIST").Parentheses()). OptionalTextAssignment("COMMENT", g.ParameterOptions().SingleQuotes()). - WithValidation(g.AtLeastOneValueSet, "AllowedIpList", "BlockedIpList", "Comment"), + WithValidation(g.AtLeastOneValueSet, "AllowedIpList", "BlockedIpList", "Comment", "AllowedNetworkRuleList", "BlockedNetworkRuleList"), g.KeywordOptions().SQL("SET"), ). + OptionalQueryStructField( + "Add", + networkPoliciesAddNetworkRule, + g.KeywordOptions().SQL("ADD"), + ). + OptionalQueryStructField( + "Remove", + networkPoliciesRemoveNetworkRule, + g.KeywordOptions().SQL("REMOVE"), + ). OptionalSQL("UNSET COMMENT"). Identifier("RenameTo", g.KindOfTPointer[AccountObjectIdentifier](), g.IdentifierOptions().SQL("RENAME TO")). WithValidation(g.ValidIdentifier, "name"). - WithValidation(g.ExactlyOneValueSet, "Set", "UnsetComment", "RenameTo"). + WithValidation(g.ExactlyOneValueSet, "Set", "UnsetComment", "RenameTo", "Add", "Remove"). WithValidation(g.ValidIdentifierIfSet, "RenameTo"), ). DropOperation( @@ -63,13 +87,17 @@ var ( Field("name", "string"). Field("comment", "string"). Field("entries_in_allowed_ip_list", "int"). - Field("entries_in_blocked_ip_list", "int"), + Field("entries_in_blocked_ip_list", "int"). + Field("entries_in_allowed_network_rules", "int"). + Field("entries_in_blocked_network_rules", "int"), g.PlainStruct("NetworkPolicy"). Field("CreatedOn", "string"). Field("Name", "string"). Field("Comment", "string"). Field("EntriesInAllowedIpList", "int"). - Field("EntriesInBlockedIpList", "int"), + Field("EntriesInBlockedIpList", "int"). + Field("EntriesInAllowedNetworkRules", "int"). + Field("EntriesInBlockedNetworkRules", "int"), g.NewQueryStruct("ShowNetworkPolicies"). Show(). SQL("NETWORK POLICIES"), diff --git a/pkg/sdk/network_policies_dto_builders_gen.go b/pkg/sdk/network_policies_dto_builders_gen.go index eca91a4c25..330f57d8ee 100644 --- a/pkg/sdk/network_policies_dto_builders_gen.go +++ b/pkg/sdk/network_policies_dto_builders_gen.go @@ -2,6 +2,8 @@ package sdk +import () + func NewCreateNetworkPolicyRequest( name AccountObjectIdentifier, ) *CreateNetworkPolicyRequest { @@ -15,6 +17,16 @@ func (s *CreateNetworkPolicyRequest) WithOrReplace(OrReplace *bool) *CreateNetwo return s } +func (s *CreateNetworkPolicyRequest) WithAllowedNetworkRuleList(AllowedNetworkRuleList []SchemaObjectIdentifier) *CreateNetworkPolicyRequest { + s.AllowedNetworkRuleList = AllowedNetworkRuleList + return s +} + +func (s *CreateNetworkPolicyRequest) WithBlockedNetworkRuleList(BlockedNetworkRuleList []SchemaObjectIdentifier) *CreateNetworkPolicyRequest { + s.BlockedNetworkRuleList = BlockedNetworkRuleList + return s +} + func (s *CreateNetworkPolicyRequest) WithAllowedIpList(AllowedIpList []IPRequest) *CreateNetworkPolicyRequest { s.AllowedIpList = AllowedIpList return s @@ -56,6 +68,16 @@ func (s *AlterNetworkPolicyRequest) WithSet(Set *NetworkPolicySetRequest) *Alter return s } +func (s *AlterNetworkPolicyRequest) WithAdd(Add *AddNetworkRuleRequest) *AlterNetworkPolicyRequest { + s.Add = Add + return s +} + +func (s *AlterNetworkPolicyRequest) WithRemove(Remove *RemoveNetworkRuleRequest) *AlterNetworkPolicyRequest { + s.Remove = Remove + return s +} + func (s *AlterNetworkPolicyRequest) WithUnsetComment(UnsetComment *bool) *AlterNetworkPolicyRequest { s.UnsetComment = UnsetComment return s @@ -70,6 +92,16 @@ func NewNetworkPolicySetRequest() *NetworkPolicySetRequest { return &NetworkPolicySetRequest{} } +func (s *NetworkPolicySetRequest) WithAllowedNetworkRuleList(AllowedNetworkRuleList []SchemaObjectIdentifier) *NetworkPolicySetRequest { + s.AllowedNetworkRuleList = AllowedNetworkRuleList + return s +} + +func (s *NetworkPolicySetRequest) WithBlockedNetworkRuleList(BlockedNetworkRuleList []SchemaObjectIdentifier) *NetworkPolicySetRequest { + s.BlockedNetworkRuleList = BlockedNetworkRuleList + return s +} + func (s *NetworkPolicySetRequest) WithAllowedIpList(AllowedIpList []IPRequest) *NetworkPolicySetRequest { s.AllowedIpList = AllowedIpList return s @@ -85,6 +117,34 @@ func (s *NetworkPolicySetRequest) WithComment(Comment *string) *NetworkPolicySet return s } +func NewAddNetworkRuleRequest() *AddNetworkRuleRequest { + return &AddNetworkRuleRequest{} +} + +func (s *AddNetworkRuleRequest) WithAllowedNetworkRuleList(AllowedNetworkRuleList []SchemaObjectIdentifier) *AddNetworkRuleRequest { + s.AllowedNetworkRuleList = AllowedNetworkRuleList + return s +} + +func (s *AddNetworkRuleRequest) WithBlockedNetworkRuleList(BlockedNetworkRuleList []SchemaObjectIdentifier) *AddNetworkRuleRequest { + s.BlockedNetworkRuleList = BlockedNetworkRuleList + return s +} + +func NewRemoveNetworkRuleRequest() *RemoveNetworkRuleRequest { + return &RemoveNetworkRuleRequest{} +} + +func (s *RemoveNetworkRuleRequest) WithAllowedNetworkRuleList(AllowedNetworkRuleList []SchemaObjectIdentifier) *RemoveNetworkRuleRequest { + s.AllowedNetworkRuleList = AllowedNetworkRuleList + return s +} + +func (s *RemoveNetworkRuleRequest) WithBlockedNetworkRuleList(BlockedNetworkRuleList []SchemaObjectIdentifier) *RemoveNetworkRuleRequest { + s.BlockedNetworkRuleList = BlockedNetworkRuleList + return s +} + func NewDropNetworkPolicyRequest( name AccountObjectIdentifier, ) *DropNetworkPolicyRequest { diff --git a/pkg/sdk/network_policies_dto_gen.go b/pkg/sdk/network_policies_dto_gen.go index 8902f05440..192ec211ca 100644 --- a/pkg/sdk/network_policies_dto_gen.go +++ b/pkg/sdk/network_policies_dto_gen.go @@ -11,11 +11,13 @@ var ( ) type CreateNetworkPolicyRequest struct { - OrReplace *bool - name AccountObjectIdentifier // required - AllowedIpList []IPRequest - BlockedIpList []IPRequest - Comment *string + OrReplace *bool + name AccountObjectIdentifier // required + AllowedNetworkRuleList []SchemaObjectIdentifier + BlockedNetworkRuleList []SchemaObjectIdentifier + AllowedIpList []IPRequest + BlockedIpList []IPRequest + Comment *string } func (r *CreateNetworkPolicyRequest) GetName() AccountObjectIdentifier { @@ -30,14 +32,28 @@ type AlterNetworkPolicyRequest struct { IfExists *bool name AccountObjectIdentifier // required Set *NetworkPolicySetRequest + Add *AddNetworkRuleRequest + Remove *RemoveNetworkRuleRequest UnsetComment *bool RenameTo *AccountObjectIdentifier } type NetworkPolicySetRequest struct { - AllowedIpList []IPRequest - BlockedIpList []IPRequest - Comment *string + AllowedNetworkRuleList []SchemaObjectIdentifier + BlockedNetworkRuleList []SchemaObjectIdentifier + AllowedIpList []IPRequest + BlockedIpList []IPRequest + Comment *string +} + +type AddNetworkRuleRequest struct { + AllowedNetworkRuleList []SchemaObjectIdentifier + BlockedNetworkRuleList []SchemaObjectIdentifier +} + +type RemoveNetworkRuleRequest struct { + AllowedNetworkRuleList []SchemaObjectIdentifier + BlockedNetworkRuleList []SchemaObjectIdentifier } type DropNetworkPolicyRequest struct { diff --git a/pkg/sdk/network_policies_gen.go b/pkg/sdk/network_policies_gen.go index d1aba59894..b1d03a2dff 100644 --- a/pkg/sdk/network_policies_gen.go +++ b/pkg/sdk/network_policies_gen.go @@ -13,13 +13,15 @@ type NetworkPolicies interface { // CreateNetworkPolicyOptions is based on https://docs.snowflake.com/en/sql-reference/sql/create-network-policy. type CreateNetworkPolicyOptions struct { - create bool `ddl:"static" sql:"CREATE"` - OrReplace *bool `ddl:"keyword" sql:"OR REPLACE"` - networkPolicy bool `ddl:"static" sql:"NETWORK POLICY"` - name AccountObjectIdentifier `ddl:"identifier"` - AllowedIpList []IP `ddl:"parameter,parentheses" sql:"ALLOWED_IP_LIST"` - BlockedIpList []IP `ddl:"parameter,parentheses" sql:"BLOCKED_IP_LIST"` - Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` + create bool `ddl:"static" sql:"CREATE"` + OrReplace *bool `ddl:"keyword" sql:"OR REPLACE"` + networkPolicy bool `ddl:"static" sql:"NETWORK POLICY"` + name AccountObjectIdentifier `ddl:"identifier"` + AllowedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"ALLOWED_NETWORK_RULE_LIST"` + BlockedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"BLOCKED_NETWORK_RULE_LIST"` + AllowedIpList []IP `ddl:"parameter,parentheses" sql:"ALLOWED_IP_LIST"` + BlockedIpList []IP `ddl:"parameter,parentheses" sql:"BLOCKED_IP_LIST"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` } type IP struct { @@ -33,14 +35,28 @@ type AlterNetworkPolicyOptions struct { IfExists *bool `ddl:"keyword" sql:"IF EXISTS"` name AccountObjectIdentifier `ddl:"identifier"` Set *NetworkPolicySet `ddl:"keyword" sql:"SET"` + Add *AddNetworkRule `ddl:"keyword" sql:"ADD"` + Remove *RemoveNetworkRule `ddl:"keyword" sql:"REMOVE"` UnsetComment *bool `ddl:"keyword" sql:"UNSET COMMENT"` RenameTo *AccountObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` } type NetworkPolicySet struct { - AllowedIpList []IP `ddl:"parameter,parentheses" sql:"ALLOWED_IP_LIST"` - BlockedIpList []IP `ddl:"parameter,parentheses" sql:"BLOCKED_IP_LIST"` - Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` + AllowedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"ALLOWED_NETWORK_RULE_LIST"` + BlockedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"BLOCKED_NETWORK_RULE_LIST"` + AllowedIpList []IP `ddl:"parameter,parentheses" sql:"ALLOWED_IP_LIST"` + BlockedIpList []IP `ddl:"parameter,parentheses" sql:"BLOCKED_IP_LIST"` + Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` +} + +type AddNetworkRule struct { + AllowedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"ALLOWED_NETWORK_RULE_LIST"` + BlockedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"BLOCKED_NETWORK_RULE_LIST"` +} + +type RemoveNetworkRule struct { + AllowedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"ALLOWED_NETWORK_RULE_LIST"` + BlockedNetworkRuleList []SchemaObjectIdentifier `ddl:"parameter,parentheses" sql:"BLOCKED_NETWORK_RULE_LIST"` } // DropNetworkPolicyOptions is based on https://docs.snowflake.com/en/sql-reference/sql/drop-network-policy. @@ -58,19 +74,23 @@ type ShowNetworkPolicyOptions struct { } type showNetworkPolicyDBRow struct { - CreatedOn string `db:"created_on"` - Name string `db:"name"` - Comment string `db:"comment"` - EntriesInAllowedIpList int `db:"entries_in_allowed_ip_list"` - EntriesInBlockedIpList int `db:"entries_in_blocked_ip_list"` + CreatedOn string `db:"created_on"` + Name string `db:"name"` + Comment string `db:"comment"` + EntriesInAllowedIpList int `db:"entries_in_allowed_ip_list"` + EntriesInBlockedIpList int `db:"entries_in_blocked_ip_list"` + EntriesInAllowedNetworkRules int `db:"entries_in_allowed_network_rules"` + EntriesInBlockedNetworkRules int `db:"entries_in_blocked_network_rules"` } type NetworkPolicy struct { - CreatedOn string - Name string - Comment string - EntriesInAllowedIpList int - EntriesInBlockedIpList int + CreatedOn string + Name string + Comment string + EntriesInAllowedIpList int + EntriesInBlockedIpList int + EntriesInAllowedNetworkRules int + EntriesInBlockedNetworkRules int } // DescribeNetworkPolicyOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-network-policy. diff --git a/pkg/sdk/network_policies_gen_test.go b/pkg/sdk/network_policies_gen_test.go index 697661c4ff..a6103ca896 100644 --- a/pkg/sdk/network_policies_gen_test.go +++ b/pkg/sdk/network_policies_gen_test.go @@ -7,14 +7,18 @@ import ( func TestNetworkPolicies_Create(t *testing.T) { id := RandomAccountObjectIdentifier() + allowedNetworkRule := RandomSchemaObjectIdentifier() + blockedNetworkRule := RandomSchemaObjectIdentifier() // Minimal valid CreateNetworkPolicyOptions defaultOpts := func() *CreateNetworkPolicyOptions { return &CreateNetworkPolicyOptions{ - OrReplace: Bool(true), - name: id, - AllowedIpList: []IP{{IP: "123.0.0.1"}, {IP: "321.0.0.1"}}, - BlockedIpList: []IP{{IP: "123.0.0.1"}, {IP: "321.0.0.1"}}, - Comment: String("some_comment"), + OrReplace: Bool(true), + name: id, + AllowedIpList: []IP{{IP: "123.0.0.1"}, {IP: "321.0.0.1"}}, + BlockedIpList: []IP{{IP: "123.0.0.1"}, {IP: "321.0.0.1"}}, + AllowedNetworkRuleList: []SchemaObjectIdentifier{allowedNetworkRule}, + BlockedNetworkRuleList: []SchemaObjectIdentifier{blockedNetworkRule}, + Comment: String("some_comment"), } } @@ -31,7 +35,7 @@ func TestNetworkPolicies_Create(t *testing.T) { t.Run("all options", func(t *testing.T) { opts := defaultOpts() - assertOptsValidAndSQLEquals(t, opts, "CREATE OR REPLACE NETWORK POLICY %s ALLOWED_IP_LIST = ('123.0.0.1', '321.0.0.1') BLOCKED_IP_LIST = ('123.0.0.1', '321.0.0.1') COMMENT = 'some_comment'", opts.name.FullyQualifiedName()) + assertOptsValidAndSQLEquals(t, opts, "CREATE OR REPLACE NETWORK POLICY %s ALLOWED_NETWORK_RULE_LIST = (%s) BLOCKED_NETWORK_RULE_LIST = (%s) ALLOWED_IP_LIST = ('123.0.0.1', '321.0.0.1') BLOCKED_IP_LIST = ('123.0.0.1', '321.0.0.1') COMMENT = 'some_comment'", opts.name.FullyQualifiedName(), allowedNetworkRule.FullyQualifiedName(), blockedNetworkRule.FullyQualifiedName()) }) } @@ -58,15 +62,37 @@ func TestNetworkPolicies_Alter(t *testing.T) { assertOptsInvalidJoinedErrors(t, opts, ErrInvalidObjectIdentifier) }) - t.Run("validation: exactly one field from [opts.Set opts.UnsetComment opts.RenameTo] should be present", func(t *testing.T) { + t.Run("validation: exactly one field from [opts.Set opts.UnsetComment opts.RenameTo opts.Add opts.Remove] should be present", func(t *testing.T) { opts := defaultOpts() - assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterNetworkPolicyOptions", "Set", "UnsetComment", "RenameTo")) + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterNetworkPolicyOptions", "Set", "UnsetComment", "RenameTo", "Add", "Remove")) }) - t.Run("validation: at least one of the fields [opts.Set.AllowedIpList opts.Set.BlockedIpList opts.Set.Comment] should be set", func(t *testing.T) { + t.Run("validation: at least one of the fields [opts.Set.AllowedIpList opts.Set.BlockedIpList opts.Set.Comment opts.Set.AllowedNetworkRuleList opts.Set.BlockedNetworkRuleList] should be set", func(t *testing.T) { opts := defaultOpts() opts.Set = &NetworkPolicySet{} - assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterNetworkPolicyOptions.Set", "AllowedIpList", "BlockedIpList", "Comment")) + assertOptsInvalidJoinedErrors(t, opts, errAtLeastOneOf("AlterNetworkPolicyOptions.Set", "AllowedIpList", "BlockedIpList", "Comment", "AllowedNetworkRuleList", "BlockedNetworkRuleList")) + }) + + t.Run("validation: exactly one field from [opts.Add.AllowedNetworkRuleList opts.Add.BlockedNetworkRuleList] should be present", func(t *testing.T) { + allowedNetworkRule := RandomSchemaObjectIdentifier() + blockedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Add = &AddNetworkRule{ + AllowedNetworkRuleList: []SchemaObjectIdentifier{allowedNetworkRule}, + BlockedNetworkRuleList: []SchemaObjectIdentifier{blockedNetworkRule}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterNetworkPolicyOptions.Add", "AllowedNetworkRuleList", "BlockedNetworkRuleList")) + }) + + t.Run("validation: exactly one field from [opts.Remove.AllowedNetworkRuleList opts.Remove.BlockedNetworkRuleList] should be present", func(t *testing.T) { + allowedNetworkRule := RandomSchemaObjectIdentifier() + blockedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Remove = &RemoveNetworkRule{ + AllowedNetworkRuleList: []SchemaObjectIdentifier{allowedNetworkRule}, + BlockedNetworkRuleList: []SchemaObjectIdentifier{blockedNetworkRule}, + } + assertOptsInvalidJoinedErrors(t, opts, errExactlyOneOf("AlterNetworkPolicyOptions.Remove", "AllowedNetworkRuleList", "BlockedNetworkRuleList")) }) t.Run("set allowed ip list", func(t *testing.T) { @@ -85,6 +111,60 @@ func TestNetworkPolicies_Alter(t *testing.T) { assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s SET BLOCKED_IP_LIST = ('123.0.0.1')", id.FullyQualifiedName()) }) + t.Run("set allowed network rule list", func(t *testing.T) { + allowedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Set = &NetworkPolicySet{ + AllowedNetworkRuleList: []SchemaObjectIdentifier{allowedNetworkRule}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s SET ALLOWED_NETWORK_RULE_LIST = (%s)", id.FullyQualifiedName(), allowedNetworkRule.FullyQualifiedName()) + }) + + t.Run("set blocked network rule list", func(t *testing.T) { + blockedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Set = &NetworkPolicySet{ + BlockedNetworkRuleList: []SchemaObjectIdentifier{blockedNetworkRule}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s SET BLOCKED_NETWORK_RULE_LIST = (%s)", id.FullyQualifiedName(), blockedNetworkRule.FullyQualifiedName()) + }) + + t.Run("add allowed network rule", func(t *testing.T) { + allowedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Add = &AddNetworkRule{ + AllowedNetworkRuleList: []SchemaObjectIdentifier{allowedNetworkRule}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s ADD ALLOWED_NETWORK_RULE_LIST = (%s)", id.FullyQualifiedName(), allowedNetworkRule.FullyQualifiedName()) + }) + + t.Run("add blocked network rule", func(t *testing.T) { + blockedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Add = &AddNetworkRule{ + BlockedNetworkRuleList: []SchemaObjectIdentifier{blockedNetworkRule}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s ADD BLOCKED_NETWORK_RULE_LIST = (%s)", id.FullyQualifiedName(), blockedNetworkRule.FullyQualifiedName()) + }) + + t.Run("remove allowed network rule", func(t *testing.T) { + allowedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Remove = &RemoveNetworkRule{ + AllowedNetworkRuleList: []SchemaObjectIdentifier{allowedNetworkRule}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s REMOVE ALLOWED_NETWORK_RULE_LIST = (%s)", id.FullyQualifiedName(), allowedNetworkRule.FullyQualifiedName()) + }) + + t.Run("remove blocked network rule", func(t *testing.T) { + blockedNetworkRule := RandomSchemaObjectIdentifier() + opts := defaultOpts() + opts.Remove = &RemoveNetworkRule{ + BlockedNetworkRuleList: []SchemaObjectIdentifier{blockedNetworkRule}, + } + assertOptsValidAndSQLEquals(t, opts, "ALTER NETWORK POLICY IF EXISTS %s REMOVE BLOCKED_NETWORK_RULE_LIST = (%s)", id.FullyQualifiedName(), blockedNetworkRule.FullyQualifiedName()) + }) + t.Run("set comment", func(t *testing.T) { opts := defaultOpts() opts.Set = &NetworkPolicySet{ diff --git a/pkg/sdk/network_policies_impl_gen.go b/pkg/sdk/network_policies_impl_gen.go index b8b787c225..98aa07cf42 100644 --- a/pkg/sdk/network_policies_impl_gen.go +++ b/pkg/sdk/network_policies_impl_gen.go @@ -50,17 +50,19 @@ func (v *networkPolicies) Describe(ctx context.Context, id AccountObjectIdentifi opts := &DescribeNetworkPolicyOptions{ name: id, } - s, err := validateAndQuery[describeNetworkPolicyDBRow](v.client, ctx, opts) + rows, err := validateAndQuery[describeNetworkPolicyDBRow](v.client, ctx, opts) if err != nil { return nil, err } - return convertRows[describeNetworkPolicyDBRow, NetworkPolicyDescription](s), nil + return convertRows[describeNetworkPolicyDBRow, NetworkPolicyDescription](rows), nil } func (r *CreateNetworkPolicyRequest) toOpts() *CreateNetworkPolicyOptions { opts := &CreateNetworkPolicyOptions{ - OrReplace: r.OrReplace, - name: r.name, + OrReplace: r.OrReplace, + name: r.name, + AllowedNetworkRuleList: r.AllowedNetworkRuleList, + BlockedNetworkRuleList: r.BlockedNetworkRuleList, Comment: r.Comment, } @@ -91,6 +93,9 @@ func (r *AlterNetworkPolicyRequest) toOpts() *AlterNetworkPolicyOptions { } if r.Set != nil { opts.Set = &NetworkPolicySet{ + AllowedNetworkRuleList: r.Set.AllowedNetworkRuleList, + BlockedNetworkRuleList: r.Set.BlockedNetworkRuleList, + Comment: r.Set.Comment, } if r.Set.AllowedIpList != nil { @@ -108,6 +113,18 @@ func (r *AlterNetworkPolicyRequest) toOpts() *AlterNetworkPolicyOptions { opts.Set.BlockedIpList = s } } + if r.Add != nil { + opts.Add = &AddNetworkRule{ + AllowedNetworkRuleList: r.Add.AllowedNetworkRuleList, + BlockedNetworkRuleList: r.Add.BlockedNetworkRuleList, + } + } + if r.Remove != nil { + opts.Remove = &RemoveNetworkRule{ + AllowedNetworkRuleList: r.Remove.AllowedNetworkRuleList, + BlockedNetworkRuleList: r.Remove.BlockedNetworkRuleList, + } + } return opts } @@ -126,11 +143,13 @@ func (r *ShowNetworkPolicyRequest) toOpts() *ShowNetworkPolicyOptions { func (r showNetworkPolicyDBRow) convert() *NetworkPolicy { return &NetworkPolicy{ - CreatedOn: r.CreatedOn, - Name: r.Name, - Comment: r.Comment, - EntriesInAllowedIpList: r.EntriesInAllowedIpList, - EntriesInBlockedIpList: r.EntriesInBlockedIpList, + CreatedOn: r.CreatedOn, + Name: r.Name, + Comment: r.Comment, + EntriesInAllowedIpList: r.EntriesInAllowedIpList, + EntriesInBlockedIpList: r.EntriesInBlockedIpList, + EntriesInAllowedNetworkRules: r.EntriesInAllowedNetworkRules, + EntriesInBlockedNetworkRules: r.EntriesInBlockedNetworkRules, } } diff --git a/pkg/sdk/network_policies_validations_gen.go b/pkg/sdk/network_policies_validations_gen.go index 2b06d456e4..7da9f3ac31 100644 --- a/pkg/sdk/network_policies_validations_gen.go +++ b/pkg/sdk/network_policies_validations_gen.go @@ -1,7 +1,5 @@ package sdk -import "errors" - var ( _ validatable = new(CreateNetworkPolicyOptions) _ validatable = new(AlterNetworkPolicyOptions) @@ -12,63 +10,73 @@ var ( func (opts *CreateNetworkPolicyOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *AlterNetworkPolicyOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - if ok := exactlyOneValueSet(opts.Set, opts.UnsetComment, opts.RenameTo); !ok { - errs = append(errs, errExactlyOneOf("AlterNetworkPolicyOptions", "Set", "UnsetComment", "RenameTo")) + if !exactlyOneValueSet(opts.Set, opts.UnsetComment, opts.RenameTo, opts.Add, opts.Remove) { + errs = append(errs, errExactlyOneOf("AlterNetworkPolicyOptions", "Set", "UnsetComment", "RenameTo", "Add", "Remove")) } - if valueSet(opts.RenameTo) && !ValidObjectIdentifier(opts.RenameTo) { + if opts.RenameTo != nil && !ValidObjectIdentifier(opts.RenameTo) { errs = append(errs, ErrInvalidObjectIdentifier) } if valueSet(opts.Set) { - if ok := anyValueSet(opts.Set.AllowedIpList, opts.Set.BlockedIpList, opts.Set.Comment); !ok { - errs = append(errs, errAtLeastOneOf("AlterNetworkPolicyOptions.Set", "AllowedIpList", "BlockedIpList", "Comment")) + if !anyValueSet(opts.Set.AllowedIpList, opts.Set.BlockedIpList, opts.Set.Comment, opts.Set.AllowedNetworkRuleList, opts.Set.BlockedNetworkRuleList) { + errs = append(errs, errAtLeastOneOf("AlterNetworkPolicyOptions.Set", "AllowedIpList", "BlockedIpList", "Comment", "AllowedNetworkRuleList", "BlockedNetworkRuleList")) + } + } + if valueSet(opts.Add) { + if !exactlyOneValueSet(opts.Add.AllowedNetworkRuleList, opts.Add.BlockedNetworkRuleList) { + errs = append(errs, errExactlyOneOf("AlterNetworkPolicyOptions.Add", "AllowedNetworkRuleList", "BlockedNetworkRuleList")) + } + } + if valueSet(opts.Remove) { + if !exactlyOneValueSet(opts.Remove.AllowedNetworkRuleList, opts.Remove.BlockedNetworkRuleList) { + errs = append(errs, errExactlyOneOf("AlterNetworkPolicyOptions.Remove", "AllowedNetworkRuleList", "BlockedNetworkRuleList")) } } - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *DropNetworkPolicyOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *ShowNetworkPolicyOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error - return errors.Join(errs...) + return JoinErrors(errs...) } func (opts *DescribeNetworkPolicyOptions) validate() error { if opts == nil { - return errors.Join(ErrNilOptions) + return ErrNilOptions } var errs []error if !ValidObjectIdentifier(opts.name) { errs = append(errs, ErrInvalidObjectIdentifier) } - return errors.Join(errs...) + return JoinErrors(errs...) } diff --git a/pkg/sdk/testint/network_policies_gen_integration_test.go b/pkg/sdk/testint/network_policies_gen_integration_test.go index a0d8d219fa..9ac7a6d768 100644 --- a/pkg/sdk/testint/network_policies_gen_integration_test.go +++ b/pkg/sdk/testint/network_policies_gen_integration_test.go @@ -6,6 +6,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/internal/random" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -17,6 +18,21 @@ func TestInt_NetworkPolicies(t *testing.T) { allowedIP := sdk.NewIPRequest("123.0.0.1") blockedIP := sdk.NewIPRequest("125.0.0.1") blockedIP2 := sdk.NewIPRequest("124.0.0.1") + + databaseTest, schemaTest := testDb(t), testSchema(t) + createNetworkRuleHandle := func(t *testing.T, client *sdk.Client) sdk.SchemaObjectIdentifier { + t.Helper() + + id := sdk.NewSchemaObjectIdentifier(databaseTest.Name, schemaTest.Name, random.AlphaN(4)) + err := client.NetworkRules.Create(ctx, sdk.NewCreateNetworkRuleRequest(id, sdk.NetworkRuleTypeIpv4, []sdk.NetworkRuleValue{}, sdk.NetworkRuleModeIngress)) + require.NoError(t, err) + t.Cleanup(func() { + err := client.NetworkRules.Drop(ctx, sdk.NewDropNetworkRuleRequest(id)) + require.NoError(t, err) + }) + return id + } + defaultCreateRequest := func() *sdk.CreateNetworkPolicyRequest { id := sdk.RandomAccountObjectIdentifier() comment := "some_comment" @@ -35,6 +51,11 @@ func TestInt_NetworkPolicies(t *testing.T) { t.Run("Create", func(t *testing.T) { req := defaultCreateRequest() + allowedNetworkRule := createNetworkRuleHandle(t, client) + blockedNetworkRule := createNetworkRuleHandle(t, client) + req = req.WithAllowedNetworkRuleList([]sdk.SchemaObjectIdentifier{allowedNetworkRule}) + req = req.WithBlockedNetworkRuleList([]sdk.SchemaObjectIdentifier{blockedNetworkRule}) + err, dropNetworkPolicy := createNetworkPolicy(t, client, req) require.NoError(t, err) t.Cleanup(dropNetworkPolicy) @@ -47,6 +68,8 @@ func TestInt_NetworkPolicies(t *testing.T) { assert.Equal(t, *req.Comment, np.Comment) assert.Equal(t, len(req.AllowedIpList), np.EntriesInAllowedIpList) assert.Equal(t, len(req.BlockedIpList), np.EntriesInBlockedIpList) + assert.Equal(t, len(req.AllowedNetworkRuleList), np.EntriesInAllowedNetworkRules) + assert.Equal(t, len(req.BlockedNetworkRuleList), np.EntriesInBlockedNetworkRules) }) t.Run("Alter - set allowed ip list", func(t *testing.T) { @@ -85,6 +108,90 @@ func TestInt_NetworkPolicies(t *testing.T) { assert.Equal(t, 1, np.EntriesInBlockedIpList) }) + t.Run("Alter - set allowed network rule list", func(t *testing.T) { + allowedNetworkRule := createNetworkRuleHandle(t, client) + + req := defaultCreateRequest() + err, dropNetworkPolicy := createNetworkPolicy(t, client, req) + require.NoError(t, err) + t.Cleanup(dropNetworkPolicy) + + err = client.NetworkPolicies.Alter(ctx, sdk.NewAlterNetworkPolicyRequest(req.GetName()). + WithSet(sdk.NewNetworkPolicySetRequest().WithAllowedNetworkRuleList([]sdk.SchemaObjectIdentifier{allowedNetworkRule}))) + require.NoError(t, err) + + np, err := client.NetworkPolicies.ShowByID(ctx, req.GetName()) + require.NoError(t, err) + assert.Equal(t, 1, np.EntriesInAllowedNetworkRules) + }) + + t.Run("Alter - set blocked network rule list", func(t *testing.T) { + blockedNetworkRule := createNetworkRuleHandle(t, client) + + req := defaultCreateRequest() + err, dropNetworkPolicy := createNetworkPolicy(t, client, req) + require.NoError(t, err) + t.Cleanup(dropNetworkPolicy) + + err = client.NetworkPolicies.Alter(ctx, sdk.NewAlterNetworkPolicyRequest(req.GetName()). + WithSet(sdk.NewNetworkPolicySetRequest().WithBlockedNetworkRuleList([]sdk.SchemaObjectIdentifier{blockedNetworkRule}))) + require.NoError(t, err) + + np, err := client.NetworkPolicies.ShowByID(ctx, req.GetName()) + require.NoError(t, err) + assert.Equal(t, 1, np.EntriesInBlockedNetworkRules) + }) + + t.Run("Alter - add and remove allowed network rule list", func(t *testing.T) { + allowedNetworkRule := createNetworkRuleHandle(t, client) + + req := defaultCreateRequest() + err, dropNetworkPolicy := createNetworkPolicy(t, client, req) + require.NoError(t, err) + t.Cleanup(dropNetworkPolicy) + + err = client.NetworkPolicies.Alter(ctx, sdk.NewAlterNetworkPolicyRequest(req.GetName()). + WithAdd(sdk.NewAddNetworkRuleRequest().WithAllowedNetworkRuleList([]sdk.SchemaObjectIdentifier{allowedNetworkRule}))) + require.NoError(t, err) + + np, err := client.NetworkPolicies.ShowByID(ctx, req.GetName()) + require.NoError(t, err) + assert.Equal(t, 1, np.EntriesInAllowedNetworkRules) + + err = client.NetworkPolicies.Alter(ctx, sdk.NewAlterNetworkPolicyRequest(req.GetName()). + WithRemove(sdk.NewRemoveNetworkRuleRequest().WithAllowedNetworkRuleList([]sdk.SchemaObjectIdentifier{allowedNetworkRule}))) + require.NoError(t, err) + + np, err = client.NetworkPolicies.ShowByID(ctx, req.GetName()) + require.NoError(t, err) + assert.Equal(t, 0, np.EntriesInAllowedNetworkRules) + }) + + t.Run("Alter - add and remove blocked network rule list", func(t *testing.T) { + blockedNetworkRule := createNetworkRuleHandle(t, client) + + req := defaultCreateRequest() + err, dropNetworkPolicy := createNetworkPolicy(t, client, req) + require.NoError(t, err) + t.Cleanup(dropNetworkPolicy) + + err = client.NetworkPolicies.Alter(ctx, sdk.NewAlterNetworkPolicyRequest(req.GetName()). + WithAdd(sdk.NewAddNetworkRuleRequest().WithBlockedNetworkRuleList([]sdk.SchemaObjectIdentifier{blockedNetworkRule}))) + require.NoError(t, err) + + np, err := client.NetworkPolicies.ShowByID(ctx, req.GetName()) + require.NoError(t, err) + assert.Equal(t, 1, np.EntriesInBlockedNetworkRules) + + err = client.NetworkPolicies.Alter(ctx, sdk.NewAlterNetworkPolicyRequest(req.GetName()). + WithRemove(sdk.NewRemoveNetworkRuleRequest().WithBlockedNetworkRuleList([]sdk.SchemaObjectIdentifier{blockedNetworkRule}))) + require.NoError(t, err) + + np, err = client.NetworkPolicies.ShowByID(ctx, req.GetName()) + require.NoError(t, err) + assert.Equal(t, 0, np.EntriesInAllowedNetworkRules) + }) + t.Run("Alter - set comment", func(t *testing.T) { req := defaultCreateRequest() err, dropNetworkPolicy := createNetworkPolicy(t, client, req)