|
1 | 1 | package aws_signing_helper
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "crypto" |
| 5 | + "crypto/ecdsa" |
| 6 | + "crypto/sha1" |
| 7 | + "crypto/sha256" |
| 8 | + "crypto/sha512" |
4 | 9 | "crypto/x509"
|
5 | 10 | "encoding/asn1"
|
6 | 11 | "encoding/pem"
|
| 12 | + "errors" |
| 13 | + "fmt" |
| 14 | + "io" |
| 15 | + "math/big" |
| 16 | + "strconv" |
| 17 | + "strings" |
7 | 18 |
|
8 | 19 | tpm2 "github.com/google/go-tpm/tpm2"
|
9 | 20 | tpmutil "github.com/google/go-tpm/tpmutil"
|
@@ -76,3 +87,358 @@ type GetTPMv2SignerOpts struct {
|
76 | 87 | emptyAuth bool
|
77 | 88 | handle string
|
78 | 89 | }
|
| 90 | + |
| 91 | +// Returns the public key associated with this TPMv2Signer |
| 92 | +func (tpmv2Signer *TPMv2Signer) Public() crypto.PublicKey { |
| 93 | + ret, _ := tpmv2Signer.public.Key() |
| 94 | + return ret |
| 95 | +} |
| 96 | + |
| 97 | +// Closes this TPMv2Signer |
| 98 | +func (tpmv2Signer *TPMv2Signer) Close() { |
| 99 | + tpmv2Signer.password = "" |
| 100 | +} |
| 101 | + |
| 102 | +func checkCapability(rw io.ReadWriter, algo tpm2.Algorithm) error { |
| 103 | + descs, _, err := tpm2.GetCapability(rw, tpm2.CapabilityAlgs, 1, uint32(algo)) |
| 104 | + if err != nil { |
| 105 | + errMsg := fmt.Sprintf("error trying to get capability from TPM for the algorithm (%s)", algo) |
| 106 | + return errors.New(errMsg) |
| 107 | + } |
| 108 | + if tpm2.Algorithm(descs[0].(tpm2.AlgorithmDescription).ID) != algo { |
| 109 | + errMsg := fmt.Sprintf("unsupported algorithm (%s) for TPM", algo) |
| 110 | + return errors.New(errMsg) |
| 111 | + } |
| 112 | + |
| 113 | + return nil |
| 114 | +} |
| 115 | + |
| 116 | +// Implements the crypto.Signer interface and signs the passed in digest |
| 117 | +func (tpmv2Signer *TPMv2Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { |
| 118 | + var ( |
| 119 | + keyHandle tpmutil.Handle |
| 120 | + ) |
| 121 | + |
| 122 | + rw, err := openTPM() |
| 123 | + if err != nil { |
| 124 | + return nil, err |
| 125 | + } |
| 126 | + defer rw.Close() |
| 127 | + |
| 128 | + if tpmv2Signer.handle != 0 { |
| 129 | + keyHandle = tpmv2Signer.handle |
| 130 | + } else { |
| 131 | + parentHandle := tpmutil.Handle(tpmv2Signer.tpmData.Parent) |
| 132 | + if !handleIsPersistent(tpmv2Signer.tpmData.Parent) { |
| 133 | + // Parent and owner passwords aren't supported currently when creating a primary given a persistent handle for the parent |
| 134 | + parentHandle, _, err = tpm2.CreatePrimary(rw, tpmutil.Handle(tpmv2Signer.tpmData.Parent), tpm2.PCRSelection{}, "", "", primaryParams) |
| 135 | + if err != nil { |
| 136 | + return nil, err |
| 137 | + } |
| 138 | + defer tpm2.FlushContext(rw, parentHandle) |
| 139 | + } |
| 140 | + |
| 141 | + keyHandle, _, err = tpm2.Load(rw, parentHandle, "", tpmv2Signer.tpmData.Pubkey[2:], tpmv2Signer.tpmData.Privkey[2:]) |
| 142 | + if err != nil { |
| 143 | + return nil, err |
| 144 | + } |
| 145 | + defer tpm2.FlushContext(rw, keyHandle) |
| 146 | + } |
| 147 | + |
| 148 | + var algo tpm2.Algorithm |
| 149 | + var shadigest []byte |
| 150 | + |
| 151 | + switch opts.HashFunc() { |
| 152 | + case crypto.SHA256: |
| 153 | + sha256digest := sha256.Sum256(digest) |
| 154 | + shadigest = sha256digest[:] |
| 155 | + algo = tpm2.AlgSHA256 |
| 156 | + case crypto.SHA384: |
| 157 | + sha384digest := sha512.Sum384(digest) |
| 158 | + shadigest = sha384digest[:] |
| 159 | + algo = tpm2.AlgSHA384 |
| 160 | + case crypto.SHA512: |
| 161 | + sha512digest := sha512.Sum512(digest) |
| 162 | + shadigest = sha512digest[:] |
| 163 | + algo = tpm2.AlgSHA512 |
| 164 | + } |
| 165 | + |
| 166 | + if tpmv2Signer.public.Type == tpm2.AlgECC { |
| 167 | + // Check to see that ECDSA is supported for signing |
| 168 | + err = checkCapability(rw, tpm2.AlgECC) |
| 169 | + if err != nil { |
| 170 | + return nil, err |
| 171 | + } |
| 172 | + |
| 173 | + // For an EC key we lie to the TPM about what the hash is. |
| 174 | + // It doesn't actually matter what the original digest was; |
| 175 | + // the algo we feed to the TPM is *purely* based on the |
| 176 | + // size of the curve itself. We truncate the actual digest, |
| 177 | + // or pad with zeroes, to the byte size of the key. |
| 178 | + pubKey, err := tpmv2Signer.public.Key() |
| 179 | + if err != nil { |
| 180 | + return nil, err |
| 181 | + } |
| 182 | + ecPubKey, ok := pubKey.(*ecdsa.PublicKey) |
| 183 | + if !ok { |
| 184 | + return nil, errors.New("failed to obtain ecdsa.PublicKey") |
| 185 | + } |
| 186 | + bitSize := ecPubKey.Curve.Params().BitSize |
| 187 | + byteSize := (bitSize + 7) / 8 |
| 188 | + if byteSize > sha512.Size { |
| 189 | + byteSize = sha512.Size |
| 190 | + } |
| 191 | + switch byteSize { |
| 192 | + case sha512.Size: |
| 193 | + algo = tpm2.AlgSHA512 |
| 194 | + case sha512.Size384: |
| 195 | + algo = tpm2.AlgSHA384 |
| 196 | + case sha512.Size256: |
| 197 | + algo = tpm2.AlgSHA256 |
| 198 | + case sha1.Size: |
| 199 | + algo = tpm2.AlgSHA1 |
| 200 | + default: |
| 201 | + return nil, errors.New("unsupported curve") |
| 202 | + } |
| 203 | + |
| 204 | + if len(shadigest) > byteSize { |
| 205 | + shadigest = shadigest[:byteSize] |
| 206 | + } |
| 207 | + |
| 208 | + for len(shadigest) < byteSize { |
| 209 | + shadigest = append([]byte{0}, shadigest...) |
| 210 | + } |
| 211 | + |
| 212 | + sig, err := tpmv2Signer.signHelper(rw, keyHandle, shadigest, |
| 213 | + &tpm2.SigScheme{Alg: tpm2.AlgECDSA, Hash: algo}) |
| 214 | + if err != nil { |
| 215 | + return nil, err |
| 216 | + } |
| 217 | + signature, err = asn1.Marshal(struct { |
| 218 | + R *big.Int |
| 219 | + S *big.Int |
| 220 | + }{sig.ECC.R, sig.ECC.S}) |
| 221 | + if err != nil { |
| 222 | + return nil, err |
| 223 | + } |
| 224 | + } else { |
| 225 | + // Check to see that the requested hash function is supported |
| 226 | + err = checkCapability(rw, algo) |
| 227 | + if err != nil { |
| 228 | + return nil, err |
| 229 | + } |
| 230 | + |
| 231 | + // Check to see that RSASSA is supported for signing |
| 232 | + err = checkCapability(rw, tpm2.AlgRSASSA) |
| 233 | + if err != nil { |
| 234 | + return nil, err |
| 235 | + } |
| 236 | + |
| 237 | + sig, err := tpmv2Signer.signHelper(rw, keyHandle, shadigest, |
| 238 | + &tpm2.SigScheme{Alg: tpm2.AlgRSASSA, Hash: algo}) |
| 239 | + if err != nil { |
| 240 | + return nil, err |
| 241 | + } |
| 242 | + signature = sig.RSA.Signature |
| 243 | + } |
| 244 | + return signature, nil |
| 245 | +} |
| 246 | + |
| 247 | +func (tpmv2Signer *TPMv2Signer) signHelper(rw io.ReadWriter, keyHandle tpmutil.Handle, digest tpmutil.U16Bytes, sigScheme *tpm2.SigScheme) (*tpm2.Signature, error) { |
| 248 | + passwordPromptInput := PasswordPromptProps{ |
| 249 | + InitialPassword: tpmv2Signer.password, |
| 250 | + NoPassword: tpmv2Signer.emptyAuth, |
| 251 | + CheckPassword: func(password string) (interface{}, error) { |
| 252 | + return tpm2.Sign(rw, keyHandle, password, digest, nil, sigScheme) |
| 253 | + }, |
| 254 | + IncorrectPasswordMsg: "incorrect TPM key password", |
| 255 | + Prompt: "Please enter your TPM key password:", |
| 256 | + Reprompt: "Incorrect TPM key password. Please try again:", |
| 257 | + ParseErrMsg: "unable to read your TPM key password", |
| 258 | + CheckPasswordAuthorizationErrorMsg: TPM_RC_AUTH_FAIL, |
| 259 | + } |
| 260 | + |
| 261 | + password, sig, err := PasswordPrompt(passwordPromptInput) |
| 262 | + if err != nil { |
| 263 | + return nil, err |
| 264 | + } |
| 265 | + |
| 266 | + tpmv2Signer.password = password |
| 267 | + return sig.(*tpm2.Signature), err |
| 268 | +} |
| 269 | + |
| 270 | +// Gets the x509.Certificate associated with this TPMv2Signer |
| 271 | +func (tpmv2Signer *TPMv2Signer) Certificate() (*x509.Certificate, error) { |
| 272 | + return tpmv2Signer.cert, nil |
| 273 | +} |
| 274 | + |
| 275 | +// Gets the certificate chain associated with this TPMv2Signer |
| 276 | +func (tpmv2Signer *TPMv2Signer) CertificateChain() (chain []*x509.Certificate, err error) { |
| 277 | + return tpmv2Signer.certChain, nil |
| 278 | +} |
| 279 | + |
| 280 | +/* |
| 281 | + * DER forbids storing a BOOLEAN as anything but 0x00 or 0xFF, |
| 282 | + * 0x01, and the Go asn1 parser cannot be relaxed. But both |
| 283 | + * OpenSSL ENGINEs which produce these keys have at least in |
| 284 | + * the past emitted 0x01 as the value, leading to an Unmarshal |
| 285 | + * failure with 'asn1: syntax error: invalid boolean'. So... |
| 286 | + */ |
| 287 | +func fixupEmptyAuth(tpmData *[]byte) { |
| 288 | + var pos int = 0 |
| 289 | + |
| 290 | + // Skip the SEQUENCE tag and length |
| 291 | + if len(*tpmData) < 2 || (*tpmData)[0] != 0x30 { |
| 292 | + return |
| 293 | + } |
| 294 | + |
| 295 | + // Don't care what the SEQUENCE length is, just skip it |
| 296 | + pos = 1 |
| 297 | + lenByte := (*tpmData)[pos] |
| 298 | + if lenByte < 0x80 { |
| 299 | + pos = pos + 1 |
| 300 | + } else if lenByte < 0x85 { |
| 301 | + pos = pos + 1 + int(lenByte) - 0x80 |
| 302 | + } else { |
| 303 | + return |
| 304 | + } |
| 305 | + |
| 306 | + if len(*tpmData) <= pos { |
| 307 | + return |
| 308 | + } |
| 309 | + |
| 310 | + // Use asn1.Unmarshal to eat the OID; we care about 'rest' |
| 311 | + var oid asn1.ObjectIdentifier |
| 312 | + rest, err := asn1.Unmarshal((*tpmData)[pos:], &oid) |
| 313 | + if err != nil || rest == nil || !oid.Equal(oidLoadableKey) || len(rest) < 5 { |
| 314 | + return |
| 315 | + } |
| 316 | + |
| 317 | + // If the OPTIONAL EXPLICIT BOOLEAN [0] exists, it'll be here |
| 318 | + pos = len(*tpmData) - len(rest) |
| 319 | + |
| 320 | + if (*tpmData)[pos] == 0xa0 && // Tag |
| 321 | + (*tpmData)[pos+1] == 0x03 && // length |
| 322 | + (*tpmData)[pos+2] == 0x01 && |
| 323 | + (*tpmData)[pos+3] == 0x01 && |
| 324 | + (*tpmData)[pos+4] == 0x01 { |
| 325 | + (*tpmData)[pos+4] = 0xff |
| 326 | + } |
| 327 | +} |
| 328 | + |
| 329 | +// Returns a TPMv2Signer, that can be used to sign a payload through a TPMv2-compatible |
| 330 | +// cryptographic device |
| 331 | +func GetTPMv2Signer(opts GetTPMv2SignerOpts) (signer Signer, signingAlgorithm string, err error) { |
| 332 | + var ( |
| 333 | + certificate *x509.Certificate |
| 334 | + certificateChain []*x509.Certificate |
| 335 | + keyPem *pem.Block |
| 336 | + password string |
| 337 | + emptyAuth bool |
| 338 | + tpmData tpm2_TPMKey |
| 339 | + handle tpmutil.Handle |
| 340 | + public tpm2.Public |
| 341 | + private []byte |
| 342 | + ) |
| 343 | + |
| 344 | + certificate = opts.certificate |
| 345 | + certificateChain = opts.certificateChain |
| 346 | + keyPem = opts.keyPem |
| 347 | + password = opts.password |
| 348 | + emptyAuth = opts.emptyAuth |
| 349 | + |
| 350 | + // If a handle is provided instead of a TPM key file |
| 351 | + if opts.handle != "" { |
| 352 | + handleParts := strings.Split(opts.handle, ":") |
| 353 | + if len(handleParts) != 2 { |
| 354 | + return nil, "", errors.New("invalid TPM handle format") |
| 355 | + } |
| 356 | + hexHandleStr := handleParts[1] |
| 357 | + if strings.HasPrefix(hexHandleStr, "0x") { |
| 358 | + hexHandleStr = hexHandleStr[2:] |
| 359 | + } |
| 360 | + handleValue, err := strconv.ParseUint(hexHandleStr, 16, 32) |
| 361 | + if err != nil { |
| 362 | + return nil, "", errors.New("invalid hex TPM handle value") |
| 363 | + } |
| 364 | + handle = tpmutil.Handle(handleValue) |
| 365 | + |
| 366 | + // Read the public key from the loaded key within the TPM |
| 367 | + rw, err := openTPM() |
| 368 | + if err != nil { |
| 369 | + return nil, "", err |
| 370 | + } |
| 371 | + defer rw.Close() |
| 372 | + |
| 373 | + public, _, _, err = tpm2.ReadPublic(rw, handle) |
| 374 | + if err != nil { |
| 375 | + return nil, "", err |
| 376 | + } |
| 377 | + } else { |
| 378 | + fixupEmptyAuth(&keyPem.Bytes) |
| 379 | + _, err = asn1.Unmarshal(keyPem.Bytes, &tpmData) |
| 380 | + if err != nil { |
| 381 | + return nil, "", err |
| 382 | + } |
| 383 | + |
| 384 | + emptyAuth = tpmData.EmptyAuth |
| 385 | + if emptyAuth && password != "" { |
| 386 | + return nil, "", errors.New("password is provided but TPM key file indicates that one isn't required") |
| 387 | + } |
| 388 | + |
| 389 | + if !tpmData.Oid.Equal(oidLoadableKey) { |
| 390 | + return nil, "", errors.New("invalid OID for TPMv2 key:" + tpmData.Oid.String()) |
| 391 | + } |
| 392 | + |
| 393 | + if tpmData.Policy != nil || tpmData.AuthPolicy != nil { |
| 394 | + return nil, "", errors.New("TPMv2 policy not implemented yet") |
| 395 | + } |
| 396 | + if tpmData.Secret != nil { |
| 397 | + return nil, "", errors.New("TPMv2 key has 'secret' field which should not be set") |
| 398 | + } |
| 399 | + |
| 400 | + if !handleIsPersistent(tpmData.Parent) && |
| 401 | + tpmData.Parent != int(tpm2.HandleOwner) && |
| 402 | + tpmData.Parent != int(tpm2.HandleNull) && |
| 403 | + tpmData.Parent != int(tpm2.HandleEndorsement) && |
| 404 | + tpmData.Parent != int(tpm2.HandlePlatform) { |
| 405 | + return nil, "", errors.New("invalid parent for TPMv2 key") |
| 406 | + } |
| 407 | + if len(tpmData.Pubkey) < 2 || |
| 408 | + len(tpmData.Pubkey)-2 != (int(tpmData.Pubkey[0])<<8)+int(tpmData.Pubkey[1]) { |
| 409 | + return nil, "", errors.New("invalid length for TPMv2 PUBLIC blob") |
| 410 | + } |
| 411 | + |
| 412 | + public, err = tpm2.DecodePublic(tpmData.Pubkey[2:]) |
| 413 | + if err != nil { |
| 414 | + return nil, "", err |
| 415 | + } |
| 416 | + |
| 417 | + if len(tpmData.Privkey) < 2 || |
| 418 | + len(tpmData.Privkey)-2 != (int(tpmData.Privkey[0])<<8)+int(tpmData.Privkey[1]) { |
| 419 | + return nil, "", errors.New("invalid length for TPMv2 PRIVATE blob") |
| 420 | + } |
| 421 | + private = tpmData.Privkey[2:] |
| 422 | + } |
| 423 | + |
| 424 | + switch public.Type { |
| 425 | + case tpm2.AlgRSA: |
| 426 | + signingAlgorithm = aws4_x509_rsa_sha256 |
| 427 | + case tpm2.AlgECC: |
| 428 | + signingAlgorithm = aws4_x509_ecdsa_sha256 |
| 429 | + default: |
| 430 | + return nil, "", errors.New("unsupported TPMv2 key type") |
| 431 | + } |
| 432 | + |
| 433 | + return &TPMv2Signer{ |
| 434 | + certificate, |
| 435 | + certificateChain, |
| 436 | + tpmData, |
| 437 | + public, |
| 438 | + private, |
| 439 | + password, |
| 440 | + emptyAuth, |
| 441 | + handle, |
| 442 | + }, |
| 443 | + signingAlgorithm, nil |
| 444 | +} |
0 commit comments