Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dual stack support. #1584

Merged
merged 8 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ linters-settings:
threshold: 150
funlen:
Lines: 120
Statements: 50
Statements: 60
goconst:
min-len: 2
min-occurrences: 2
Expand Down
16 changes: 4 additions & 12 deletions pkg/networkservice/connectioncontext/ipcontext/vl3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type vl3Client struct {
// NewClient - returns a new vL3 client instance that manages connection.context.ipcontext for vL3 scenario.
//
// Produces refresh on prefix update.
// Requires begin and metdata chain elements.
// Requires begin and metadata chain elements.
func NewClient(chainContext context.Context, pool *IPAM) networkservice.NetworkServiceClient {
if chainContext == nil {
panic("chainContext can not be nil")
Expand Down Expand Up @@ -97,17 +97,9 @@ func (n *vl3Client) Request(ctx context.Context, request *networkservice.Network

var address, prefix = n.pool.selfAddress().String(), n.pool.selfPrefix().String()

conn.GetContext().GetIpContext().SrcIpAddrs = []string{address}
conn.GetContext().GetIpContext().DstRoutes = []*networkservice.Route{
{
Prefix: address,
NextHop: n.pool.selfAddress().IP.String(),
},
{
Prefix: prefix,
NextHop: n.pool.selfAddress().IP.String(),
},
}
addAddr(&conn.GetContext().GetIpContext().SrcIpAddrs, address)
addRoute(&conn.GetContext().GetIpContext().DstRoutes, address, n.pool.selfAddress().IP.String())
addRoute(&conn.GetContext().GetIpContext().DstRoutes, prefix, n.pool.selfAddress().IP.String())

return next.Client(ctx).Request(ctx, request, opts...)
}
Expand Down
93 changes: 90 additions & 3 deletions pkg/networkservice/connectioncontext/ipcontext/vl3/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ import (
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/networkservicemesh/sdk/pkg/ipam/strictvl3ipam"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/common/excludedprefixes"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/ipcontext/vl3"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/adapters"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/chain"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
)
Expand Down Expand Up @@ -133,6 +135,91 @@ func Test_VL3NSE_ConnectsToVl3NSE(t *testing.T) {
require.Equal(t, "10.0.1.0/24", resp.GetContext().GetIpContext().GetDstRoutes()[1].GetPrefix())
}

func Test_VL3NSE_ConnectsToVl3NSE_DualStack(t *testing.T) {
t.Cleanup(func() {
goleak.VerifyNone(t)
})

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

var ipams []*vl3.IPAM
ipam1 := vl3.NewIPAM("10.0.0.1/24")
ipams = append(ipams, ipam1)
ipam2 := vl3.NewIPAM("2001:db8::/112")
ipams = append(ipams, ipam2)

var clients []networkservice.NetworkServiceClient
for _, ipam := range ipams {
clients = append(clients, vl3.NewClient(ctx, ipam))
}

var server = next.NewNetworkServiceServer(
adapters.NewClientToServer(
next.NewNetworkServiceClient(
begin.NewClient(),
metadata.NewClient(),
chain.NewNetworkServiceClient(clients...),
),
),
metadata.NewServer(),
strictvl3ipam.NewServer(ctx, vl3.NewServer, ipams...),
)

resp, err := server.Request(ctx, &networkservice.NetworkServiceRequest{Connection: &networkservice.Connection{Id: t.Name()}})

require.NoError(t, err)

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[0])
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetSrcIpAddrs()[1])
require.Equal(t, "10.0.0.1/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[2])
require.Equal(t, "2001:db8::1/128", resp.GetContext().GetIpContext().GetSrcIpAddrs()[3])

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetDstIpAddrs()[0])
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetDstIpAddrs()[1])

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetSrcRoutes()[0].GetPrefix())
require.Equal(t, "10.0.0.0/24", resp.GetContext().GetIpContext().GetSrcRoutes()[1].GetPrefix())
require.Equal(t, "10.0.0.0/16", resp.GetContext().GetIpContext().GetSrcRoutes()[5].GetPrefix())
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetSrcRoutes()[2].GetPrefix())
require.Equal(t, "2001:db8::/112", resp.GetContext().GetIpContext().GetSrcRoutes()[3].GetPrefix())
require.Equal(t, "2001:db8::/64", resp.GetContext().GetIpContext().GetSrcRoutes()[4].GetPrefix())

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetDstRoutes()[0].GetPrefix())
require.Equal(t, "10.0.0.0/24", resp.GetContext().GetIpContext().GetDstRoutes()[1].GetPrefix())
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetDstRoutes()[2].GetPrefix())
require.Equal(t, "2001:db8::/112", resp.GetContext().GetIpContext().GetDstRoutes()[3].GetPrefix())
require.Equal(t, "10.0.0.1/32", resp.GetContext().GetIpContext().GetDstRoutes()[4].GetPrefix())
require.Equal(t, "2001:db8::1/128", resp.GetContext().GetIpContext().GetDstRoutes()[5].GetPrefix())

