Skip to content

Commit f3cbf84

Browse files
morambrocopybara-github
authored andcommittedFeb 13, 2025
Automated rollback of commit 28001de.
PiperOrigin-RevId: 726361336 Change-Id: I2c62133c1129527250a9089713fd460b31160946
1 parent ac534b5 commit f3cbf84

22 files changed

+238
-436
lines changed
 

‎aead/aead_factory.go

+8-16
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import (
3030

3131
// New returns an AEAD primitive from the given keyset handle.
3232
func New(handle *keyset.Handle) (tink.AEAD, error) {
33-
ps, err := handle.Primitives(internalapi.Token{})
33+
ps, err := keyset.Primitives[tink.AEAD](handle, internalapi.Token{})
3434
if err != nil {
3535
return nil, fmt.Errorf("aead_factory: cannot obtain primitive set: %s", err)
3636
}
@@ -40,7 +40,7 @@ func New(handle *keyset.Handle) (tink.AEAD, error) {
4040
// NewWithConfig creates an AEAD primitive from the given [keyset.Handle] using
4141
// the provided [Config].
4242
func NewWithConfig(handle *keyset.Handle, config keyset.Config) (tink.AEAD, error) {
43-
ps, err := handle.Primitives(internalapi.Token{}, keyset.WithConfig(config))
43+
ps, err := keyset.Primitives[tink.AEAD](handle, internalapi.Token{}, keyset.WithConfig(config))
4444
if err != nil {
4545
return nil, fmt.Errorf("aead_factory: cannot obtain primitive set with config: %s", err)
4646
}
@@ -90,26 +90,18 @@ func (a *fullAEADPrimitiveAdapter) Decrypt(ciphertext, associatedData []byte) ([
9090
}
9191

9292
// extractFullAEAD returns a full aeadAndKeyID primitive from the given
93-
// [primitiveset.Entry].
94-
func extractFullAEAD(entry *primitiveset.Entry) (*aeadAndKeyID, error) {
93+
// [primitiveset.Entry[tink.AEAD]].
94+
func extractFullAEAD(entry *primitiveset.Entry[tink.AEAD]) (*aeadAndKeyID, error) {
9595
if entry.FullPrimitive != nil {
96-
a, ok := (entry.FullPrimitive).(tink.AEAD)
97-
if !ok {
98-
return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
99-
}
100-
return &aeadAndKeyID{primitive: a, keyID: entry.KeyID}, nil
101-
}
102-
a, ok := (entry.Primitive).(tink.AEAD)
103-
if !ok {
104-
return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
96+
return &aeadAndKeyID{primitive: entry.FullPrimitive, keyID: entry.KeyID}, nil
10597
}
10698
return &aeadAndKeyID{
107-
primitive: &fullAEADPrimitiveAdapter{primitive: a, prefix: []byte(entry.Prefix)},
99+
primitive: &fullAEADPrimitiveAdapter{primitive: entry.Primitive, prefix: []byte(entry.Prefix)},
108100
keyID: entry.KeyID,
109101
}, nil
110102
}
111103

112-
func newWrappedAead(ps *primitiveset.PrimitiveSet) (*wrappedAead, error) {
104+
func newWrappedAead(ps *primitiveset.PrimitiveSet[tink.AEAD]) (*wrappedAead, error) {
113105
primary, err := extractFullAEAD(ps.Primary)
114106
if err != nil {
115107
return nil, err
@@ -136,7 +128,7 @@ func newWrappedAead(ps *primitiveset.PrimitiveSet) (*wrappedAead, error) {
136128
}, nil
137129
}
138130

139-
func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring.Logger, error) {
131+
func createLoggers(ps *primitiveset.PrimitiveSet[tink.AEAD]) (monitoring.Logger, monitoring.Logger, error) {
140132
if len(ps.Annotations) == 0 {
141133
return &monitoringutil.DoNothingLogger{}, &monitoringutil.DoNothingLogger{}, nil
142134
}

‎daead/daead_factory.go

+6-27
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929

3030
// New returns a DeterministicAEAD primitive from the given keyset handle.
3131
func New(handle *keyset.Handle) (tink.DeterministicAEAD, error) {
32-
ps, err := handle.Primitives(internalapi.Token{})
32+
ps, err := keyset.Primitives[tink.DeterministicAEAD](handle, internalapi.Token{})
3333
if err != nil {
3434
return nil, fmt.Errorf("daead_factory: cannot obtain primitive set: %s", err)
3535
}
@@ -39,26 +39,15 @@ func New(handle *keyset.Handle) (tink.DeterministicAEAD, error) {
3939
// wrappedDeterministicAEAD is a DeterministicAEAD implementation that uses an underlying primitive set
4040
// for deterministic encryption and decryption.
4141
type wrappedDeterministicAEAD struct {
42-
ps *primitiveset.PrimitiveSet
42+
ps *primitiveset.PrimitiveSet[tink.DeterministicAEAD]
4343
encLogger monitoring.Logger
4444
decLogger monitoring.Logger
4545
}
4646

4747
// Asserts that wrappedDeterministicAEAD implements the DeterministicAEAD interface.
4848
var _ tink.DeterministicAEAD = (*wrappedDeterministicAEAD)(nil)
4949

50-
func newWrappedDeterministicAEAD(ps *primitiveset.PrimitiveSet) (*wrappedDeterministicAEAD, error) {
51-
if _, ok := (ps.Primary.Primitive).(tink.DeterministicAEAD); !ok {
52-
return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
53-
}
54-
55-
for _, primitives := range ps.Entries {
56-
for _, p := range primitives {
57-
if _, ok := (p.Primitive).(tink.DeterministicAEAD); !ok {
58-
return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
59-
}
60-
}
61-
}
50+
func newWrappedDeterministicAEAD(ps *primitiveset.PrimitiveSet[tink.DeterministicAEAD]) (*wrappedDeterministicAEAD, error) {
6251
encLogger, decLogger, err := createLoggers(ps)
6352
if err != nil {
6453
return nil, err
@@ -70,7 +59,7 @@ func newWrappedDeterministicAEAD(ps *primitiveset.PrimitiveSet) (*wrappedDetermi
7059
}, nil
7160
}
7261

73-
func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring.Logger, error) {
62+
func createLoggers(ps *primitiveset.PrimitiveSet[tink.DeterministicAEAD]) (monitoring.Logger, monitoring.Logger, error) {
7463
if len(ps.Annotations) == 0 {
7564
return &monitoringutil.DoNothingLogger{}, &monitoringutil.DoNothingLogger{}, nil
7665
}
@@ -134,12 +123,7 @@ func (d *wrappedDeterministicAEAD) DecryptDeterministically(ct, aad []byte) ([]b
134123
entries, err := d.ps.EntriesForPrefix(string(prefix))
135124
if err == nil {
136125
for i := 0; i < len(entries); i++ {
137-
p, ok := (entries[i].Primitive).(tink.DeterministicAEAD)
138-
if !ok {
139-
return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
140-
}
141-
142-
pt, err := p.DecryptDeterministically(ctNoPrefix, aad)
126+
pt, err := entries[i].Primitive.DecryptDeterministically(ctNoPrefix, aad)
143127
if err == nil {
144128
d.decLogger.Log(entries[i].KeyID, len(ctNoPrefix))
145129
return pt, nil
@@ -152,12 +136,7 @@ func (d *wrappedDeterministicAEAD) DecryptDeterministically(ct, aad []byte) ([]b
152136
entries, err := d.ps.RawEntries()
153137
if err == nil {
154138
for i := 0; i < len(entries); i++ {
155-
p, ok := (entries[i].Primitive).(tink.DeterministicAEAD)
156-
if !ok {
157-
return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
158-
}
159-
160-
pt, err := p.DecryptDeterministically(ct, aad)
139+
pt, err := entries[i].Primitive.DecryptDeterministically(ct, aad)
161140
if err == nil {
162141
d.decLogger.Log(entries[i].KeyID, len(ct))
163142
return pt, nil

‎hybrid/hybrid_decrypt_factory.go

+11-26
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929

3030
// NewHybridDecrypt returns an HybridDecrypt primitive from the given keyset handle.
3131
func NewHybridDecrypt(handle *keyset.Handle) (tink.HybridDecrypt, error) {
32-
ps, err := handle.Primitives(internalapi.Token{})
32+
ps, err := keyset.Primitives[tink.HybridDecrypt](handle, internalapi.Token{})
3333
if err != nil {
3434
return nil, fmt.Errorf("hybrid_factory: cannot obtain primitive set: %s", err)
3535
}
@@ -39,22 +39,22 @@ func NewHybridDecrypt(handle *keyset.Handle) (tink.HybridDecrypt, error) {
3939
// wrappedHybridDecrypt is an HybridDecrypt implementation that uses the underlying primitive set
4040
// for decryption.
4141
type wrappedHybridDecrypt struct {
42-
ps *primitiveset.PrimitiveSet
42+
ps *primitiveset.PrimitiveSet[tink.HybridDecrypt]
4343
logger monitoring.Logger
4444
}
4545

4646
// compile time assertion that wrappedHybridDecrypt implements the HybridDecrypt interface.
4747
var _ tink.HybridDecrypt = (*wrappedHybridDecrypt)(nil)
4848

49-
func newWrappedHybridDecrypt(ps *primitiveset.PrimitiveSet) (*wrappedHybridDecrypt, error) {
50-
if err := isHybridDecrypt(ps.Primary.Primitive); err != nil {
51-
return nil, err
49+
func newWrappedHybridDecrypt(ps *primitiveset.PrimitiveSet[tink.HybridDecrypt]) (*wrappedHybridDecrypt, error) {
50+
// Make sure the primitives do not implement tink.AEAD.
51+
if isAEAD(ps.Primary.Primitive) || isAEAD(ps.Primary.FullPrimitive) {
52+
return nil, fmt.Errorf("hybrid_factory: primary primitive must NOT implement tink.AEAD")
5253
}
53-
5454
for _, primitives := range ps.Entries {
5555
for _, p := range primitives {
56-
if err := isHybridDecrypt(p.Primitive); err != nil {
57-
return nil, err
56+
if isAEAD(p.Primitive) || isAEAD(p.FullPrimitive) {
57+
return nil, fmt.Errorf("hybrid_factory: primitive must NOT implement tink.AEAD")
5858
}
5959
}
6060
}
@@ -68,7 +68,7 @@ func newWrappedHybridDecrypt(ps *primitiveset.PrimitiveSet) (*wrappedHybridDecry
6868
}, nil
6969
}
7070

71-
func createDecryptLogger(ps *primitiveset.PrimitiveSet) (monitoring.Logger, error) {
71+
func createDecryptLogger(ps *primitiveset.PrimitiveSet[tink.HybridDecrypt]) (monitoring.Logger, error) {
7272
if len(ps.Annotations) == 0 {
7373
return &monitoringutil.DoNothingLogger{}, nil
7474
}
@@ -94,22 +94,19 @@ func (a *wrappedHybridDecrypt) Decrypt(ciphertext, contextInfo []byte) ([]byte,
9494
entries, err := a.ps.EntriesForPrefix(string(prefix))
9595
if err == nil {
9696
for i := 0; i < len(entries); i++ {
97-
p := entries[i].Primitive.(tink.HybridDecrypt) // verified in newWrappedHybridDecrypt
98-
pt, err := p.Decrypt(ctNoPrefix, contextInfo)
97+
pt, err := entries[i].Primitive.Decrypt(ctNoPrefix, contextInfo)
9998
if err == nil {
10099
a.logger.Log(entries[i].KeyID, len(ctNoPrefix))
101100
return pt, nil
102101
}
103102
}
104103
}
105104
}
106-
107105
// try raw keys
108106
entries, err := a.ps.RawEntries()
109107
if err == nil {
110108
for i := 0; i < len(entries); i++ {
111-
p := entries[i].Primitive.(tink.HybridDecrypt) // verified in newWrappedHybridDecrypt
112-
pt, err := p.Decrypt(ciphertext, contextInfo)
109+
pt, err := entries[i].Primitive.Decrypt(ciphertext, contextInfo)
113110
if err == nil {
114111
a.logger.Log(entries[i].KeyID, len(ciphertext))
115112
return pt, nil
@@ -121,15 +118,3 @@ func (a *wrappedHybridDecrypt) Decrypt(ciphertext, contextInfo []byte) ([]byte,
121118
a.logger.LogFailure()
122119
return nil, fmt.Errorf("hybrid_factory: decryption failed")
123120
}
124-
125-
// Asserts `p` implements tink.HybridDecrypt and not tink.AEAD. The latter check
126-
// is required as implementations of tink.AEAD also satisfy tink.HybridDecrypt.
127-
func isHybridDecrypt(p any) error {
128-
if _, ok := p.(tink.AEAD); ok {
129-
return fmt.Errorf("hybrid_factory: tink.AEAD is not tink.HybridDecrypt")
130-
}
131-
if _, ok := p.(tink.HybridDecrypt); !ok {
132-
return fmt.Errorf("hybrid_factory: not tink.HybridDecrypt")
133-
}
134-
return nil
135-
}

‎hybrid/hybrid_encrypt_factory.go

+17-23
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
// NewHybridEncrypt returns an HybridEncrypt primitive from the given keyset handle.
3030
func NewHybridEncrypt(handle *keyset.Handle) (tink.HybridEncrypt, error) {
31-
ps, err := handle.Primitives(internalapi.Token{})
31+
ps, err := keyset.Primitives[tink.HybridEncrypt](handle, internalapi.Token{})
3232
if err != nil {
3333
return nil, fmt.Errorf("hybrid_factory: cannot obtain primitive set: %s", err)
3434
}
@@ -37,22 +37,30 @@ func NewHybridEncrypt(handle *keyset.Handle) (tink.HybridEncrypt, error) {
3737

3838
// encryptPrimitiveSet is an HybridEncrypt implementation that uses the underlying primitive set for encryption.
3939
type wrappedHybridEncrypt struct {
40-
ps *primitiveset.PrimitiveSet
40+
ps *primitiveset.PrimitiveSet[tink.HybridEncrypt]
4141
logger monitoring.Logger
4242
}
4343

4444
// compile time assertion that wrappedHybridEncrypt implements the HybridEncrypt interface.
4545
var _ tink.HybridEncrypt = (*wrappedHybridEncrypt)(nil)
4646

47-
func newEncryptPrimitiveSet(ps *primitiveset.PrimitiveSet) (*wrappedHybridEncrypt, error) {
48-
if err := isHybridEncrypt(ps.Primary.Primitive); err != nil {
49-
return nil, err
47+
func isAEAD(p any) bool {
48+
if p == nil {
49+
return false
5050
}
51+
_, ok := p.(tink.AEAD)
52+
return ok
53+
}
5154

55+
func newEncryptPrimitiveSet(ps *primitiveset.PrimitiveSet[tink.HybridEncrypt]) (*wrappedHybridEncrypt, error) {
56+
// Make sure the primitives do not implement tink.AEAD.
57+
if isAEAD(ps.Primary.Primitive) || isAEAD(ps.Primary.FullPrimitive) {
58+
return nil, fmt.Errorf("hybrid_factory: primary primitive must NOT implement tink.AEAD")
59+
}
5260
for _, primitives := range ps.Entries {
5361
for _, p := range primitives {
54-
if err := isHybridEncrypt(p.Primitive); err != nil {
55-
return nil, err
62+
if isAEAD(p.Primitive) || isAEAD(p.FullPrimitive) {
63+
return nil, fmt.Errorf("hybrid_factory: primitive must NOT implement tink.AEAD")
5664
}
5765
}
5866
}
@@ -66,7 +74,7 @@ func newEncryptPrimitiveSet(ps *primitiveset.PrimitiveSet) (*wrappedHybridEncryp
6674
}, nil
6775
}
6876

69-
func createEncryptLogger(ps *primitiveset.PrimitiveSet) (monitoring.Logger, error) {
77+
func createEncryptLogger(ps *primitiveset.PrimitiveSet[tink.HybridEncrypt]) (monitoring.Logger, error) {
7078
if len(ps.Annotations) == 0 {
7179
return &monitoringutil.DoNothingLogger{}, nil
7280
}
@@ -85,9 +93,7 @@ func createEncryptLogger(ps *primitiveset.PrimitiveSet) (monitoring.Logger, erro
8593
// It returns the concatenation of the primary's identifier and the ciphertext.
8694
func (a *wrappedHybridEncrypt) Encrypt(plaintext, contextInfo []byte) ([]byte, error) {
8795
primary := a.ps.Primary
88-
p := primary.Primitive.(tink.HybridEncrypt) // verified in newEncryptPrimitiveSet
89-
90-
ct, err := p.Encrypt(plaintext, contextInfo)
96+
ct, err := primary.Primitive.Encrypt(plaintext, contextInfo)
9197
if err != nil {
9298
a.logger.LogFailure()
9399
return nil, err
@@ -101,15 +107,3 @@ func (a *wrappedHybridEncrypt) Encrypt(plaintext, contextInfo []byte) ([]byte, e
101107
output = append(output, ct...)
102108
return output, nil
103109
}
104-
105-
// Asserts `p` implements tink.HybridEncrypt and not tink.AEAD. The latter check
106-
// is required as implementations of tink.AEAD also satisfy tink.HybridEncrypt.
107-
func isHybridEncrypt(p any) error {
108-
if _, ok := p.(tink.AEAD); ok {
109-
return fmt.Errorf("hybrid_factory: tink.AEAD is not tink.HybridEncrypt")
110-
}
111-
if _, ok := p.(tink.HybridEncrypt); !ok {
112-
return fmt.Errorf("hybrid_factory: not tink.HybridEncrypt")
113-
}
114-
return nil
115-
}

‎internal/monitoringutil/monitoring_util.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func parseKeyTypeURL(ktu string) string {
5959

6060
// KeysetInfoFromPrimitiveSet creates a `KeysetInfo` from a `PrimitiveSet`.
6161
// This function doesn't guarantee to preserve the ordering of the keys in the keyset.
62-
func KeysetInfoFromPrimitiveSet(ps *primitiveset.PrimitiveSet) (*monitoring.KeysetInfo, error) {
62+
func KeysetInfoFromPrimitiveSet[T any](ps *primitiveset.PrimitiveSet[T]) (*monitoring.KeysetInfo, error) {
6363
if ps == nil {
6464
return nil, fmt.Errorf("primitive set is nil")
6565
}

‎internal/monitoringutil/monitoring_util_test.go

+17-16
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,24 @@ import (
1818
"testing"
1919

2020
"github.com/google/go-cmp/cmp"
21-
"github.com/tink-crypto/tink-go/v2/internal/primitiveset"
2221
"github.com/tink-crypto/tink-go/v2/internal/monitoringutil"
22+
"github.com/tink-crypto/tink-go/v2/internal/primitiveset"
2323
"github.com/tink-crypto/tink-go/v2/monitoring"
24+
"github.com/tink-crypto/tink-go/v2/tink"
2425
tpb "github.com/tink-crypto/tink-go/v2/proto/tink_go_proto"
2526
)
2627

2728
func TestKeysetInfoFromPrimitiveSetWithNilPrimitiveSetFails(t *testing.T) {
28-
if _, err := monitoringutil.KeysetInfoFromPrimitiveSet(nil); err == nil {
29-
t.Errorf("KeysetInfoFromPrimitiveSet(nil) err = nil, want error")
29+
if _, err := monitoringutil.KeysetInfoFromPrimitiveSet[any](nil); err == nil {
30+
t.Errorf("monitoringutil.KeysetInfoFromPrimitiveSet[any](nil) err = nil, want error")
3031
}
3132
}
3233

33-
func validPrimitiveSet() *primitiveset.PrimitiveSet {
34-
return &primitiveset.PrimitiveSet{
35-
Primary: &primitiveset.Entry{},
36-
Entries: map[string][]*primitiveset.Entry{
37-
"one": []*primitiveset.Entry{
34+
func validPrimitiveSet() *primitiveset.PrimitiveSet[tink.AEAD] {
35+
return &primitiveset.PrimitiveSet[tink.AEAD]{
36+
Primary: &primitiveset.Entry[tink.AEAD]{},
37+
Entries: map[string][]*primitiveset.Entry[tink.AEAD]{
38+
"one": []*primitiveset.Entry[tink.AEAD]{
3839
{
3940
Status: tpb.KeyStatusType_ENABLED,
4041
TypeURL: "type.googleapis.com/google.crypto.tink.AesGcmKey",
@@ -68,7 +69,7 @@ func TestKeysetInfoFromPrimitiveSetWithNoPrimaryFails(t *testing.T) {
6869

6970
func TestKeysetInfoFromPrimitiveSetWithInvalidKeyStatusFails(t *testing.T) {
7071
ps := validPrimitiveSet()
71-
ps.Entries["invalid_key_status"] = []*primitiveset.Entry{
72+
ps.Entries["invalid_key_status"] = []*primitiveset.Entry[tink.AEAD]{
7273
{
7374
KeyID: 123,
7475
Status: tpb.KeyStatusType_UNKNOWN_STATUS,
@@ -80,30 +81,30 @@ func TestKeysetInfoFromPrimitiveSetWithInvalidKeyStatusFails(t *testing.T) {
8081
}
8182

8283
func TestKeysetInfoFromPrimitiveSet(t *testing.T) {
83-
ps := &primitiveset.PrimitiveSet{
84-
Primary: &primitiveset.Entry{
84+
ps := &primitiveset.PrimitiveSet[tink.AEAD]{
85+
Primary: &primitiveset.Entry[tink.AEAD]{
8586
KeyID: 1,
8687
},
8788
Annotations: map[string]string{
8889
"foo": "bar",
8990
"zoo": "far",
9091
},
91-
Entries: map[string][]*primitiveset.Entry{
92+
Entries: map[string][]*primitiveset.Entry[tink.AEAD]{
9293
// Adding all entries under the same prefix to get deterministic output.
93-
"one": []*primitiveset.Entry{
94-
&primitiveset.Entry{
94+
"one": []*primitiveset.Entry[tink.AEAD]{
95+
&primitiveset.Entry[tink.AEAD]{
9596
KeyID: 1,
9697
Status: tpb.KeyStatusType_ENABLED,
9798
TypeURL: "type.googleapis.com/google.crypto.tink.AesSivKey",
9899
PrefixType: tpb.OutputPrefixType_TINK,
99100
},
100-
&primitiveset.Entry{
101+
&primitiveset.Entry[tink.AEAD]{
101102
KeyID: 2,
102103
Status: tpb.KeyStatusType_DISABLED,
103104
TypeURL: "type.googleapis.com/google.crypto.tink.AesGcmKey",
104105
PrefixType: tpb.OutputPrefixType_TINK,
105106
},
106-
&primitiveset.Entry{
107+
&primitiveset.Entry[tink.AEAD]{
107108
KeyID: 3,
108109
Status: tpb.KeyStatusType_DESTROYED,
109110
TypeURL: "type.googleapis.com/google.crypto.tink.AesCtrHmacKey",

0 commit comments

Comments
 (0)