diff --git a/pkg/packetfilter/nftables/nftables.go b/pkg/packetfilter/nftables/nftables.go index e96a02ff7..de64ee145 100644 --- a/pkg/packetfilter/nftables/nftables.go +++ b/pkg/packetfilter/nftables/nftables.go @@ -85,9 +85,13 @@ func New() (packetfilter.Driver, error) { return nil, errors.Wrap(err, "error creating knftables") } + return NewWithNft(nft), nil +} + +func NewWithNft(nft knftables.Interface) packetfilter.Driver { return &packetFilter{ nftables: nft, - }, nil + } } func (p *packetFilter) ChainExists(_ packetfilter.TableType, chain string) (bool, error) { diff --git a/pkg/packetfilter/nftables/nftables_test.go b/pkg/packetfilter/nftables/nftables_test.go index a0dea10a6..dcb986792 100644 --- a/pkg/packetfilter/nftables/nftables_test.go +++ b/pkg/packetfilter/nftables/nftables_test.go @@ -19,10 +19,13 @@ limitations under the License. package nftables_test import ( + "context" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/submariner-io/submariner/pkg/packetfilter" "github.com/submariner-io/submariner/pkg/packetfilter/nftables" + "sigs.k8s.io/knftables" ) var _ = Describe("Rule conversion", func() { @@ -97,6 +100,245 @@ var _ = Describe("Rule conversion", func() { }) }) +var _ = Describe("Interface", func() { + const ( + chainName = "egress" + setName = "my-set" + ) + + var ( + fakeKnftables *fakeKnftablesWrapper + pf packetfilter.Driver + + setInfo = &packetfilter.SetInfo{ + Name: setName, + Table: packetfilter.TableTypeNAT, + Family: packetfilter.SetFamilyV4, + } + ) + + BeforeEach(func() { + fakeKnftables = &fakeKnftablesWrapper{knftables.NewFake(knftables.IPv4Family, "submariner")} + pf = nftables.NewWithNft(fakeKnftables) + }) + + assertRules := func(r ...*packetfilter.Rule) { + rules, err := pf.List(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + + if len(r) == 0 { + Expect(rules).To(BeEmpty()) + } else { + Expect(rules).To(Equal(r)) + } + } + + assertSets := func(s ...string) { + sets, err := fakeKnftables.List(context.TODO(), "set") + Expect(err).To(Succeed()) + + if len(s) == 0 { + Expect(sets).To(BeEmpty()) + } else { + Expect(sets).To(Equal(s)) + } + } + + assertEntries := func(set packetfilter.NamedSet, e ...string) { + entries, err := set.ListEntries() + Expect(err).To(Succeed()) + + if len(e) == 0 { + Expect(entries).To(BeEmpty()) + } else { + Expect(entries).To(Equal(e)) + } + } + + Specify("Creating and deleting a chain", func() { + err := pf.CreateChainIfNotExists(packetfilter.TableTypeNAT, &packetfilter.Chain{ + Name: chainName, + }) + Expect(err).To(Succeed()) + + exists, err := pf.ChainExists(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + Expect(exists).To(BeTrue()) + + // Already exists - should succeed. + err = pf.CreateChainIfNotExists(packetfilter.TableTypeNAT, &packetfilter.Chain{ + Name: chainName, + }) + Expect(err).To(Succeed()) + + err = pf.DeleteChain(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + + exists, err = pf.ChainExists(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + Expect(exists).To(BeFalse()) + + // After deletion, these should be a no-op. + err = pf.DeleteChain(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + + err = pf.ClearChain(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + }) + + Specify("Creating and deleting an IP hook chain", func() { + chainIPHook := &packetfilter.ChainIPHook{ + Name: chainName, + Type: packetfilter.ChainTypeNAT, + Hook: packetfilter.ChainHookPrerouting, + Priority: packetfilter.ChainPriorityFirst, + } + + err := pf.CreateIPHookChainIfNotExists(chainIPHook) + Expect(err).To(Succeed()) + + exists, err := pf.ChainExists(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + Expect(exists).To(BeTrue()) + + // Already exists - should succeed. + err = pf.CreateIPHookChainIfNotExists(chainIPHook) + Expect(err).To(Succeed()) + + err = pf.DeleteIPHookChain(chainIPHook) + Expect(err).To(Succeed()) + + exists, err = pf.ChainExists(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + Expect(exists).To(BeFalse()) + + // After deletion, these should be a no-op. + err = pf.DeleteIPHookChain(chainIPHook) + Expect(err).To(Succeed()) + + err = pf.ClearChain(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + }) + + Specify("Adding and deleting rules", func() { + err := pf.CreateChainIfNotExists(packetfilter.TableTypeNAT, &packetfilter.Chain{ + Name: chainName, + }) + Expect(err).To(Succeed()) + + By("Append the first rule") + + rule1 := &packetfilter.Rule{ + Proto: packetfilter.RuleProtoICMP, + DestCIDR: "171.254.1.0/24", + DnatCIDR: "172.254.1.0/24", + Action: packetfilter.RuleActionDNAT, + } + + err = pf.Append(packetfilter.TableTypeNAT, chainName, rule1) + Expect(err).To(Succeed()) + + assertRules(rule1) + + By("Prepend the second rule") + + rule2 := &packetfilter.Rule{ + Proto: packetfilter.RuleProtoUDP, + DestCIDR: "170.254.1.0/24", + SrcCIDR: "171.254.1.0/24", + DPort: "d-port", + Action: packetfilter.RuleActionAccept, + } + + err = pf.Insert(packetfilter.TableTypeNAT, chainName, 1, rule2) + Expect(err).To(Succeed()) + + assertRules(rule2, rule1) + + By("Insert the third rule") + + rule3 := &packetfilter.Rule{ + Proto: packetfilter.RuleProtoUDP, + DestCIDR: "170.254.1.0/24", + SrcCIDR: "171.254.1.0/24", + DPort: "d-port", + Action: packetfilter.RuleActionAccept, + } + + err = pf.Insert(packetfilter.TableTypeNAT, chainName, 2, rule3) + Expect(err).To(Succeed()) + + assertRules(rule2, rule3, rule1) + + By("Delete the first rule") + + err = pf.Delete(packetfilter.TableTypeNAT, chainName, rule1) + Expect(err).To(Succeed()) + + assertRules(rule2, rule3) + + // Try to delete again - should succeed. + err = pf.Delete(packetfilter.TableTypeNAT, chainName, rule1) + Expect(err).To(Succeed()) + + By("Clear the chain") + + err = pf.ClearChain(packetfilter.TableTypeNAT, chainName) + Expect(err).To(Succeed()) + + assertRules() + }) + + Specify("Creating and deleting sets", func() { + set := pf.NewNamedSet(setInfo) + + err := set.Create(true) + Expect(err).To(Succeed()) + + assertSets(set.Name()) + + err = set.Destroy() + Expect(err).To(Succeed()) + + assertSets() + + err = set.Create(true) + Expect(err).To(Succeed()) + + assertSets(set.Name()) + + err = pf.DestroySets(func(s string) bool { + return s == setName + }) + Expect(err).To(Succeed()) + + assertSets() + }) + + Specify("Adding and deleting entries from a set", func() { + set := pf.NewNamedSet(setInfo) + + err := set.Create(true) + Expect(err).To(Succeed()) + + err = set.AddEntry("entry1", false) + Expect(err).To(Succeed()) + assertEntries(set, "entry1") + + err = set.AddEntry("entry2", false) + Expect(err).To(Succeed()) + assertEntries(set, "entry1", "entry2") + + err = set.DelEntry("entry1") + Expect(err).To(Succeed()) + assertEntries(set, "entry2") + + err = set.Flush() + Expect(err).To(Succeed()) + assertEntries(set) + }) +}) + func testRuleConversion(rule *packetfilter.Rule) { spec := nftables.ToRuleSpec(rule) parsed := nftables.FromRuleSpec(spec) @@ -108,3 +350,23 @@ func testRuleConversion(rule *packetfilter.Rule) { Expect(parsed).To(Equal(rule)) } + +type fakeKnftablesWrapper struct { + *knftables.Fake +} + +func (f *fakeKnftablesWrapper) ListRules(ctx context.Context, chain string) ([]*knftables.Rule, error) { + rules, err := f.Fake.ListRules(ctx, chain) + + // The docs for ListRules interface says "the Rule objects will have their Comment and Handle fields filled in, + // but not the actual Rule field.". However, the fake implementation doesn't honor this so clear out the Rule field. + newRules := make([]*knftables.Rule, len(rules)) + + for i := range rules { + nr := *rules[i] + nr.Rule = "" + newRules[i] = &nr + } + + return newRules, err +}