// refresh
resp, err = server.Request(ctx, &networkservice.NetworkServiceRequest{Connection: resp})

require.NoError(t, err)

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[0])
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetSrcIpAddrs()[1])
require.Equal(t, "10.0.0.1/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[2])
require.Equal(t, "2001:db8::1/128", resp.GetContext().GetIpContext().GetSrcIpAddrs()[3])

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetDstIpAddrs()[0])
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetDstIpAddrs()[1])

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetSrcRoutes()[0].GetPrefix())
require.Equal(t, "10.0.0.0/24", resp.GetContext().GetIpContext().GetSrcRoutes()[1].GetPrefix())
require.Equal(t, "10.0.0.0/16", resp.GetContext().GetIpContext().GetSrcRoutes()[5].GetPrefix())
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetSrcRoutes()[2].GetPrefix())
require.Equal(t, "2001:db8::/112", resp.GetContext().GetIpContext().GetSrcRoutes()[3].GetPrefix())
require.Equal(t, "2001:db8::/64", resp.GetContext().GetIpContext().GetSrcRoutes()[4].GetPrefix())

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetDstRoutes()[0].GetPrefix())
require.Equal(t, "10.0.0.0/24", resp.GetContext().GetIpContext().GetDstRoutes()[1].GetPrefix())
require.Equal(t, "2001:db8::/128", resp.GetContext().GetIpContext().GetDstRoutes()[2].GetPrefix())
require.Equal(t, "2001:db8::/112", resp.GetContext().GetIpContext().GetDstRoutes()[3].GetPrefix())
require.Equal(t, "10.0.0.1/32", resp.GetContext().GetIpContext().GetDstRoutes()[4].GetPrefix())
require.Equal(t, "2001:db8::1/128", resp.GetContext().GetIpContext().GetDstRoutes()[5].GetPrefix())
}

