Skip to content

Commit

Permalink
netmap: Use CID to prevent incorrect usage in netmap
Browse files Browse the repository at this point in the history
Signed-off-by: Evgenii Baidakov <evgenii@nspcc.io>
  • Loading branch information
smallhive committed Apr 19, 2023
1 parent c6bda42 commit 079fd8a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 15 deletions.
26 changes: 21 additions & 5 deletions netmap/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path/filepath"
"testing"

cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -62,7 +63,10 @@ func TestPlacementPolicy_Interopability(t *testing.T) {

for name, tt := range tc.Tests {
t.Run(name, func(t *testing.T) {
v, err := nm.ContainerNodes(tt.Policy, tt.Pivot)
var pivot cid.ID
copy(pivot[:], tt.Pivot)

v, err := nm.ContainerNodes(tt.Policy, pivot)
if tt.Result == nil {
require.Error(t, err)
require.Contains(t, err.Error(), tt.Error)
Expand All @@ -73,7 +77,10 @@ func TestPlacementPolicy_Interopability(t *testing.T) {
compareNodes(t, tt.Result, tc.Nodes, v)

if tt.Placement.Result != nil {
res, err := nm.PlacementVectors(v, tt.Placement.Pivot)
var placementPivot cid.ID
copy(placementPivot[:], tt.Placement.Pivot)

res, err := nm.PlacementVectors(v, placementPivot)
require.NoError(t, err)
compareNodes(t, tt.Placement.Result, tc.Nodes, res)
require.Equal(t, srcNodes, tc.Nodes)
Expand Down Expand Up @@ -108,11 +115,14 @@ func BenchmarkPlacementPolicyInteropability(b *testing.B) {

for name, tt := range tc.Tests {
b.Run(name, func(b *testing.B) {
var pivot cid.ID
copy(pivot[:], tt.Pivot)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
v, err := nm.ContainerNodes(tt.Policy, tt.Pivot)
v, err := nm.ContainerNodes(tt.Policy, pivot)
b.StopTimer()
if tt.Result == nil {
require.Error(b, err)
Expand All @@ -123,8 +133,11 @@ func BenchmarkPlacementPolicyInteropability(b *testing.B) {
compareNodes(b, tt.Result, tc.Nodes, v)

if tt.Placement.Result != nil {
var placementPivot cid.ID
copy(placementPivot[:], tt.Placement.Pivot)

b.StartTimer()
res, err := nm.PlacementVectors(v, tt.Placement.Pivot)
res, err := nm.PlacementVectors(v, placementPivot)
b.StopTimer()
require.NoError(b, err)
compareNodes(b, tt.Placement.Result, tc.Nodes, res)
Expand All @@ -150,11 +163,14 @@ func BenchmarkManySelects(b *testing.B) {
var nm NetMap
nm.SetNodes(tc.Nodes)

var pivot cid.ID
copy(pivot[:], tt.Pivot)

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
_, err = nm.ContainerNodes(tt.Policy, tt.Pivot)
_, err = nm.ContainerNodes(tt.Policy, pivot)
if err != nil {
b.FailNow()
}
Expand Down
4 changes: 2 additions & 2 deletions netmap/json_tests/many_selects.json
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@
"policy": {"replicas":[{"count":1,"selector":"SameRU"},{"count":1,"selector":"DistinctRU"},{"count":1,"selector":"Good"},{"count":1,"selector":"Main"}],"containerBackupFactor":2,"selectors":[{"name":"SameRU","count":2,"clause":"SAME","attribute":"City","filter":"FromRU"},{"name":"DistinctRU","count":2,"clause":"DISTINCT","attribute":"City","filter":"FromRU"},{"name":"Good","count":2,"clause":"DISTINCT","attribute":"Country","filter":"Good"},{"name":"Main","count":3,"clause":"DISTINCT","attribute":"Country","filter":"*"}],"filters":[{"name":"FromRU","key":"Country","op":"EQ","value":"Russia"},{"name":"Good","key":"Rating","op":"GE","value":"4"}]},
"result": [
[0, 5, 9, 10],
[2, 6, 0, 5],
[0, 5, 2, 6],
[1, 8, 2, 5],
[3, 4, 1, 7, 0, 2]
[0, 2, 1, 7, 3, 4]
]
}
}
Expand Down
6 changes: 3 additions & 3 deletions netmap/json_tests/multiple_rep_asymmetric.json
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@
4
],
[
8,
12,
5,
10
10,
8,
12
]
]
}
Expand Down
14 changes: 11 additions & 3 deletions netmap/netmap.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package netmap

import (
"crypto/sha256"
"fmt"

"github.com/nspcc-dev/hrw"
"github.com/nspcc-dev/neofs-api-go/v2/netmap"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
)

// NetMap represents NeoFS network map. It includes information about all
Expand Down Expand Up @@ -144,7 +146,10 @@ func flattenNodes(ns []nodes) nodes {
// For example, in order to build node list to store the object, binary-encoded
// object identifier can be used as pivot. Result is deterministic for
// the fixed NetMap and parameters.
func (m NetMap) PlacementVectors(vectors [][]NodeInfo, pivot []byte) ([][]NodeInfo, error) {
func (m NetMap) PlacementVectors(vectors [][]NodeInfo, containerID cid.ID) ([][]NodeInfo, error) {
pivot := make([]byte, sha256.Size)
containerID.Encode(pivot)

h := hrw.Hash(pivot)
wf := defaultWeightFunc(m.nodes)
result := make([][]NodeInfo, len(vectors))
Expand All @@ -166,11 +171,14 @@ func (m NetMap) PlacementVectors(vectors [][]NodeInfo, pivot []byte) ([][]NodeIn
// the fixed NetMap and parameters.
//
// Result can be used in PlacementVectors.
func (m NetMap) ContainerNodes(p PlacementPolicy, pivot []byte) ([][]NodeInfo, error) {
func (m NetMap) ContainerNodes(p PlacementPolicy, containerID cid.ID) ([][]NodeInfo, error) {
c := newContext(m)
c.setPivot(pivot)
c.setCBF(p.backupFactor)

pivot := make([]byte, sha256.Size)
containerID.Encode(pivot)
c.setPivot(pivot)

if err := c.processFilters(p); err != nil {
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions netmap/selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/nspcc-dev/hrw"
"github.com/nspcc-dev/neofs-api-go/v2/netmap"
cid "github.com/nspcc-dev/neofs-sdk-go/container/id"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -129,7 +130,7 @@ func BenchmarkPolicyHRWType(b *testing.B) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := nm.ContainerNodes(p, []byte{1})
_, err := nm.ContainerNodes(p, cid.ID{1})
if err != nil {
b.Fatal()
}
Expand Down Expand Up @@ -173,7 +174,7 @@ func TestPlacementPolicy_DeterministicOrder(t *testing.T) {
nm.SetNodes(nodeList)

getIndices := func(t *testing.T) (uint64, uint64) {
v, err := nm.ContainerNodes(p, []byte{1})
v, err := nm.ContainerNodes(p, cid.ID{1})
require.NoError(t, err)

nss := make([]nodes, len(v))
Expand Down

0 comments on commit 079fd8a

Please sign in to comment.