/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ package main import ( "bytes" "crypto" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/json" "encoding/pem" "errors" "flag" "fmt" "io/ioutil" "math/big" "net/http" "os" "strings" "time" "github.com/cloudflare/cfssl/log" cfocsp "github.com/cloudflare/cfssl/ocsp" "github.com/hashicorp/vault/api" "golang.org/x/crypto/ocsp" ) func main() { var pkiMount = flag.String("pkimount", "pki", "vault PKI mount to use") var serverAddr = flag.String("serverAddr", ":8080", "Server IP and Port to use") var responderCertFile = flag.String("responderCert", "", "OCSP responder signing certificate file") var responderKeyFile = flag.String("responderKey", "", "OCSP responder signing private key file") flag.Parse() if *responderKeyFile == "" || *responderCertFile == "" { log.Critical("You have to specify a responder key and certificate") flag.Usage() os.Exit(1) } responderCert, err := parseResponderCertificate(*responderCertFile) if err != nil { log.Criticalf("Error, no responder certificate: %v", err) os.Exit(1) } responderKey, err := parseResponderKey(*responderKeyFile) if err != nil { log.Criticalf("Error, no responder key: %v", err) os.Exit(1) } vaultSource, err := NewVaultSource(*pkiMount, responderCert, &responderKey, nil) if err != nil { log.Criticalf("vault source initialization failed: %v", err) os.Exit(1) } http.Handle("/", cfocsp.NewResponder(vaultSource, nil)) server := &http.Server{ Addr: *serverAddr, } if err := server.ListenAndServe(); err != nil { log.Criticalf("ListenAndServe failed: %v", err) } } func parseResponderKey(responderKeyFile string) (responderKey crypto.Signer, err error) { pemBytes, err := ioutil.ReadFile(responderKeyFile) if err != nil { err = fmt.Errorf("could not read responder key data: %v", err) return } pemBlock, _ := pem.Decode(pemBytes) if pemBlock == nil { err = errors.New("could not decode PEM data for responder key") return } responderKey, err = x509.ParsePKCS1PrivateKey(pemBlock.Bytes) if err != nil { err = fmt.Errorf("could not parse PKCS1 formatted RSA key: %v", err) return } return } func parseResponderCertificate(responderCertFile string) (responderCert *x509.Certificate, err error) { pemBytes, err := ioutil.ReadFile(responderCertFile) if err != nil { err = fmt.Errorf("could not read responder certificate data: %v", err) return } pemBlock, _ := pem.Decode(pemBytes) if pemBlock == nil { err = errors.New("could not decode PEM data for responder certificate") return } responderCert, err = x509.ParseCertificate(pemBlock.Bytes) if err != nil { err = fmt.Errorf("could not parse responder certificate: %v", err) return } return } type VaultSource struct { pkiMount string cached map[string][]byte vaultClient *api.Client caCertificate *x509.Certificate responderCertificate *x509.Certificate responderKey *crypto.Signer } func NewVaultSource(pkiMount string, responderCertificate *x509.Certificate, responderKey *crypto.Signer, config *api.Config) (*VaultSource, error) { client, err := api.NewClient(config) if err != nil { return nil, fmt.Errorf("error initializing vault client: %v", err) } vaultRequest := client.NewRequest(http.MethodGet, fmt.Sprintf("/v1/%s/ca", pkiMount)) vaultResponse, err := client.RawRequest(vaultRequest) if err != nil { return nil, fmt.Errorf("error getting CA certificate from vault: %v", err) } caCertificateBytes, err := ioutil.ReadAll(vaultResponse.Body) if err != nil { return nil, fmt.Errorf("could not read CA certificate data from vault: %v", err) } caCertificate, err := x509.ParseCertificate(caCertificateBytes) if err != nil { return nil, fmt.Errorf("could not parse CA certificate data from vault: %v", err) } log.Infof("Found CA certificate %v", caCertificate.Subject.CommonName) vaultSource := &VaultSource{ pkiMount: pkiMount, vaultClient: client, caCertificate: caCertificate, responderCertificate: responderCertificate, responderKey: responderKey, cached: make(map[string][]byte), } return vaultSource, nil } func (source VaultSource) buildCAHash(algorithm crypto.Hash) (issuerHash []byte, err error) { h := algorithm.New() var publicKeyInfo struct { Algorithm pkix.AlgorithmIdentifier PublicKey asn1.BitString } if _, err := asn1.Unmarshal(source.caCertificate.RawSubjectPublicKeyInfo, &publicKeyInfo); err != nil { log.Errorf("Error parsing CA certificate public key info: %v", err) return nil, err } h.Write(publicKeyInfo.PublicKey.RightAlign()) issuerHash = h.Sum(nil) return issuerHash, nil } func (source VaultSource) Response(request *ocsp.Request) ([]byte, http.Header, error) { caHash, err := source.buildCAHash(request.HashAlgorithm) if err != nil { return nil, nil, fmt.Errorf("error building CA certificate hash with algorithm %d: %v", request.HashAlgorithm, err) } if bytes.Compare(request.IssuerKeyHash, caHash) != 0 { return nil, nil, errors.New("request issuer key has does not match CA subject key hash") } cacheKey := request.SerialNumber.String() response, present := source.cached[cacheKey] if present { return response, nil, nil } vaultSerial := toVaultSerial(request.SerialNumber) log.Infof("OCSP request for serial %s\n", vaultSerial) vaultResponse, err := source.vaultClient.Logical().Read( fmt.Sprintf("%s/cert/%s", source.pkiMount, vaultSerial)) if err != nil { return nil, nil, fmt.Errorf("error reading certificate information for %s from vault", vaultSerial) } revocationTime, found := vaultResponse.Data["revocation_time"] if !found { // no revocation time in data return response, nil, nil } switch revocationTime.(type) { case json.Number: revTime, err := revocationTime.(json.Number).Int64() if err != nil { return nil, nil, errors.New("could not convert revocation time to int64 value") } if revTime != 0 { log.Infof("Certificate with serial number %s is revoked", vaultSerial) response, err = source.buildRevokedResponse(request.SerialNumber, time.Unix(revTime, 0)) if err != nil { return nil, nil, fmt.Errorf("could not build response %v", err) } source.cached[cacheKey] = response return response, nil, nil } certificateString, found := vaultResponse.Data["certificate"] if !found { // no certificate in data return response, nil, nil } certificateBytes, err := ioutil.ReadAll(strings.NewReader(certificateString.(string))) if err != nil { return nil, nil, fmt.Errorf("could not read certificate %v", err) } block, _ := pem.Decode(certificateBytes) if block == nil { return nil, nil, errors.New("could not decode PEM data") } certificate, err := x509.ParseCertificate(block.Bytes) if err != nil { return nil, nil, fmt.Errorf("could not parse certificate: %v", err) } if certificate.NotAfter.Before(time.Now()) { // certificate is expired, store unauthorized response in cache log.Infof("Certificate with serial %s expired at %s, returning unauthorized", vaultSerial, certificate.NotAfter) response = ocsp.UnauthorizedErrorResponse source.cached[cacheKey] = response } else { log.Infof("Certificate with serial %s is valid", vaultSerial) response, err = source.buildOkResponse(request.SerialNumber) if err != nil { return nil, nil, fmt.Errorf("could not build response %v", err) } } present = true } return response, nil, nil } func (source VaultSource) buildRevokedResponse(serialNumber *big.Int, revocationTime time.Time) ([]byte, error) { template := ocsp.Response{ SerialNumber: serialNumber, Status: ocsp.Revoked, ThisUpdate: time.Now(), Certificate: source.responderCertificate, } template.RevokedAt = revocationTime template.RevocationReason = ocsp.Unspecified return source.buildResponse(template) } func (source VaultSource) buildOkResponse(serialNumber *big.Int) (ocspResponse []byte, err error) { template := ocsp.Response{ SerialNumber: serialNumber, Status: ocsp.Good, ThisUpdate: time.Now(), NextUpdate: time.Now().Add(time.Hour), Certificate: source.responderCertificate, } return source.buildResponse(template) } func (source VaultSource) buildResponse(template ocsp.Response) (ocspResponse []byte, err error) { ocspResponse, err = ocsp.CreateResponse( source.caCertificate, source.responderCertificate, template, *source.responderKey) return } func toVaultSerial(serial *big.Int) string { vaultSerial := serial.Text(16) if len(vaultSerial)%2 != 0 { vaultSerial = "0" + vaultSerial } serialParts := make([]string, len(vaultSerial)/2) for i := 0; i < len(vaultSerial)/2; i++ { serialParts[i] = vaultSerial[i*2 : i*2+2] } vaultSerial = strings.Join(serialParts, "-") return vaultSerial }