@@ -29,7 +29,7 @@ import (
29
29
30
30
// NewHybridDecrypt returns an HybridDecrypt primitive from the given keyset handle.
31
31
func NewHybridDecrypt (handle * keyset.Handle ) (tink.HybridDecrypt , error ) {
32
- ps , err := handle .Primitives ( internalapi.Token {})
32
+ ps , err := keyset .Primitives [tink. HybridDecrypt ]( handle , internalapi.Token {})
33
33
if err != nil {
34
34
return nil , fmt .Errorf ("hybrid_factory: cannot obtain primitive set: %s" , err )
35
35
}
@@ -39,22 +39,22 @@ func NewHybridDecrypt(handle *keyset.Handle) (tink.HybridDecrypt, error) {
39
39
// wrappedHybridDecrypt is an HybridDecrypt implementation that uses the underlying primitive set
40
40
// for decryption.
41
41
type wrappedHybridDecrypt struct {
42
- ps * primitiveset.PrimitiveSet
42
+ ps * primitiveset.PrimitiveSet [tink. HybridDecrypt ]
43
43
logger monitoring.Logger
44
44
}
45
45
46
46
// compile time assertion that wrappedHybridDecrypt implements the HybridDecrypt interface.
47
47
var _ tink.HybridDecrypt = (* wrappedHybridDecrypt )(nil )
48
48
49
- func newWrappedHybridDecrypt (ps * primitiveset.PrimitiveSet ) (* wrappedHybridDecrypt , error ) {
50
- if err := isHybridDecrypt (ps .Primary .Primitive ); err != nil {
51
- return nil , err
49
+ func newWrappedHybridDecrypt (ps * primitiveset.PrimitiveSet [tink.HybridDecrypt ]) (* wrappedHybridDecrypt , error ) {
50
+ // Make sure the primitives do not implement tink.AEAD.
51
+ if isAEAD (ps .Primary .Primitive ) || isAEAD (ps .Primary .FullPrimitive ) {
52
+ return nil , fmt .Errorf ("hybrid_factory: primary primitive must NOT implement tink.AEAD" )
52
53
}
53
-
54
54
for _ , primitives := range ps .Entries {
55
55
for _ , p := range primitives {
56
- if err := isHybridDecrypt (p .Primitive ); err != nil {
57
- return nil , err
56
+ if isAEAD (p .Primitive ) || isAEAD ( p . FullPrimitive ) {
57
+ return nil , fmt . Errorf ( "hybrid_factory: primitive must NOT implement tink.AEAD" )
58
58
}
59
59
}
60
60
}
@@ -68,7 +68,7 @@ func newWrappedHybridDecrypt(ps *primitiveset.PrimitiveSet) (*wrappedHybridDecry
68
68
}, nil
69
69
}
70
70
71
- func createDecryptLogger (ps * primitiveset.PrimitiveSet ) (monitoring.Logger , error ) {
71
+ func createDecryptLogger (ps * primitiveset.PrimitiveSet [tink. HybridDecrypt ] ) (monitoring.Logger , error ) {
72
72
if len (ps .Annotations ) == 0 {
73
73
return & monitoringutil.DoNothingLogger {}, nil
74
74
}
@@ -94,22 +94,19 @@ func (a *wrappedHybridDecrypt) Decrypt(ciphertext, contextInfo []byte) ([]byte,
94
94
entries , err := a .ps .EntriesForPrefix (string (prefix ))
95
95
if err == nil {
96
96
for i := 0 ; i < len (entries ); i ++ {
97
- p := entries [i ].Primitive .(tink.HybridDecrypt ) // verified in newWrappedHybridDecrypt
98
- pt , err := p .Decrypt (ctNoPrefix , contextInfo )
97
+ pt , err := entries [i ].Primitive .Decrypt (ctNoPrefix , contextInfo )
99
98
if err == nil {
100
99
a .logger .Log (entries [i ].KeyID , len (ctNoPrefix ))
101
100
return pt , nil
102
101
}
103
102
}
104
103
}
105
104
}
106
-
107
105
// try raw keys
108
106
entries , err := a .ps .RawEntries ()
109
107
if err == nil {
110
108
for i := 0 ; i < len (entries ); i ++ {
111
- p := entries [i ].Primitive .(tink.HybridDecrypt ) // verified in newWrappedHybridDecrypt
112
- pt , err := p .Decrypt (ciphertext , contextInfo )
109
+ pt , err := entries [i ].Primitive .Decrypt (ciphertext , contextInfo )
113
110
if err == nil {
114
111
a .logger .Log (entries [i ].KeyID , len (ciphertext ))
115
112
return pt , nil
@@ -121,15 +118,3 @@ func (a *wrappedHybridDecrypt) Decrypt(ciphertext, contextInfo []byte) ([]byte,
121
118
a .logger .LogFailure ()
122
119
return nil , fmt .Errorf ("hybrid_factory: decryption failed" )
123
120
}
124
-
125
- // Asserts `p` implements tink.HybridDecrypt and not tink.AEAD. The latter check
126
- // is required as implementations of tink.AEAD also satisfy tink.HybridDecrypt.
127
- func isHybridDecrypt (p any ) error {
128
- if _ , ok := p .(tink.AEAD ); ok {
129
- return fmt .Errorf ("hybrid_factory: tink.AEAD is not tink.HybridDecrypt" )
130
- }
131
- if _ , ok := p .(tink.HybridDecrypt ); ! ok {
132
- return fmt .Errorf ("hybrid_factory: not tink.HybridDecrypt" )
133
- }
134
- return nil
135
- }
0 commit comments