func Test_VL3NSE_ConnectsToVl3NSE_ChangePrefix(t *testing.T) {
t.Cleanup(func() {
goleak.VerifyNone(t)
Expand Down Expand Up @@ -178,14 +265,14 @@ func Test_VL3NSE_ConnectsToVl3NSE_ChangePrefix(t *testing.T) {

require.NoError(t, err)

require.Equal(t, "10.0.5.0/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[0])
require.Equal(t, "10.0.5.0/32", resp.GetContext().GetIpContext().GetSrcIpAddrs()[2])
require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetDstIpAddrs()[0])

require.Equal(t, "10.0.0.0/32", resp.GetContext().GetIpContext().GetSrcRoutes()[0].GetPrefix())
require.Equal(t, "10.0.0.0/24", resp.GetContext().GetIpContext().GetSrcRoutes()[1].GetPrefix())
require.Equal(t, "10.0.0.0/16", resp.GetContext().GetIpContext().GetSrcRoutes()[2].GetPrefix())
require.Equal(t, "10.0.5.0/32", resp.GetContext().GetIpContext().GetDstRoutes()[0].GetPrefix())
require.Equal(t, "10.0.5.0/24", resp.GetContext().GetIpContext().GetDstRoutes()[1].GetPrefix())
require.Equal(t, "10.0.5.0/32", resp.GetContext().GetIpContext().GetDstRoutes()[3].GetPrefix())
require.Equal(t, "10.0.5.0/24", resp.GetContext().GetIpContext().GetDstRoutes()[4].GetPrefix())
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/networkservice/connectioncontext/ipcontext/vl3/ipam.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (p *IPAM) isExcluded(ipNet string) bool {
}

// Reset resets IPAM's ippol by setting new prefix
func (p *IPAM) Reset(prefix string, excludePrefies ...string) error {
func (p *IPAM) Reset(prefix string, excludePrefixes ...string) error {
p.Lock()
defer p.Unlock()

Expand All @@ -172,7 +172,7 @@ func (p *IPAM) Reset(prefix string, excludePrefies ...string) error {
p.excludedPrefixes[selfAddress.String()] = struct{}{}
p.ipPool.Exclude(selfAddress)

for _, excludePrefix := range excludePrefies {
for _, excludePrefix := range excludePrefixes {
p.ipPool.ExcludeString(excludePrefix)
p.excludedPrefixes[excludePrefix] = struct{}{}
}
Expand Down
15 changes: 0 additions & 15 deletions pkg/networkservice/connectioncontext/ipcontext/vl3/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,6 @@ import (
"github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata"
)

type addressKey struct{}

func loadAddress(ctx context.Context) (string, bool) {
v, ok := metadata.Map(ctx, false).Load(addressKey{})
if ok {
return v.(string), true
}

return "", false
}

func storeAddress(ctx context.Context, address string) {
metadata.Map(ctx, false).Store(addressKey{}, address)
}

type cancelKey struct{}

func storeCancel(ctx context.Context, cancel context.CancelFunc) {
Expand Down
80 changes: 61 additions & 19 deletions pkg/networkservice/connectioncontext/ipcontext/vl3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@ import (
"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/edwarnicke/genericsync"

"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/tools/ippool"
"github.com/networkservicemesh/sdk/pkg/tools/log"
)

type vl3Server struct {
pool *IPAM
pool *IPAM
subnetMap genericsync.Map[string, string] // Map connectionId:subnet
}

// NewServer - returns a new vL3 server instance that manages connection.context.ipcontext for vL3 scenario.
Expand All @@ -54,26 +59,28 @@ func (v *vl3Server) Request(ctx context.Context, request *networkservice.Network
conn.GetContext().IpContext = new(networkservice.IPContext)
}

var ipContext = &networkservice.IPContext{
SrcIpAddrs: request.GetConnection().Context.GetIpContext().GetSrcIpAddrs(),
DstRoutes: request.GetConnection().Context.GetIpContext().GetDstRoutes(),
ExcludedPrefixes: request.GetConnection().Context.GetIpContext().GetExcludedPrefixes(),
}

shouldAllocate := len(ipContext.SrcIpAddrs) == 0

if prevAddress, ok := loadAddress(ctx); ok && !shouldAllocate {
shouldAllocate = !v.pool.isExcluded(prevAddress)
}

if shouldAllocate {
var ipContext = conn.GetContext().GetIpContext()

if prevAddress, ok := v.subnetMap.Load(conn.GetId()); ok {
// Remove previous prefix from IP Context if a current server prefix has changed
if v.pool.globalIPNet().String() != prevAddress {
srcNet, err := v.pool.allocate()
log.FromContext(ctx).Infof("Server Request. Allocated net: %+v for connection: %+v", srcNet.String(), conn.GetId())
if err != nil {
return nil, err
}
removePreviousPrefixFromIPContext(ipContext, prevAddress)
ipContext.SrcIpAddrs = append(ipContext.SrcIpAddrs, srcNet.String())
v.subnetMap.Store(conn.GetId(), v.pool.globalIPNet().String())
}
} else {
srcNet, err := v.pool.allocate()
log.FromContext(ctx).Infof("Server Request. Allocated initial net: %+v for connection: %+v", srcNet.String(), conn.GetId())
if err != nil {
return nil, err
}
ipContext.DstRoutes = nil
ipContext.SrcIpAddrs = append([]string(nil), srcNet.String())
storeAddress(ctx, srcNet.String())
ipContext.SrcIpAddrs = append(ipContext.SrcIpAddrs, srcNet.String())
v.subnetMap.Store(conn.GetId(), v.pool.globalIPNet().String())
}

addRoute(&ipContext.SrcRoutes, v.pool.selfAddress().String(), v.pool.selfAddress().IP.String())
Expand All @@ -83,8 +90,6 @@ func (v *vl3Server) Request(ctx context.Context, request *networkservice.Network
}
addAddr(&ipContext.DstIpAddrs, v.pool.selfAddress().String())

conn.GetContext().IpContext = ipContext

resp, err := next.Server(ctx).Request(ctx, request)
if err == nil {
addRoute(&resp.GetContext().GetIpContext().SrcRoutes, v.pool.globalIPNet().String(), v.pool.selfAddress().IP.String())
Expand All @@ -96,6 +101,7 @@ func (v *vl3Server) Close(ctx context.Context, conn *networkservice.Connection)
for _, srcAddr := range conn.GetContext().GetIpContext().GetSrcIpAddrs() {
v.pool.freeIfAllocated(srcAddr)
}
v.subnetMap.Delete(conn.GetId())
return next.Server(ctx).Close(ctx, conn)
}

Expand All @@ -119,3 +125,39 @@ func addAddr(addrs *[]string, addr string) {
}
*addrs = append(*addrs, addr)
}

func removePreviousPrefixFromIPContext(ipContext *networkservice.IPContext, prevAddress string) {
prevIPPool := ippool.NewWithNetString(prevAddress)

var srcIPAddrs []string
for _, ip := range ipContext.SrcIpAddrs {
if !prevIPPool.ContainsNetString(ip) {
srcIPAddrs = append(srcIPAddrs, ip)
}
}
ipContext.SrcIpAddrs = srcIPAddrs

var dstIPAddrs []string
for _, ip := range ipContext.DstIpAddrs {
if !prevIPPool.ContainsNetString(ip) {
dstIPAddrs = append(dstIPAddrs, ip)
}
}
ipContext.DstIpAddrs = dstIPAddrs

var srcRoutes []*networkservice.Route
for _, r := range ipContext.SrcRoutes {
if !prevIPPool.ContainsNetString(r.Prefix) {
srcRoutes = append(srcRoutes, r)
}
}
ipContext.SrcRoutes = srcRoutes

var dstRoutes []*networkservice.Route
for _, r := range ipContext.DstRoutes {
if !prevIPPool.ContainsNetString(r.Prefix) {
dstRoutes = append(dstRoutes, r)
}
}
ipContext.DstRoutes = dstRoutes
}
50 changes: 48 additions & 2 deletions pkg/networkservice/connectioncontext/ipcontext/vl3/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/ipam/strictvl3ipam"
"github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/ipcontext/vl3"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"

Expand Down Expand Up @@ -125,7 +126,11 @@ func Test_NSC_ConnectsToVl3NSE_Close(t *testing.T) {
)

for i := 0; i < 10; i++ {
resp, err := server.Request(context.Background(), new(networkservice.NetworkServiceRequest))
resp, err := server.Request(context.Background(), &networkservice.NetworkServiceRequest{
Connection: &networkservice.Connection{
Id: "1",
},
})

require.NoError(t, err)

Expand All @@ -137,7 +142,11 @@ func Test_NSC_ConnectsToVl3NSE_Close(t *testing.T) {
require.Equal(t, "10.0.0.0/16", resp.GetContext().GetIpContext().GetSrcRoutes()[2].GetPrefix(), i)
require.Equal(t, "10.0.0.1/32", resp.GetContext().GetIpContext().GetDstRoutes()[0].GetPrefix(), i)

resp1, err1 := server.Request(context.Background(), new(networkservice.NetworkServiceRequest))
resp1, err1 := server.Request(context.Background(), &networkservice.NetworkServiceRequest{
Connection: &networkservice.Connection{
Id: "2",
},
})

require.NoError(t, err1)

Expand All @@ -155,3 +164,40 @@ func Test_NSC_ConnectsToVl3NSE_Close(t *testing.T) {
require.NoError(t, err, i)
}
}

func Test_NSC_ConnectsToVl3NSE_DualStack(t *testing.T) {
t.Cleanup(func() {
goleak.VerifyNone(t)
})

var ipams []*vl3.IPAM
ipam1 := vl3.NewIPAM("10.0.0.1/24")
ipams = append(ipams, ipam1)
ipam2 := vl3.NewIPAM("2001:db8::/112")
ipams = append(ipams, ipam2)

var server = next.NewNetworkServiceServer(
metadata.NewServer(),
strictvl3ipam.NewServer(context.Background(), vl3.NewServer, ipams...),
)

resp, err := server.Request(context.Background(), new(networkservice.NetworkServiceRequest))
require.NoError(t, err)
ipContext := resp.GetContext().GetIpContext()

require.Equal(t, "10.0.0.1/32", ipContext.GetSrcIpAddrs()[0])
require.Equal(t, "2001:db8::1/128", ipContext.GetSrcIpAddrs()[1])

require.Equal(t, "10.0.0.0/32", ipContext.GetDstIpAddrs()[0])
require.Equal(t, "2001:db8::/128", ipContext.GetDstIpAddrs()[1])

require.Equal(t, "10.0.0.0/32", ipContext.GetSrcRoutes()[0].GetPrefix())
require.Equal(t, "10.0.0.0/24", ipContext.GetSrcRoutes()[1].GetPrefix())
require.Equal(t, "10.0.0.0/16", ipContext.GetSrcRoutes()[5].GetPrefix())
require.Equal(t, "2001:db8::/128", ipContext.GetSrcRoutes()[2].GetPrefix())
require.Equal(t, "2001:db8::/112", ipContext.GetSrcRoutes()[3].GetPrefix())
require.Equal(t, "2001:db8::/64", ipContext.GetSrcRoutes()[4].GetPrefix())

require.Equal(t, "10.0.0.1/32", ipContext.GetDstRoutes()[0].GetPrefix())
require.Equal(t, "2001:db8::1/128", ipContext.GetDstRoutes()[1].GetPrefix())
}
Loading