diff --git a/bearer/bearer.go b/bearer/bearer.go index e9fafddb..4b5f8827 100644 --- a/bearer/bearer.go +++ b/bearer/bearer.go @@ -20,11 +20,9 @@ import ( // // Instances can be created using built-in var declaration. type Token struct { - targetUserSet bool - targetUser user.ID + targetUser user.ID - issuerSet bool - issuer user.ID + issuer user.ID eaclTableSet bool eaclTable eacl.Table @@ -56,19 +54,23 @@ func (b *Token) readFromV2(m acl.BearerToken, checkFieldPresence bool) error { } targetUser := body.GetOwnerID() - if b.targetUserSet = targetUser != nil; b.targetUserSet { + if targetUser != nil { err = b.targetUser.ReadFromV2(*targetUser) if err != nil { return fmt.Errorf("invalid target user: %w", err) } + } else { + b.targetUser = user.ID{} } issuer := body.GetIssuer() - if b.issuerSet = issuer != nil; b.issuerSet { + if issuer != nil { err = b.issuer.ReadFromV2(*issuer) if err != nil { return fmt.Errorf("invalid issuer: %w", err) } + } else { + b.issuer = user.ID{} } lifetime := body.GetLifetime() @@ -98,7 +100,7 @@ func (b *Token) ReadFromV2(m acl.BearerToken) error { } func (b Token) fillBody() *acl.BearerTokenBody { - if !b.eaclTableSet && !b.targetUserSet && !b.lifetimeSet && !b.issuerSet { + if !b.eaclTableSet && b.targetUser.IsZero() && !b.lifetimeSet && b.issuer.IsZero() { return nil } @@ -108,14 +110,14 @@ func (b Token) fillBody() *acl.BearerTokenBody { body.SetEACL(b.eaclTable.ToV2()) } - if b.targetUserSet { + if !b.targetUser.IsZero() { var targetUser refs.OwnerID b.targetUser.WriteToV2(&targetUser) body.SetOwnerID(&targetUser) } - if b.issuerSet { + if !b.issuer.IsZero() { var issuer refs.OwnerID b.issuer.WriteToV2(&issuer) @@ -240,8 +242,8 @@ func (b Token) AssertContainer(cnr cid.ID) bool { return true } - cnrTable, set := b.eaclTable.CID() - return !set || cnrTable == cnr + cnrTable := b.eaclTable.GetCID() + return cnrTable.IsZero() || cnrTable == cnr } // ForUser specifies ID of the user who can use the Token for the operations @@ -252,7 +254,6 @@ func (b Token) AssertContainer(cnr cid.ID) bool { // See also AssertUser. func (b *Token) ForUser(id user.ID) { b.targetUser = id - b.targetUserSet = true } // AssertUser checks if the Token is issued to the given user. @@ -261,7 +262,7 @@ func (b *Token) ForUser(id user.ID) { // // See also ForUser. func (b Token) AssertUser(id user.ID) bool { - return !b.targetUserSet || b.targetUser == id + return b.targetUser.IsZero() || b.targetUser == id } // Sign calculates and writes signature of the [Token] data along with issuer ID @@ -403,7 +404,6 @@ func (b Token) SigningKeyBytes() []byte { // // See also [Token.Issuer], [Token.Sign]. func (b *Token) SetIssuer(usr user.ID) { - b.issuerSet = true b.issuer = usr } @@ -413,10 +413,7 @@ func (b *Token) SetIssuer(usr user.ID) { // // See also [Token.SetIssuer], [Token.Sign]. func (b Token) Issuer() user.ID { - if b.issuerSet { - return b.issuer - } - return user.ID{} + return b.issuer } // ResolveIssuer works like [Token.Issuer] with fallback to the public key @@ -425,7 +422,7 @@ func (b Token) Issuer() user.ID { // // See also [Token.SigningKeyBytes], [Token.Sign]. func (b Token) ResolveIssuer() user.ID { - if b.issuerSet { + if !b.issuer.IsZero() { return b.issuer } diff --git a/client/accounting.go b/client/accounting.go index da8b974b..503d3236 100644 --- a/client/accounting.go +++ b/client/accounting.go @@ -21,15 +21,13 @@ var ( type PrmBalanceGet struct { prmCommonMeta - accountSet bool - account user.ID + account user.ID } // SetAccount sets identifier of the NeoFS account for which the balance is requested. // Required parameter. func (x *PrmBalanceGet) SetAccount(id user.ID) { x.account = id - x.accountSet = true } // BalanceGet requests current balance of the NeoFS account. @@ -48,7 +46,7 @@ func (c *Client) BalanceGet(ctx context.Context, prm PrmBalanceGet) (accounting. }() switch { - case !prm.accountSet: + case prm.account.IsZero(): err = ErrMissingAccount return accounting.Decimal{}, err } diff --git a/client/container.go b/client/container.go index 49c89de5..c27d7725 100644 --- a/client/container.go +++ b/client/container.go @@ -158,7 +158,7 @@ type PrmContainerGet struct { prmCommonMeta } -// ContainerGet reads NeoFS container by ID. +// ContainerGet reads NeoFS container by ID. The ID must not be zero. // // Any errors (local or remote, including returned status codes) are returned as Go errors, // see [apistatus] package for NeoFS-specific error types. @@ -509,8 +509,7 @@ func (c *Client) ContainerSetEACL(ctx context.Context, table eacl.Table, signer return ErrMissingSigner } - _, isCIDSet := table.CID() - if !isCIDSet { + if table.GetCID().IsZero() { err = ErrMissingEACLContainer return err } diff --git a/client/container_statistic_test.go b/client/container_statistic_test.go index 476f696d..537079b4 100644 --- a/client/container_statistic_test.go +++ b/client/container_statistic_test.go @@ -362,7 +362,7 @@ func TestClientStatistic_ContainerSetEacl(t *testing.T) { c.prm.statisticCallback = collector.Collect var prm PrmContainerSetEACL - table := testEaclTable(cid.ID{}) + table := testEaclTable(cidtest.ID()) err := c.ContainerSetEACL(ctx, table, usr, prm) require.NoError(t, err) diff --git a/client/object_replicate_test.go b/client/object_replicate_test.go index a4f27da7..aa50038d 100644 --- a/client/object_replicate_test.go +++ b/client/object_replicate_test.go @@ -106,8 +106,8 @@ func (x *testReplicationServer) Replicate(_ context.Context, req *objectgrpc.Rep return &resp, nil } - id, ok := obj.ID() - if !ok { + id := obj.GetID() + if id.IsZero() { st.Code = 1024 // internal error st.Message = "missing object ID" resp.Status = &st diff --git a/container/id/id.go b/container/id/id.go index 91d70cb7..963dd520 100644 --- a/container/id/id.go +++ b/container/id/id.go @@ -2,6 +2,7 @@ package cid import ( "crypto/sha256" + "errors" "fmt" "github.com/mr-tron/base58" @@ -11,7 +12,8 @@ import ( // Size is the size of an [ID] in bytes. const Size = sha256.Size -// ID represents NeoFS container identifier. +// ID represents NeoFS container identifier. Zero ID is usually prohibited, see +// docs for details. // // ID implements built-in comparable interface. // @@ -19,6 +21,9 @@ const Size = sha256.Size // message. See ReadFromV2 / WriteToV2 methods. type ID [Size]byte +// ErrZero is an error returned on zero [ID] encounter. +var ErrZero = errors.New("zero container ID") + // NewFromMarshalledContainer returns new ID calculated from the given NeoFS // container encoded into Protocol Buffers V3 with ascending order of fields by // number. It's callers responsibility to ensure the format of b. See @@ -43,7 +48,11 @@ func DecodeString(s string) (ID, error) { // // See also WriteToV2. func (id *ID) ReadFromV2(m refs.ContainerID) error { - return id.Decode(m.GetValue()) + err := id.Decode(m.GetValue()) + if err == nil && id.IsZero() { + err = ErrZero + } + return err } // WriteToV2 writes ID to the refs.ContainerID message. @@ -139,3 +148,13 @@ func (id ID) String() string { // See also [container.Container.CalculateID], [container.Container.AssertID]. // Deprecated: use [NewFromContainerBinary]. func (id *ID) FromBinary(cnr []byte) { *id = NewFromMarshalledContainer(cnr) } + +// IsZero checks whether ID is zero. +func (id ID) IsZero() bool { + for i := range id { + if id[i] != 0 { + return false + } + } + return true +} diff --git a/container/id/id_test.go b/container/id/id_test.go index 0774e2e4..0dc29bd8 100644 --- a/container/id/id_test.go +++ b/container/id/id_test.go @@ -18,11 +18,13 @@ var validBytes = [cid.Size]byte{231, 189, 121, 7, 173, 134, 254, 165, 63, 186, 6 // corresponds to validBytes. const validString = "GbckSBPEdM2P41Gkb9cVapFYb5HmRPDTZZp9JExGnsCF" -var invalidValueTestcases = []struct { +type invalidValueTestCase struct { name string err string val []byte -}{ +} + +var invalidValueTestcases = []invalidValueTestCase{ {name: "nil value", err: "invalid length 0", val: nil}, {name: "empty value", err: "invalid length 0", val: []byte{}}, {name: "undersized value", err: "invalid length 31", val: make([]byte, 31)}, @@ -37,7 +39,9 @@ func TestID_ReadFromV2(t *testing.T) { require.EqualValues(t, validBytes, id) t.Run("invalid", func(t *testing.T) { - for _, tc := range invalidValueTestcases { + for _, tc := range append(invalidValueTestcases, invalidValueTestCase{ + name: "zero value", err: "zero container ID", val: make([]byte, cid.Size), + }) { t.Run(tc.name, func(t *testing.T) { var m refs.ContainerID m.SetValue(tc.val) @@ -166,3 +170,13 @@ func TestID_String(t *testing.T) { require.Equal(t, id.String(), id.String()) require.NotEqual(t, id.String(), cidtest.OtherID(id).String()) } + +func TestID_IsZero(t *testing.T) { + var id cid.ID + require.True(t, id.IsZero()) + for i := 0; i < cid.Size; i++ { + var id2 cid.ID + id2[i]++ + require.False(t, id2.IsZero()) + } +} diff --git a/container/size_test.go b/container/size_test.go index 1d8ed404..9cbe84a6 100644 --- a/container/size_test.go +++ b/container/size_test.go @@ -6,7 +6,6 @@ import ( v2container "github.com/nspcc-dev/neofs-api-go/v2/container" "github.com/nspcc-dev/neofs-api-go/v2/refs" "github.com/nspcc-dev/neofs-sdk-go/container" - cid "github.com/nspcc-dev/neofs-sdk-go/container/id" cidtest "github.com/nspcc-dev/neofs-sdk-go/container/id/test" "github.com/stretchr/testify/require" ) @@ -77,10 +76,8 @@ func TestSizeEstimation_ReadFromV2(t *testing.T) { require.Error(t, val.ReadFromV2(msg)) - cnrMsg.SetValue(make([]byte, cid.Size)) - - var cnr cid.ID - require.NoError(t, cnr.ReadFromV2(cnrMsg)) + cnr := cidtest.ID() + cnrMsg.SetValue(cnr[:]) msg.SetEpoch(epoch) msg.SetUsedSpace(value) diff --git a/eacl/table.go b/eacl/table.go index e56f933d..3743d1ed 100644 --- a/eacl/table.go +++ b/eacl/table.go @@ -14,7 +14,7 @@ import ( // Table is compatible with v2 acl.EACLTable message. type Table struct { version version.Version - cid *cid.ID + cid cid.ID records []Record } @@ -22,13 +22,7 @@ type Table struct { func (t Table) CopyTo(dst *Table) { ver := t.version dst.version = ver - - if t.cid != nil { - id := *t.cid - dst.cid = &id - } else { - dst.cid = nil - } + dst.cid = t.cid dst.records = make([]Record, len(t.records)) for i := range t.records { @@ -37,18 +31,17 @@ func (t Table) CopyTo(dst *Table) { } // CID returns identifier of the container that should use given access control rules. -func (t Table) CID() (cID cid.ID, isSet bool) { - if t.cid != nil { - cID = *t.cid - isSet = true - } +// Deprecated: use [Table.GetCID] instead. +func (t Table) CID() (cid.ID, bool) { return t.cid, !t.cid.IsZero() } - return -} +// GetCID returns identifier of the NeoFS container to which the eACL scope is +// limited. Zero return means the eACL may be applied to any container. +func (t Table) GetCID() cid.ID { return t.cid } -// SetCID sets identifier of the container that should use given access control rules. +// SetCID limits scope of the eACL to a referenced container. By default, if ID +// is zero, the eACL is applicable to any container. func (t *Table) SetCID(cid cid.ID) { - t.cid = &cid + t.cid = cid } // Version returns version of eACL format. @@ -87,12 +80,11 @@ func (t *Table) AddRecord(r *Record) { func (t *Table) ReadFromV2(m v2acl.Table) error { // set container id if id := m.GetContainerID(); id != nil { - if t.cid == nil { - t.cid = new(cid.ID) - } if err := t.cid.ReadFromV2(*id); err != nil { return fmt.Errorf("invalid container ID: %w", err) } + } else { + t.cid = cid.ID{} } // set version @@ -128,7 +120,7 @@ func (t *Table) ToV2() *v2acl.Table { v2 := new(v2acl.Table) var cidV2 refs.ContainerID - if t.cid != nil { + if !t.cid.IsZero() { t.cid.WriteToV2(&cidV2) v2.SetContainerID(&cidV2) } @@ -194,10 +186,6 @@ func NewTableFromV2(table *v2acl.Table) *Table { // set container id if id := table.GetContainerID(); id != nil { - if t.cid == nil { - t.cid = new(cid.ID) - } - copy(t.cid[:], id.GetValue()) } @@ -249,10 +237,7 @@ func (t *Table) UnmarshalJSON(data []byte) error { // EqualTables compares Table with each other. func EqualTables(t1, t2 Table) bool { - cID1, set1 := t1.CID() - cID2, set2 := t2.CID() - - if set1 != set2 || cID1 != cID2 || + if t1.GetCID() != t2.GetCID() || !t1.Version().Equal(t2.Version()) { return false } diff --git a/eacl/table_internal_test.go b/eacl/table_internal_test.go index 4bc97246..6093a516 100644 --- a/eacl/table_internal_test.go +++ b/eacl/table_internal_test.go @@ -54,24 +54,10 @@ func TestTable_CopyTo(t *testing.T) { t.Run("change cid", func(t *testing.T) { var dst Table table.CopyTo(&dst) - - cid1, isSet1 := table.CID() - require.True(t, isSet1) - - cid2, isSet2 := dst.CID() - require.True(t, isSet2) - - require.True(t, cid1 == cid2) + require.Equal(t, table.GetCID(), dst.GetCID()) dst.SetCID(cidtest.OtherID(id)) - - cid1, isSet1 = table.CID() - require.True(t, isSet1) - - cid2, isSet2 = dst.CID() - require.True(t, isSet2) - - require.False(t, cid1 == cid2) + require.NotEqual(t, table.GetCID(), dst.GetCID()) }) t.Run("change record", func(t *testing.T) { diff --git a/eacl/table_test.go b/eacl/table_test.go index 8447fc28..2f8640d5 100644 --- a/eacl/table_test.go +++ b/eacl/table_test.go @@ -43,9 +43,7 @@ func TestTable(t *testing.T) { id := cidtest.ID() table := eacl.CreateTable(id) - cID, set := table.CID() - require.True(t, set) - require.Equal(t, id, cID) + require.Equal(t, id, table.GetCID()) require.Equal(t, version.Current(), table.Version()) }) } @@ -100,8 +98,7 @@ func TestTable_ToV2(t *testing.T) { // check initial values require.Equal(t, version.Current(), table.Version()) require.Nil(t, table.Records()) - _, set := table.CID() - require.False(t, set) + require.Zero(t, table.GetCID()) // convert to v2 message tableV2 := table.ToV2() @@ -113,3 +110,29 @@ func TestTable_ToV2(t *testing.T) { require.Nil(t, tableV2.GetContainerID()) }) } + +func TestTable_LimitToContainer(t *testing.T) { + cnr := cidtest.ID() + var tbl eacl.Table + require.Zero(t, tbl.GetCID()) + tbl.SetCID(cnr) + require.Equal(t, cnr, tbl.GetCID()) +} + +func TestTable_CID(t *testing.T) { + cnr := cidtest.ID() + var tbl eacl.Table + _, ok := tbl.CID() + require.False(t, ok) + tbl.SetCID(cnr) + res, ok := tbl.CID() + require.True(t, ok) + require.Equal(t, cnr, res) +} + +func TestTable_SetCID(t *testing.T) { + cnr := cidtest.ID() + var tbl eacl.Table + tbl.SetCID(cnr) + require.Equal(t, cnr, tbl.GetCID()) +} diff --git a/eacl/types.go b/eacl/types.go index ef045821..cd3c0620 100644 --- a/eacl/types.go +++ b/eacl/types.go @@ -37,6 +37,7 @@ type ValidationUnit struct { } // WithContainerID configures ValidationUnit to use v as request's container ID. +// ID value must not be zero. func (u *ValidationUnit) WithContainerID(v *cid.ID) *ValidationUnit { if u != nil { u.cid = v diff --git a/object/fmt.go b/object/fmt.go index c1fee7d9..ead434ba 100644 --- a/object/fmt.go +++ b/object/fmt.go @@ -79,9 +79,9 @@ func (o *Object) VerifyID() error { return err } - oID, set := o.ID() - if !set { - return errOIDNotSet + oID := o.GetID() + if oID.IsZero() { + return oid.ErrZero } if id != oID { @@ -95,9 +95,9 @@ func (o *Object) VerifyID() error { // // See also [oid.ID.CalculateIDSignature]. func (o *Object) Sign(signer neofscrypto.Signer) error { - oID, set := o.ID() - if !set { - return errOIDNotSet + oID := o.GetID() + if oID.IsZero() { + return oid.ErrZero } sig, err := oID.CalculateIDSignature(signer) @@ -114,8 +114,7 @@ func (o *Object) Sign(signer neofscrypto.Signer) error { // // See also [Object.Sign]. func (o *Object) SignedData() []byte { - oID, _ := o.ID() - return oID.Marshal() + return o.GetID().Marshal() } // VerifySignature verifies object ID signature. diff --git a/object/id/address.go b/object/id/address.go index 76eb7143..9dfe754b 100644 --- a/object/id/address.go +++ b/object/id/address.go @@ -11,6 +11,7 @@ import ( // Address represents global object identifier in NeoFS network. Each object // belongs to exactly one container and is uniquely addressed within the container. +// Zero Address is usually prohibited, see docs for details. // // ID implements built-in comparable interface. // @@ -22,6 +23,9 @@ type Address struct { obj ID } +// ErrZeroAddress is an error returned on zero [Address] encounter. +var ErrZeroAddress = errors.New("zero object address") + // NewAddress constructs new Address. func NewAddress(cnr cid.ID, obj ID) Address { return Address{cnr, obj} } diff --git a/object/id/address_test.go b/object/id/address_test.go index 194af9f6..4217a833 100644 --- a/object/id/address_test.go +++ b/object/id/address_test.go @@ -100,15 +100,10 @@ func testAddressIDField[T ~[32]byte]( t.Run("encoding", func(t *testing.T) { t.Run("api", func(t *testing.T) { - var src, dst oid.Address + src := oidtest.Address() + var dst oid.Address var msg refs.Address - set(&dst, val) - src.WriteToV2(&msg) - require.Equal(t, make([]byte, len(val)), getAPI(&msg)) - require.NoError(t, dst.ReadFromV2(msg)) - require.Zero(t, get(dst)) - set(&src, val) src.WriteToV2(&msg) require.EqualValues(t, val[:], getAPI(&msg)) @@ -116,16 +111,11 @@ func testAddressIDField[T ~[32]byte]( require.Equal(t, val, get(dst)) }) t.Run("json", func(t *testing.T) { - var src, dst oid.Address - - set(&dst, val) - b, err := src.MarshalJSON() - require.NoError(t, err) - require.NoError(t, dst.UnmarshalJSON(b)) - require.Zero(t, get(dst)) + src := oidtest.Address() + var dst oid.Address set(&src, val) - b, err = src.MarshalJSON() + b, err := src.MarshalJSON() require.NoError(t, err) require.NoError(t, dst.UnmarshalJSON(b)) require.Equal(t, val, get(dst)) diff --git a/object/id/id.go b/object/id/id.go index 9af017aa..2b24c69f 100644 --- a/object/id/id.go +++ b/object/id/id.go @@ -2,6 +2,7 @@ package oid import ( "crypto/sha256" + "errors" "fmt" "github.com/mr-tron/base58" @@ -12,7 +13,8 @@ import ( // Size is the size of an [ID] in bytes. const Size = sha256.Size -// ID represents NeoFS object identifier in a container. +// ID represents NeoFS object identifier in a container. Zero ID is usually +// prohibited, see docs for details. // // ID implements built-in comparable interface. // @@ -20,6 +22,9 @@ const Size = sha256.Size // message. See ReadFromV2 / WriteToV2 methods. type ID [Size]byte +// ErrZero is an error returned on zero [ID] encounter. +var ErrZero = errors.New("zero object ID") + // NewFromObjectHeaderBinary returns new ID calculated from the given NeoFS // object header encoded into Protocol Buffers V3 with ascending order of fields // by number. It's callers responsibility to ensure the format of b. @@ -42,7 +47,11 @@ func DecodeString(s string) (ID, error) { // // See also WriteToV2. func (id *ID) ReadFromV2(m refs.ObjectID) error { - return id.Decode(m.GetValue()) + err := id.Decode(m.GetValue()) + if err == nil && id.IsZero() { + err = ErrZero + } + return err } // WriteToV2 writes ID to the refs.ObjectID message. @@ -173,3 +182,13 @@ func (id *ID) UnmarshalJSON(data []byte) error { return id.ReadFromV2(v2) } + +// IsZero checks whether ID is zero. +func (id ID) IsZero() bool { + for i := range id { + if id[i] != 0 { + return false + } + } + return true +} diff --git a/object/id/id_test.go b/object/id/id_test.go index ffb40951..edddcf22 100644 --- a/object/id/id_test.go +++ b/object/id/id_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/nspcc-dev/neofs-api-go/v2/refs" + cid "github.com/nspcc-dev/neofs-sdk-go/container/id" neofscrypto "github.com/nspcc-dev/neofs-sdk-go/crypto" neofscryptotest "github.com/nspcc-dev/neofs-sdk-go/crypto/test" oid "github.com/nspcc-dev/neofs-sdk-go/object/id" @@ -28,11 +29,13 @@ const validIDString = "GbckSBPEdM2P41Gkb9cVapFYb5HmRPDTZZp9JExGnsCF" // corresponds to validIDBytes. var validIDJSON = `{"value":"5715B62G/qU/ujxZIV8uZ9k5pFdSzPviAWQgSPsAB6w="}` -var invalidValueTestcases = []struct { +type invalidValueTestCase struct { name string err string val []byte -}{ +} + +var invalidValueTestcases = []invalidValueTestCase{ {name: "nil value", err: "invalid length 0", val: nil}, {name: "empty value", err: "invalid length 0", val: []byte{}}, {name: "undersized value", err: "invalid length 31", val: make([]byte, 31)}, @@ -59,7 +62,9 @@ func TestID_ReadFromV2(t *testing.T) { require.EqualValues(t, validIDBytes, id) t.Run("invalid", func(t *testing.T) { - for _, tc := range invalidValueTestcases { + for _, tc := range append(invalidValueTestcases, invalidValueTestCase{ + name: "zero value", err: "zero object ID", val: make([]byte, cid.Size), + }) { t.Run(tc.name, func(t *testing.T) { var m refs.ObjectID m.SetValue(tc.val) @@ -258,3 +263,13 @@ func TestID_CalculateIDSignature(t *testing.T) { require.True(t, s.Public().Verify(id.Marshal(), sig.Value())) } } + +func TestID_IsZero(t *testing.T) { + var id oid.ID + require.True(t, id.IsZero()) + for i := 0; i < oid.Size; i++ { + var id2 oid.ID + id2[i]++ + require.False(t, id2.IsZero()) + } +} diff --git a/object/object.go b/object/object.go index dbbe7cf2..46f5c9cc 100644 --- a/object/object.go +++ b/object/object.go @@ -109,19 +109,28 @@ func (o *Object) setSplitFields(setter func(*object.SplitHeader)) { // ID returns object identifier. // // See also [Object.SetID]. -func (o *Object) ID() (v oid.ID, isSet bool) { - v2 := (*object.Object)(o) - if id := v2.GetObjectID(); id != nil { - err := v.ReadFromV2(*v2.GetObjectID()) - isSet = (err == nil) +// Deprecated: use [Object.GetID] instead. +func (o *Object) ID() (oid.ID, bool) { + id := o.GetID() + return id, !id.IsZero() +} + +// GetID returns identifier of the object. Zero return means unset ID. +// +// See also [Object.SetID]. +func (o *Object) GetID() oid.ID { + var res oid.ID + m := (*object.Object)(o) + if id := m.GetObjectID(); id != nil { + _ = res.ReadFromV2(*m.GetObjectID()) } - return + return res } // SetID sets object identifier. // -// See also [Object.ID]. +// See also [Object.GetID]. func (o *Object) SetID(v oid.ID) { var v2 refs.ObjectID v.WriteToV2(&v2) @@ -233,21 +242,15 @@ func (o *Object) SetPayloadSize(v uint64) { // ContainerID returns identifier of the related container. // // See also [Object.SetContainerID]. +// Deprecated: use [Object.GetContainerID] instead. func (o *Object) ContainerID() (v cid.ID, isSet bool) { - v2 := (*object.Object)(o) - - cidV2 := v2.GetHeader().GetContainerID() - if cidV2 != nil { - err := v.ReadFromV2(*cidV2) - isSet = (err == nil) - } - - return + cnr := o.GetContainerID() + return cnr, !cnr.IsZero() } // SetContainerID sets identifier of the related container. // -// See also [Object.ContainerID]. +// See also [Object.GetContainerID]. func (o *Object) SetContainerID(v cid.ID) { var cidV2 refs.ContainerID v.WriteToV2(&cidV2) @@ -257,6 +260,18 @@ func (o *Object) SetContainerID(v cid.ID) { }) } +// GetContainerID returns identifier of the related container. Zero means unset +// binding. +// +// See also [Object.SetContainerID]. +func (o *Object) GetContainerID() cid.ID { + var cnr cid.ID + if m := (*object.Object)(o).GetHeader().GetContainerID(); m != nil { + _ = cnr.ReadFromV2(*m) + } + return cnr +} + // OwnerID returns identifier of the object owner. // // See also [Object.SetOwnerID]. @@ -419,21 +434,27 @@ func (o *Object) SetAttributes(v ...Attribute) { // PreviousID returns identifier of the previous sibling object. // // See also [Object.SetPreviousID]. -func (o *Object) PreviousID() (v oid.ID, isSet bool) { - v2 := (*object.Object)(o) +// Deprecated: use [Object.GetPreviousID] instead. +func (o *Object) PreviousID() (oid.ID, bool) { + id := o.GetPreviousID() + return id, !id.IsZero() +} - v2Prev := v2.GetHeader().GetSplit().GetPrevious() - if v2Prev != nil { - err := v.ReadFromV2(*v2Prev) - isSet = (err == nil) +// GetPreviousID returns identifier of the previous sibling object. Zero return +// means unset ID. +// +// See also [Object.SetPreviousID]. +func (o *Object) GetPreviousID() oid.ID { + var id oid.ID + if m := (*object.Object)(o).GetHeader().GetSplit().GetPrevious(); m != nil { + _ = id.ReadFromV2(*m) } - - return + return id } // ResetPreviousID resets identifier of the previous sibling object. // -// See also [Object.SetPreviousID], [Object.PreviousID]. +// See also [Object.SetPreviousID], [Object.GetPreviousID]. func (o *Object) ResetPreviousID() { o.setSplitFields(func(split *object.SplitHeader) { split.SetPrevious(nil) @@ -442,7 +463,7 @@ func (o *Object) ResetPreviousID() { // SetPreviousID sets identifier of the previous sibling object. // -// See also [Object.PreviousID]. +// See also [Object.GetPreviousID]. func (o *Object) SetPreviousID(v oid.ID) { var v2 refs.ObjectID v.WriteToV2(&v2) @@ -494,7 +515,7 @@ func (o *Object) SetChildren(v ...oid.ID) { // SetFirstID sets the first part's ID of the object's // split chain. // -// See also [Object.FirstID]. +// See also [Object.GetFirstID]. func (o *Object) SetFirstID(id oid.ID) { var v2 refs.ObjectID id.WriteToV2(&v2) @@ -507,16 +528,20 @@ func (o *Object) SetFirstID(id oid.ID) { // FirstID returns the first part of the object's split chain. // // See also [Object.SetFirstID]. -func (o *Object) FirstID() (v oid.ID, isSet bool) { - v2 := (*object.Object)(o) +func (o *Object) FirstID() (oid.ID, bool) { + id := o.GetFirstID() + return id, !id.IsZero() +} - v2First := v2.GetHeader().GetSplit().GetFirst() - if v2First != nil { - err := v.ReadFromV2(*v2First) - isSet = (err == nil) +// GetFirstID returns the first part of the object's split chain. Zero return means unset ID. +// +// See also [Object.SetFirstID]. +func (o *Object) GetFirstID() oid.ID { + var id oid.ID + if m := (*object.Object)(o).GetHeader().GetSplit().GetFirst(); m != nil { + _ = id.ReadFromV2(*m) } - - return + return id } // SplitID return split identity of split object. If object is not split returns nil. @@ -543,21 +568,27 @@ func (o *Object) SetSplitID(id *SplitID) { // ParentID returns identifier of the parent object. // // See also [Object.SetParentID]. -func (o *Object) ParentID() (v oid.ID, isSet bool) { - v2 := (*object.Object)(o) +// Deprecated: use [Object.GetParentID] instead. +func (o *Object) ParentID() (oid.ID, bool) { + id := o.GetParentID() + return id, !id.IsZero() +} - v2Par := v2.GetHeader().GetSplit().GetParent() - if v2Par != nil { - err := v.ReadFromV2(*v2Par) - isSet = (err == nil) +// GetParentID returns identifier of the parent object. Zero return means unset +// ID. +// +// See also [Object.SetParentID]. +func (o *Object) GetParentID() oid.ID { + var id oid.ID + if m := (*object.Object)(o).GetHeader().GetSplit().GetParent(); m != nil { + _ = id.ReadFromV2(*m) } - - return + return id } // SetParentID sets identifier of the parent object. // -// See also [Object.ParentID]. +// See also [Object.GetParentID]. func (o *Object) SetParentID(v oid.ID) { var v2 refs.ObjectID v.WriteToV2(&v2) @@ -741,7 +772,6 @@ func (o *Object) UnmarshalJSON(data []byte) error { return formatCheck((*object.Object)(o)) } -var errOIDNotSet = errors.New("object ID is not set") var errCIDNotSet = errors.New("container ID is not set") func formatCheck(v2 *object.Object) error { @@ -752,7 +782,7 @@ func formatCheck(v2 *object.Object) error { oidV2 := v2.GetObjectID() if oidV2 == nil { - return errOIDNotSet + return oid.ErrZero } err := oID.ReadFromV2(*oidV2) diff --git a/object/object_internal_test.go b/object/object_internal_test.go index 707ba398..a8d99491 100644 --- a/object/object_internal_test.go +++ b/object/object_internal_test.go @@ -92,29 +92,22 @@ func TestObject_CopyTo(t *testing.T) { t.Run("overwrite id", func(t *testing.T) { var local Object - _, isSet := local.ID() - require.False(t, isSet) + require.True(t, local.GetID().IsZero()) var dst Object require.NoError(t, dst.CalculateAndSetID()) - _, isSet = dst.ID() - require.True(t, isSet) + require.False(t, dst.GetID().IsZero()) local.CopyTo(&dst) - _, isSet = local.ID() - require.False(t, isSet) - _, isSet = dst.ID() - require.False(t, isSet) + require.True(t, local.GetID().IsZero()) + require.True(t, dst.GetID().IsZero()) checkObjectEquals(t, local, dst) require.NoError(t, dst.CalculateAndSetID()) - _, isSet = dst.ID() - require.True(t, isSet) - - _, isSet = local.ID() - require.False(t, isSet) + require.False(t, dst.GetID().IsZero()) + require.True(t, local.GetID().IsZero()) }) t.Run("change payload", func(t *testing.T) { diff --git a/object/object_test.go b/object/object_test.go index 7196edec..03503711 100644 --- a/object/object_test.go +++ b/object/object_test.go @@ -20,8 +20,7 @@ func TestInitCreation(t *testing.T) { Owner: own, }) - cID, set := o.ContainerID() - require.True(t, set) + cID := o.GetContainerID() require.Equal(t, cnr, cID) require.Equal(t, &own, o.OwnerID()) } diff --git a/object/relations/relations.go b/object/relations/relations.go index b54ec25e..90ffbf03 100644 --- a/object/relations/relations.go +++ b/object/relations/relations.go @@ -60,7 +60,7 @@ func Get(ctx context.Context, executor Executor, containerID cid.ID, rootObjectI // collect split chain by the descending ease of operations (ease is evaluated heuristically). // If any approach fails, we don't try the next since we assume that it will fail too. - if _, ok := splitInfo.Link(); !ok { + if splitInfo.GetLink().IsZero() { // the list is expected to contain last part and (probably) split info list, err := findSiblingByParentID(ctx, executor, containerID, rootObjectID, tokens, signer) if err != nil { @@ -75,17 +75,17 @@ func Get(ctx context.Context, executor Executor, containerID cid.ID, rootObjectI } return nil, nil, fmt.Errorf("split info: %w", err) } - if link, ok := split.Link(); ok { + if link := split.GetLink(); !link.IsZero() { splitInfo.SetLink(link) break } - if last, ok := split.LastPart(); ok { + if last := split.GetLastPart(); !last.IsZero() { splitInfo.SetLastPart(last) } } } - if idLinking, ok := splitInfo.Link(); ok { + if idLinking := splitInfo.GetLink(); !idLinking.IsZero() { children, err := listChildrenByLinker(ctx, executor, containerID, idLinking, tokens, signer) if err != nil { return nil, nil, fmt.Errorf("linking object's header: %w", err) @@ -94,8 +94,8 @@ func Get(ctx context.Context, executor Executor, containerID cid.ID, rootObjectI return children, &idLinking, nil } - idMember, ok := splitInfo.LastPart() - if !ok { + idMember := splitInfo.GetLastPart() + if idMember.IsZero() { return nil, nil, errors.New("missing any data in received object split information") } @@ -111,7 +111,7 @@ func Get(ctx context.Context, executor Executor, containerID cid.ID, rootObjectI return nil, nil, fmt.Errorf("split chain member's header: %w", err) } - if _, ok = chainSet[idMember]; ok { + if _, ok := chainSet[idMember]; ok { return nil, nil, fmt.Errorf("duplicated member in the split chain %s", idMember) } @@ -223,8 +223,8 @@ func getLeftSibling(ctx context.Context, header HeadExecutor, cnrID cid.ID, objI return oid.ID{}, fmt.Errorf("split chain member's header: %w", err) } - idMember, ok := hdr.PreviousID() - if !ok { + idMember := hdr.GetPreviousID() + if idMember.IsZero() { return oid.ID{}, ErrNoLeftSibling } diff --git a/object/slicer/slicer.go b/object/slicer/slicer.go index 1f9a3b0a..4e2cc3dd 100644 --- a/object/slicer/slicer.go +++ b/object/slicer/slicer.go @@ -231,8 +231,8 @@ func slice(ctx context.Context, ow ObjectWriter, header object.Object, data io.R // headerData extract required fields from header, otherwise throw the error. func headerData(header object.Object) (cid.ID, user.ID, error) { - containerID, isSet := header.ContainerID() - if !isSet { + containerID := header.GetContainerID() + if containerID.IsZero() { return cid.ID{}, user.ID{}, fmt.Errorf("container-id: %w", ErrIncompleteHeader) } @@ -627,13 +627,12 @@ func flushObjectMetadata(signer neofscrypto.Signer, meta dynamicObjectMetadata, func writeInMemObject(ctx context.Context, signer user.Signer, w ObjectWriter, header object.Object, payloadBuffers [][]byte, meta dynamicObjectMetadata, prm client.PrmObjectPutInit) (oid.ID, error) { var ( - id oid.ID - err error - isSet bool + id oid.ID + err error ) - id, isSet = header.ID() - if !isSet || header.Signature() == nil { + id = header.GetID() + if id.IsZero() || header.Signature() == nil { id, err = flushObjectMetadata(signer, meta, &header) if err != nil { diff --git a/object/slicer/slicer_test.go b/object/slicer/slicer_test.go index 60be4874..3714d86d 100644 --- a/object/slicer/slicer_test.go +++ b/object/slicer/slicer_test.go @@ -563,8 +563,8 @@ func newChainCollector(tb testing.TB) *chainCollector { } func checkStaticMetadata(tb testing.TB, header object.Object, in input) { - cnr, ok := header.ContainerID() - require.True(tb, ok, "all objects must be bound to some container") + cnr := header.GetContainerID() + require.False(tb, cnr.IsZero(), "all objects must be bound to some container") require.True(tb, cnr == in.container, "the container must be set to the configured one") owner := header.OwnerID() @@ -587,7 +587,7 @@ func checkStaticMetadata(tb testing.TB, header object.Object, in input) { require.NoError(tb, header.CheckHeaderVerificationFields(), "verification fields must be correctly set in header") - _, ok = header.PayloadHomomorphicHash() + _, ok := header.PayloadHomomorphicHash() require.Equal(tb, in.withHomo, ok) } @@ -598,15 +598,15 @@ func (x *chainCollector) handleOutgoingObject(headerOriginal object.Object, payl var header object.Object headerOriginal.CopyTo(&header) - id, ok := header.ID() - require.True(x.tb, ok, "all objects must have an ID") + id := header.GetID() + require.False(x.tb, id.IsZero(), "all objects must have an ID") idCalc, err := header.CalculateID() require.NoError(x.tb, err) require.True(x.tb, idCalc == id) - _, ok = x.mProcessed[id] + _, ok := x.mProcessed[id] require.False(x.tb, ok, "object must be written exactly once") x.mProcessed[id] = struct{}{} @@ -620,8 +620,7 @@ func (x *chainCollector) handleOutgoingObject(headerOriginal object.Object, payl if x.shortParentHeader == nil { // parent in the first part - _, set := parent.ID() - require.False(x.tb, set, "first object's parent cannot have ID") + require.True(x.tb, parent.GetID().IsZero(), "first object's parent cannot have ID") require.Nil(x.tb, parent.Signature(), "first object's parent cannot have signature") @@ -634,7 +633,7 @@ func (x *chainCollector) handleOutgoingObject(headerOriginal object.Object, payl var parentNoPayloadInfo object.Object - cID, _ := x.parentHeader.ContainerID() + cID := x.parentHeader.GetContainerID() parentNoPayloadInfo.SetVersion(x.parentHeader.Version()) parentNoPayloadInfo.SetContainerID(cID) parentNoPayloadInfo.SetCreationEpoch(x.parentHeader.CreationEpoch()) @@ -647,8 +646,8 @@ func (x *chainCollector) handleOutgoingObject(headerOriginal object.Object, payl } } - prev, ok := header.PreviousID() - if ok { + prev := header.GetPreviousID() + if !prev.IsZero() { _, ok := x.mNext[prev] require.False(x.tb, ok, "split-chain must not be forked") @@ -773,8 +772,8 @@ func (x *chainCollector) verify(in input, rootID oid.ID) { require.Equal(x.tb, x.children, restoredChain) } - id, ok := rootObj.ID() - require.True(x.tb, ok, "root object must have an ID") + id := rootObj.GetID() + require.False(x.tb, id.IsZero(), "root object must have an ID") require.True(x.tb, id == rootID, "root ID in root object must be returned in the result") checkStaticMetadata(x.tb, rootObj, in) @@ -802,8 +801,8 @@ func (w *memoryWriter) ObjectPutInit(_ context.Context, hdr object.Object, _ use w.headers = append(w.headers, objectCopy) if w.firstObject == nil { - first, set := hdr.FirstID() - if set { + first := hdr.GetFirstID() + if !first.IsZero() { w.firstObject = &first } } @@ -868,12 +867,12 @@ func TestSlicedObjectsHaveSplitID(t *testing.T) { require.Equal(t, overheadAmount+1, uint64(len(writer.headers))) for i, h := range writer.headers { - first, set := h.FirstID() + first := h.GetFirstID() if i == 0 { - require.False(t, set) + require.True(t, first.IsZero()) } else { - require.True(t, set) + require.False(t, first.IsZero()) require.Equal(t, *writer.firstObject, first) } @@ -907,12 +906,12 @@ func TestSlicedObjectsHaveSplitID(t *testing.T) { require.Equal(t, overheadAmount+1, uint64(len(writer.headers))) for i, h := range writer.headers { - first, set := h.FirstID() + first := h.GetFirstID() if i == 0 { - require.False(t, set) + require.True(t, first.IsZero()) } else { - require.True(t, set) + require.False(t, first.IsZero()) require.Equal(t, *writer.firstObject, first) } diff --git a/object/splitinfo.go b/object/splitinfo.go index 86f6a5a7..0a73404a 100644 --- a/object/splitinfo.go +++ b/object/splitinfo.go @@ -40,7 +40,7 @@ func (s *SplitInfo) ToV2() *object.SplitInfo { } // SplitID returns [SplitID] if it has been set. New objects may miss it, -// use [SplitInfo.FirstPart] as a split chain identifier. +// use [SplitInfo.GetFirstPart] as a split chain identifier. // // The value returned shares memory with the structure itself, so changing it can lead to data corruption. // Make a copy if you need to change it. @@ -65,21 +65,28 @@ func (s *SplitInfo) SetSplitID(v *SplitID) { // The second return value is a flag, indicating if the last part is present. // // See also [SplitInfo.SetLastPart]. -func (s SplitInfo) LastPart() (v oid.ID, isSet bool) { - v2 := (object.SplitInfo)(s) +// Deprecated: use [SplitInfo.GetLastPart] instead. +func (s SplitInfo) LastPart() (oid.ID, bool) { + id := s.GetLastPart() + return id, !id.IsZero() +} - lpV2 := v2.GetLastPart() - if lpV2 != nil { - _ = v.ReadFromV2(*lpV2) - isSet = true +// GetLastPart returns last object ID, can be used to retrieve original object. +// Zero return means unset ID. +// +// See also [SplitInfo.SetLastPart]. +func (s SplitInfo) GetLastPart() oid.ID { + var id oid.ID + m := (*object.SplitInfo)(&s).GetLastPart() + if m != nil { + _ = id.ReadFromV2(*m) } - - return + return id } // SetLastPart sets the last object ID. // -// See also [SplitInfo.LastPart]. +// See also [SplitInfo.GetLastPart]. func (s *SplitInfo) SetLastPart(v oid.ID) { var idV2 refs.ObjectID v.WriteToV2(&idV2) @@ -91,21 +98,26 @@ func (s *SplitInfo) SetLastPart(v oid.ID) { // The second return value is a flag, indicating if the last part is present. // // See also [SplitInfo.SetLink]. -func (s SplitInfo) Link() (v oid.ID, isSet bool) { - v2 := (object.SplitInfo)(s) +// Deprecated: use [SplitInfo.GetLink] instead. +func (s SplitInfo) Link() (oid.ID, bool) { + id := s.GetLink() + return id, !id.IsZero() +} - linkV2 := v2.GetLink() - if linkV2 != nil { - _ = v.ReadFromV2(*linkV2) - isSet = true +// GetLink returns a linker object ID. Zero return means unset ID. +// +// See also [SplitInfo.SetLink]. +func (s SplitInfo) GetLink() oid.ID { + var id oid.ID + if m := (*object.SplitInfo)(&s).GetLink(); m != nil { + _ = id.ReadFromV2(*m) } - - return + return id } // SetLink sets linker object ID. // -// See also [SplitInfo.Link]. +// See also [SplitInfo.GetLink]. func (s *SplitInfo) SetLink(v oid.ID) { var idV2 refs.ObjectID v.WriteToV2(&idV2) @@ -116,21 +128,26 @@ func (s *SplitInfo) SetLink(v oid.ID) { // FirstPart returns the first part of the split chain. // // See also [SplitInfo.SetFirstPart]. -func (s SplitInfo) FirstPart() (v oid.ID, isSet bool) { - v2 := (object.SplitInfo)(s) +// Deprecated: use [SplitInfo.GetFirstPart] instead. +func (s SplitInfo) FirstPart() (oid.ID, bool) { + id := s.GetFirstPart() + return id, !id.IsZero() +} - firstV2 := v2.GetFirstPart() - if firstV2 != nil { - _ = v.ReadFromV2(*firstV2) - isSet = true +// GetFirstPart returns the first part of the split chain. Zero means unset ID. +// +// See also [SplitInfo.SetFirstPart]. +func (s SplitInfo) GetFirstPart() oid.ID { + var id oid.ID + if m := (*object.SplitInfo)(&s).GetFirstPart(); m != nil { + _ = id.ReadFromV2(*m) } - - return + return id } // SetFirstPart sets the first part of the split chain. // -// See also [SplitInfo.FirstPart]. +// See also [SplitInfo.GetFirstPart]. func (s *SplitInfo) SetFirstPart(v oid.ID) { var idV2 refs.ObjectID v.WriteToV2(&idV2) diff --git a/object/splitinfo_test.go b/object/splitinfo_test.go index 33cea156..c9294b9c 100644 --- a/object/splitinfo_test.go +++ b/object/splitinfo_test.go @@ -21,19 +21,13 @@ func TestSplitInfo(t *testing.T) { require.Equal(t, splitID, s.SplitID()) s.SetLastPart(lastPart) - lp, set := s.LastPart() - require.True(t, set) - require.Equal(t, lastPart, lp) + require.Equal(t, lastPart, s.GetLastPart()) s.SetLink(link) - l, set := s.Link() - require.True(t, set) - require.Equal(t, link, l) + require.Equal(t, link, s.GetLink()) s.SetFirstPart(firstPart) - ip, set := s.FirstPart() - require.True(t, set) - require.Equal(t, firstPart, ip) + require.Equal(t, firstPart, s.GetFirstPart()) } func TestSplitInfoMarshal(t *testing.T) { @@ -107,12 +101,9 @@ func TestNewSplitInfo(t *testing.T) { // check initial values require.Nil(t, si.SplitID()) - _, set := si.LastPart() - require.False(t, set) - _, set = si.Link() - require.False(t, set) - _, set = si.FirstPart() - require.False(t, set) + require.True(t, si.GetLastPart().IsZero()) + require.True(t, si.GetLink().IsZero()) + require.True(t, si.GetFirstPart().IsZero()) // convert to v2 message siV2 := si.ToV2() diff --git a/pool/object.go b/pool/object.go index 66c28ac0..baf73917 100644 --- a/pool/object.go +++ b/pool/object.go @@ -37,9 +37,9 @@ func (p *Pool) ObjectPutInit(ctx context.Context, hdr object.Object, signer user return nil, err } - cnr, isSet := hdr.ContainerID() - if !isSet { - return nil, errContainerRequired + cnr := hdr.GetContainerID() + if cnr.IsZero() { + return nil, cid.ErrZero } if err = p.withinContainerSession( diff --git a/pool/session.go b/pool/session.go index 70b087b9..365efaa3 100644 --- a/pool/session.go +++ b/pool/session.go @@ -14,10 +14,6 @@ import ( "github.com/nspcc-dev/neofs-sdk-go/user" ) -var ( - errContainerRequired = errors.New("container required") -) - func initSession(ctx context.Context, c *sdkClientWrapper, dur uint64, signer user.Signer) (session.Object, error) { tok := c.nodeSession.GetNodeSession(signer.Public()) if tok != nil { diff --git a/session/common.go b/session/common.go index 89bf9322..98049be5 100644 --- a/session/common.go +++ b/session/common.go @@ -16,8 +16,7 @@ type commonData struct { idSet bool id uuid.UUID - issuerSet bool - issuer user.ID + issuer user.ID lifetimeSet bool iat, nbf, exp uint64 @@ -34,9 +33,7 @@ func (x commonData) copyTo(dst *commonData) { dst.idSet = x.idSet dst.id = x.id - dst.issuerSet = x.issuerSet - iss := x.issuer - dst.issuer = iss + dst.issuer = x.issuer dst.lifetimeSet = x.lifetimeSet dst.iat = x.iat @@ -74,13 +71,15 @@ func (x *commonData) readFromV2(m session.Token, checkFieldPresence bool, r cont } issuer := body.GetOwnerID() - if x.issuerSet = issuer != nil; x.issuerSet { + if issuer != nil { err = x.issuer.ReadFromV2(*issuer) if err != nil { return fmt.Errorf("invalid session issuer: %w", err) } } else if checkFieldPresence { return errors.New("missing session issuer") + } else { + x.issuer = user.ID{} } lifetime := body.GetLifetime() @@ -131,7 +130,7 @@ func (x commonData) fillBody(w contextWriter) *session.TokenBody { body.SetID(binID) } - if x.issuerSet { + if !x.issuer.IsZero() { var issuer refs.OwnerID x.issuer.WriteToV2(&issuer) @@ -325,7 +324,6 @@ func (x *commonData) SetAuthKey(key neofscrypto.PublicKey) { // When using it please ensure that the token is signed with the same signer as the issuer passed here. func (x *commonData) SetIssuer(id user.ID) { x.issuer = id - x.issuerSet = true } // AssertAuthKey asserts public key bound to the session. @@ -344,11 +342,7 @@ func (x commonData) AssertAuthKey(key neofscrypto.PublicKey) bool { // // See also Sign. func (x commonData) Issuer() user.ID { - if x.issuerSet { - return x.issuer - } - - return user.ID{} + return x.issuer } // IssuerPublicKeyBytes returns binary-encoded public key of the session issuer. diff --git a/session/common_test.go b/session/common_test.go index 1d001fb5..fc6fb78b 100644 --- a/session/common_test.go +++ b/session/common_test.go @@ -23,7 +23,6 @@ func Test_commonData_copyTo(t *testing.T) { data := commonData{ idSet: true, id: uuid.New(), - issuerSet: true, issuer: usr.UserID(), lifetimeSet: true, iat: 1, @@ -45,8 +44,7 @@ func Test_commonData_copyTo(t *testing.T) { require.Equal(t, data, dst) require.True(t, bytes.Equal(data.marshal(emptyWriter), dst.marshal(emptyWriter))) - require.Equal(t, data.issuerSet, dst.issuerSet) - require.Equal(t, data.issuer.String(), dst.issuer.String()) + require.Equal(t, data.issuer, dst.issuer) }) t.Run("change id", func(t *testing.T) { @@ -95,38 +93,35 @@ func Test_commonData_copyTo(t *testing.T) { var dst commonData data.copyTo(&dst) - require.Equal(t, data.issuerSet, dst.issuerSet) require.True(t, data.issuer == dst.issuer) dst.SetIssuer(usertest.OtherID(usr.ID)) - require.Equal(t, data.issuerSet, dst.issuerSet) require.False(t, data.issuer == dst.issuer) }) t.Run("overwrite issuer", func(t *testing.T) { var local commonData - require.False(t, local.issuerSet) + require.Zero(t, local.issuer) var dst commonData dst.SetIssuer(usertest.OtherID(usr.ID)) - require.True(t, dst.issuerSet) + require.NotZero(t, dst.issuer) local.copyTo(&dst) - require.False(t, local.issuerSet) - require.False(t, dst.issuerSet) + require.Zero(t, local.issuer) + require.Zero(t, dst.issuer) emptyWriter := func() session.TokenContext { return &session.ContainerSessionContext{} } require.True(t, bytes.Equal(local.marshal(emptyWriter), dst.marshal(emptyWriter))) - require.Equal(t, local.issuerSet, dst.issuerSet) require.True(t, local.issuer == dst.issuer) dst.SetIssuer(usertest.OtherID(usr.ID)) - require.False(t, local.issuerSet) - require.True(t, dst.issuerSet) + require.Zero(t, local.issuer) + require.NotZero(t, dst.issuer) require.False(t, local.issuer == dst.issuer) }) @@ -203,7 +198,7 @@ func Test_commonData_copyTo(t *testing.T) { dst.sig.SetScheme(100) dst.sig.SetSign([]byte{10, 11, 12}) - require.Equal(t, data.issuerSet, dst.issuerSet) + require.Equal(t, data.issuer, dst.issuer) require.NotEqual(t, data.sig.GetScheme(), dst.sig.GetScheme()) require.False(t, bytes.Equal(data.sig.GetKey(), dst.sig.GetKey())) require.False(t, bytes.Equal(data.sig.GetSign(), dst.sig.GetSign())) diff --git a/session/container.go b/session/container.go index b8a3c837..9bf36174 100644 --- a/session/container.go +++ b/session/container.go @@ -26,8 +26,7 @@ type Container struct { verb ContainerVerb - cnrSet bool - cnr cid.ID + cnr cid.ID } // CopyTo writes deep copy of the [Container] to dst. @@ -35,10 +34,7 @@ func (x Container) CopyTo(dst *Container) { x.commonData.copyTo(&dst.commonData) dst.verb = x.verb - - dst.cnrSet = x.cnrSet - contID := x.cnr - dst.cnr = contID + dst.cnr = x.cnr } // readContext is a contextReader needed for commonData methods. @@ -48,10 +44,10 @@ func (x *Container) readContext(c session.TokenContext, checkFieldPresence bool) return fmt.Errorf("invalid context %T", c) } - x.cnrSet = !cCnr.Wildcard() + wildcard := cCnr.Wildcard() cnr := cCnr.ContainerID() - if x.cnrSet { + if !wildcard { if cnr != nil { err := x.cnr.ReadFromV2(*cnr) if err != nil { @@ -59,6 +55,8 @@ func (x *Container) readContext(c session.TokenContext, checkFieldPresence bool) } } else if checkFieldPresence { return errors.New("missing container or wildcard flag") + } else { + x.cnr = cid.ID{} } } else if cnr != nil { return errors.New("container conflicts with wildcard flag") @@ -82,11 +80,12 @@ func (x *Container) ReadFromV2(m session.Token) error { } func (x Container) writeContext() session.TokenContext { + wildcard := x.cnr.IsZero() var c session.ContainerSessionContext - c.SetWildcard(!x.cnrSet) + c.SetWildcard(wildcard) c.SetVerb(session.ContainerSessionVerb(x.verb)) - if x.cnrSet { + if !wildcard { var cnr refs.ContainerID x.cnr.WriteToV2(&cnr) @@ -148,7 +147,9 @@ func (x *Container) UnmarshalJSON(data []byte) error { // See also [Container.VerifySignature], [Container.SignedData]. func (x *Container) Sign(signer user.Signer) error { x.issuer = signer.UserID() - x.issuerSet = true + if x.issuer.IsZero() { + return user.ErrZeroID + } return x.SetSignature(signer) } @@ -194,7 +195,6 @@ func (x Container) VerifySignature() bool { // See also AppliedTo. func (x *Container) ApplyOnlyTo(cnr cid.ID) { x.cnr = cnr - x.cnrSet = true } // AppliedTo checks if the session is propagated to the given container. @@ -203,7 +203,7 @@ func (x *Container) ApplyOnlyTo(cnr cid.ID) { // // See also ApplyOnlyTo. func (x Container) AppliedTo(cnr cid.ID) bool { - return !x.cnrSet || x.cnr == cnr + return x.cnr.IsZero() || x.cnr == cnr } // ContainerVerb enumerates container operations. diff --git a/session/container_internal_test.go b/session/container_internal_test.go index 3cbfbe80..0d32735f 100644 --- a/session/container_internal_test.go +++ b/session/container_internal_test.go @@ -34,23 +34,23 @@ func TestContainer_CopyTo(t *testing.T) { container.CopyTo(&dst) require.Equal(t, container.verb, dst.verb) - require.True(t, container.cnrSet) - require.True(t, dst.cnrSet) + require.NotZero(t, container.cnr) + require.NotZero(t, dst.cnr) container.ForVerb(VerbContainerSetEACL) require.NotEqual(t, container.verb, dst.verb) - require.True(t, container.cnrSet) - require.True(t, dst.cnrSet) + require.NotZero(t, container.cnr) + require.NotZero(t, dst.cnr) }) t.Run("overwrite container id", func(t *testing.T) { var local Container - require.False(t, local.cnrSet) + require.Zero(t, local.cnr) var dst Container dst.ApplyOnlyTo(containerID) - require.True(t, dst.cnrSet) + require.NotZero(t, dst.cnr) local.CopyTo(&dst) emptyWriter := func() session.TokenContext { @@ -60,11 +60,11 @@ func TestContainer_CopyTo(t *testing.T) { require.Equal(t, local, dst) require.True(t, bytes.Equal(local.marshal(emptyWriter), dst.marshal(emptyWriter))) - require.False(t, local.cnrSet) - require.False(t, dst.cnrSet) + require.Zero(t, local.cnr) + require.Zero(t, dst.cnr) dst.ApplyOnlyTo(containerID) - require.True(t, dst.cnrSet) - require.False(t, local.cnrSet) + require.NotZero(t, dst.cnr) + require.Zero(t, local.cnr) }) } diff --git a/session/object.go b/session/object.go index 1ccc54b1..fc19dd46 100644 --- a/session/object.go +++ b/session/object.go @@ -27,8 +27,7 @@ type Object struct { verb ObjectVerb - cnrSet bool - cnr cid.ID + cnr cid.ID objs []oid.ID } @@ -38,10 +37,7 @@ func (x Object) CopyTo(dst *Object) { x.commonData.copyTo(&dst.commonData) dst.verb = x.verb - - dst.cnrSet = x.cnrSet - contID := x.cnr - dst.cnr = contID + dst.cnr = x.cnr if objs := x.objs; objs != nil { dst.objs = make([]oid.ID, len(x.objs)) @@ -60,13 +56,15 @@ func (x *Object) readContext(c session.TokenContext, checkFieldPresence bool) er var err error cnr := cObj.GetContainer() - if x.cnrSet = cnr != nil; x.cnrSet { + if cnr != nil { err := x.cnr.ReadFromV2(*cnr) if err != nil { return fmt.Errorf("invalid container ID: %w", err) } } else if checkFieldPresence { return errors.New("missing target container") + } else { + x.cnr = cid.ID{} } objs := cObj.GetObjects() @@ -104,10 +102,10 @@ func (x Object) writeContext() session.TokenContext { var c session.ObjectSessionContext c.SetVerb(session.ObjectSessionVerb(x.verb)) - if x.cnrSet || len(x.objs) > 0 { + if !x.cnr.IsZero() || len(x.objs) > 0 { var cnr *refs.ContainerID - if x.cnrSet { + if !x.cnr.IsZero() { cnr = new(refs.ContainerID) x.cnr.WriteToV2(cnr) } @@ -183,7 +181,9 @@ func (x *Object) UnmarshalJSON(data []byte) error { // See also [Object.VerifySignature], [Object.SignedData]. func (x *Object) Sign(signer user.Signer) error { x.issuer = signer.UserID() - x.issuerSet = true + if x.issuer.IsZero() { + return user.ErrZeroID + } return x.SetSignature(signer) } @@ -229,7 +229,6 @@ func (x Object) VerifySignature() bool { // See also AssertContainer. func (x *Object) BindContainer(cnr cid.ID) { x.cnr = cnr - x.cnrSet = true } // AssertContainer checks if Object session bound to a given container. diff --git a/session/object_internal_test.go b/session/object_internal_test.go index da58fffe..a3a44329 100644 --- a/session/object_internal_test.go +++ b/session/object_internal_test.go @@ -36,14 +36,14 @@ func TestObject_CopyTo(t *testing.T) { container.CopyTo(&dst) require.Equal(t, container.verb, dst.verb) - require.True(t, container.cnrSet) - require.True(t, dst.cnrSet) + require.NotZero(t, container.cnr) + require.NotZero(t, dst.cnr) container.ForVerb(VerbObjectHead) require.NotEqual(t, container.verb, dst.verb) - require.True(t, container.cnrSet) - require.True(t, dst.cnrSet) + require.NotZero(t, container.cnr) + require.NotZero(t, dst.cnr) }) t.Run("change ids", func(t *testing.T) { @@ -66,11 +66,11 @@ func TestObject_CopyTo(t *testing.T) { t.Run("overwrite container id", func(t *testing.T) { var local Object - require.False(t, local.cnrSet) + require.Zero(t, local.cnr) var dst Object dst.BindContainer(containerID) - require.True(t, dst.cnrSet) + require.NotZero(t, dst.cnr) local.CopyTo(&dst) emptyWriter := func() session.TokenContext { @@ -80,11 +80,11 @@ func TestObject_CopyTo(t *testing.T) { require.Equal(t, local, dst) require.True(t, bytes.Equal(local.marshal(emptyWriter), dst.marshal(emptyWriter))) - require.False(t, local.cnrSet) - require.False(t, dst.cnrSet) + require.Zero(t, local.cnr) + require.Zero(t, dst.cnr) dst.BindContainer(containerID) - require.True(t, dst.cnrSet) - require.False(t, local.cnrSet) + require.NotZero(t, dst.cnr) + require.Zero(t, local.cnr) }) } diff --git a/user/id.go b/user/id.go index e730622d..a71aa6e7 100644 --- a/user/id.go +++ b/user/id.go @@ -18,6 +18,7 @@ import ( const IDSize = 25 // ID identifies users of the NeoFS system and represents Neo3 account address. +// Zero ID is usually prohibited, see docs for details. // // ID implements built-in comparable interface. // @@ -27,6 +28,9 @@ const IDSize = 25 // Zero ID is not valid. type ID [IDSize]byte +// ErrZeroID is an error returned on zero [ID] encounter. +var ErrZeroID = errors.New("zero user ID") + // NewFromScriptHash creates new ID and makes [ID.SetScriptHash]. func NewFromScriptHash(scriptHash util.Uint160) ID { var x ID @@ -67,7 +71,11 @@ func (x *ID) decodeBytes(b []byte) error { // // See also WriteToV2. func (x *ID) ReadFromV2(m refs.OwnerID) error { - return x.decodeBytes(m.GetValue()) + err := x.decodeBytes(m.GetValue()) + if err == nil && x.IsZero() { + err = ErrZeroID + } + return err } // WriteToV2 writes ID to the refs.OwnerID message. @@ -129,3 +137,13 @@ func (x ID) String() string { func (x ID) Equals(x2 ID) bool { return x == x2 } + +// IsZero checks whether ID is zero. +func (x ID) IsZero() bool { + for i := range x { + if x[i] != 0 { + return false + } + } + return true +} diff --git a/user/id_test.go b/user/id_test.go index 1cca5d50..e3ddf5d4 100644 --- a/user/id_test.go +++ b/user/id_test.go @@ -48,6 +48,7 @@ func TestID_ReadFromV2(t *testing.T) { {name: "empty value", err: "invalid length 0, expected 25", val: []byte{}}, {name: "undersized value", err: "invalid length 24, expected 25", val: validBytes[:24]}, {name: "oversized value", err: "invalid length 26, expected 25", val: append(validBytes[:], 1)}, + {name: "zero value", err: "invalid prefix byte 0x0, expected 0x35", val: make([]byte, user.IDSize)}, } { t.Run(tc.name, func(t *testing.T) { var m refs.OwnerID @@ -150,3 +151,13 @@ func TestID_String(t *testing.T) { require.Equal(t, id.String(), id.String()) require.NotEqual(t, id.String(), usertest.OtherID(id).String()) } + +func TestID_IsZero(t *testing.T) { + var id user.ID + require.True(t, id.IsZero()) + for i := 0; i < user.IDSize; i++ { + var id2 user.ID + id2[i]++ + require.False(t, id2.IsZero()) + } +} diff --git a/waiter/container_eacl.go b/waiter/container_eacl.go index 09c01225..dd03a39c 100644 --- a/waiter/container_eacl.go +++ b/waiter/container_eacl.go @@ -45,8 +45,8 @@ func (w ContainerSetEACLWaiter) ContainerSetEACL(ctx context.Context, table eacl return fmt.Errorf("container setEacl: %w", err) } - contID, ok := table.CID() - if !ok { + contID := table.GetCID() + if contID.IsZero() { return client.ErrMissingEACLContainer }