diff --git a/charmcraft.yaml b/charmcraft.yaml index 7e9695f8..806b309d 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -41,27 +41,42 @@ resources: upstream-source: grafana/tempo:2.4.0 provides: - profiling-endpoint: - interface: parca_scrape grafana-dashboard: interface: grafana_dashboard + description: | + Forwards the built-in grafana dashboard(s) for monitoring Tempo. grafana-source: interface: grafana_datasource + description: | + Configures Grafana to be able to use this Tempo instance as a datasource. metrics-endpoint: interface: prometheus_scrape + description: | + Exposes the Prometheus metrics endpoint providing telemetry about the + Tempo instance. tracing: interface: tracing + description: | + Integration to offer other charms the possibility to send traces to Tempo. + requires: logging: interface: loki_push_api + description: | + Integration with Loki to push Tempo logs to the observability stack. ingress: interface: traefik_route - limit: 1 description: | Ingress integration for Tempo server and Tempo receiver endpoints, so that cross-model workloads can send their traces to Tempo through the ingress. Uses `traefik_route` to open ports on Traefik host for tracing ingesters. + certificates: + interface: tls-certificates + limit: 1 + description: | + Certificate and key files for securing Tempo internal and external + communications with TLS. storage: data: @@ -86,5 +101,7 @@ parts: charm: charm-binary-python-packages: - "pydantic>=2" + - "cryptography" + - "jsonschema" - "opentelemetry-exporter-otlp-proto-http==1.21.0" diff --git a/lib/charms/observability_libs/v1/cert_handler.py b/lib/charms/observability_libs/v1/cert_handler.py new file mode 100644 index 00000000..c482662f --- /dev/null +++ b/lib/charms/observability_libs/v1/cert_handler.py @@ -0,0 +1,581 @@ +# Copyright 2023 Canonical Ltd. +# See LICENSE file for licensing details. +"""## Overview. + +This document explains how to use the `CertHandler` class to +create and manage TLS certificates through the `tls_certificates` interface. + +The goal of the CertHandler is to provide a wrapper to the `tls_certificates` +library functions to make the charm integration smoother. + +## Library Usage + +This library should be used to create a `CertHandler` object, as per the +following example: + +```python +self.cert_handler = CertHandler( + charm=self, + key="my-app-cert-manager", + cert_subject="unit_name", # Optional +) +``` + +You can then observe the library's custom event and make use of the key and cert: +```python +self.framework.observe(self.cert_handler.on.cert_changed, self._on_server_cert_changed) + +container.push(keypath, self.cert_handler.private_key) +container.push(certpath, self.cert_handler.servert_cert) +``` + +Since this library uses [Juju Secrets](https://juju.is/docs/juju/secret) it requires Juju >= 3.0.3. +""" +import abc +import ipaddress +import json +import socket +from itertools import filterfalse +from typing import Dict, List, Optional, Union + +try: + from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore + AllCertificatesInvalidatedEvent, + CertificateAvailableEvent, + CertificateExpiringEvent, + CertificateInvalidatedEvent, + ProviderCertificate, + TLSCertificatesRequiresV3, + generate_csr, + generate_private_key, + ) +except ImportError as e: + raise ImportError( + "failed to import charms.tls_certificates_interface.v3.tls_certificates; " + "Either the library itself is missing (please get it through charmcraft fetch-lib) " + "or one of its dependencies is unmet." + ) from e + +import logging + +from ops.charm import CharmBase, RelationBrokenEvent +from ops.framework import EventBase, EventSource, Object, ObjectEvents +from ops.jujuversion import JujuVersion +from ops.model import Relation, Secret, SecretNotFoundError + +logger = logging.getLogger(__name__) + +LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" +LIBAPI = 1 +LIBPATCH = 8 + +VAULT_SECRET_LABEL = "cert-handler-private-vault" + + +def is_ip_address(value: str) -> bool: + """Return True if the input value is a valid IPv4 address; False otherwise.""" + try: + ipaddress.IPv4Address(value) + return True + except ipaddress.AddressValueError: + return False + + +class CertChanged(EventBase): + """Event raised when a cert is changed (becomes available or revoked).""" + + +class CertHandlerEvents(ObjectEvents): + """Events for CertHandler.""" + + cert_changed = EventSource(CertChanged) + + +class _VaultBackend(abc.ABC): + """Base class for a single secret manager. + + Assumptions: + - A single secret (label) is managed by a single instance. + - Secret is per-unit (not per-app, i.e. may differ from unit to unit). + """ + + def store(self, contents: Dict[str, str], clear: bool = False): ... + + def get_value(self, key: str) -> Optional[str]: ... + + def retrieve(self) -> Dict[str, str]: ... + + def clear(self): ... + + +class _RelationVaultBackend(_VaultBackend): + """Relation backend for Vault. + + Use it to store data in a relation databag. + Assumes that a single relation exists and its data is readable. + If not, it will raise RuntimeErrors as soon as you try to read/write. + It will store the data, in plaintext (json-dumped) nested under a configurable + key in the **unit databag** of this relation. + + Typically, you'll use this with peer relations. + + Note: it is assumed that this object has exclusive access to the data, even though in practice it does not. + Modifying relation data yourself would go unnoticed and disrupt consistency. + """ + + _NEST_UNDER = "lib.charms.observability_libs.v1.cert_handler::vault" + # This key needs to be relation-unique. If someone ever creates multiple Vault(_RelationVaultBackend) + # instances backed by the same (peer) relation, they'll need to set different _NEST_UNDERs + # for each _RelationVaultBackend instance or they'll be fighting over it. + + def __init__(self, charm: CharmBase, relation_name: str): + self.charm = charm + self.relation_name = relation_name + + def _check_ready(self): + relation = self.charm.model.get_relation(self.relation_name) + if not relation or not relation.data: + # if something goes wrong here, the peer-backed vault is not ready to operate + # it can be because you are trying to use it too soon, i.e. before the (peer) + # relation has been created (or has joined). + raise RuntimeError("Relation backend not ready.") + + @property + def _relation(self) -> Optional[Relation]: + self._check_ready() + return self.charm.model.get_relation(self.relation_name) + + @property + def _databag(self): + self._check_ready() + # _check_ready verifies that there is a relation + return self._relation.data[self.charm.unit] # type: ignore + + def _read(self) -> Dict[str, str]: + value = self._databag.get(self._NEST_UNDER) + if value: + return json.loads(value) + return {} + + def _write(self, value: Dict[str, str]): + if not all(isinstance(x, str) for x in value.values()): + # the caller has to take care of encoding + raise TypeError("You can only store strings in Vault.") + + self._databag[self._NEST_UNDER] = json.dumps(value) + + def store(self, contents: Dict[str, str], clear: bool = False): + """Create a new revision by updating the previous one with ``contents``.""" + current = self._read() + + if clear: + current.clear() + + current.update(contents) + self._write(current) + + def get_value(self, key: str) -> Optional[str]: + """Like retrieve, but single-value.""" + return self._read().get(key) + + def retrieve(self): + """Return the full vault content.""" + return self._read() + + def clear(self): + del self._databag[self._NEST_UNDER] + + +class _SecretVaultBackend(_VaultBackend): + """Relation backend for Vault. + + Use it to store data in a Juju secret. + Assumes that Juju supports secrets. + If not, it will raise some exception as soon as you try to read/write. + + Note: it is assumed that this object has exclusive access to the data, even though in practice it does not. + Modifying secret's data yourself would go unnoticed and disrupt consistency. + """ + + _uninitialized_key = "uninitialized-secret-key" + + def __init__(self, charm: CharmBase, secret_label: str): + self.charm = charm + self.secret_label = secret_label # needs to be charm-unique. + + @property + def _secret(self) -> Secret: + try: + # we are owners, so we don't need to grant it to ourselves + return self.charm.model.get_secret(label=self.secret_label) + except SecretNotFoundError: + # we need to set SOME contents when we're creating the secret, so we do it. + return self.charm.unit.add_secret( + {self._uninitialized_key: "42"}, label=self.secret_label + ) + + def store(self, contents: Dict[str, str], clear: bool = False): + """Create a new revision by updating the previous one with ``contents``.""" + secret = self._secret + current = secret.get_content(refresh=True) + + if clear: + current.clear() + elif current.get(self._uninitialized_key): + # is this the first revision? clean up the mock contents we created instants ago. + del current[self._uninitialized_key] + + current.update(contents) + secret.set_content(current) + + def get_value(self, key: str) -> Optional[str]: + """Like retrieve, but single-value.""" + return self._secret.get_content(refresh=True).get(key) + + def retrieve(self): + """Return the full vault content.""" + return self._secret.get_content(refresh=True) + + def clear(self): + self._secret.remove_all_revisions() + + +class Vault: + """Simple application secret wrapper for local usage.""" + + def __init__(self, backend: _VaultBackend): + self._backend = backend + + def store(self, contents: Dict[str, str], clear: bool = False): + """Store these contents in the vault overriding whatever is there.""" + self._backend.store(contents, clear=clear) + + def get_value(self, key: str): + """Like retrieve, but single-value.""" + return self._backend.get_value(key) + + def retrieve(self) -> Dict[str, str]: + """Return the full vault content.""" + return self._backend.retrieve() + + def clear(self): + """Clear the vault.""" + self._backend.clear() + + +class CertHandler(Object): + """A wrapper for the requirer side of the TLS Certificates charm library.""" + + on = CertHandlerEvents() # pyright: ignore + + def __init__( + self, + charm: CharmBase, + *, + key: str, + certificates_relation_name: str = "certificates", + cert_subject: Optional[str] = None, + sans: Optional[List[str]] = None, + ): + """CertHandler is used to wrap TLS Certificates management operations for charms. + + CerHandler manages one single cert. + + Args: + charm: The owning charm. + key: A manually-crafted, static, unique identifier used by ops to identify events. + It shouldn't change between one event to another. + certificates_relation_name: Must match metadata.yaml. + cert_subject: Custom subject. Name collisions are under the caller's responsibility. + sans: DNS names. If none are given, use FQDN. + """ + super().__init__(charm, key) + self.charm = charm + + # We need to sanitize the unit name, otherwise route53 complains: + # "urn:ietf:params:acme:error:malformed" :: Domain name contains an invalid character + self.cert_subject = charm.unit.name.replace("/", "-") if not cert_subject else cert_subject + + # Use fqdn only if no SANs were given, and drop empty/duplicate SANs + sans = list(set(filter(None, (sans or [socket.getfqdn()])))) + self.sans_ip = list(filter(is_ip_address, sans)) + self.sans_dns = list(filterfalse(is_ip_address, sans)) + + if self._check_juju_supports_secrets(): + vault_backend = _SecretVaultBackend(charm, secret_label=VAULT_SECRET_LABEL) + + # TODO: gracefully handle situations where the + # secret is gone because the admin has removed it manually + # self.framework.observe(self.charm.on.secret_remove, self._rotate_csr) + + else: + vault_backend = _RelationVaultBackend(charm, relation_name="peers") + self.vault = Vault(vault_backend) + + self.certificates_relation_name = certificates_relation_name + self.certificates = TLSCertificatesRequiresV3(self.charm, self.certificates_relation_name) + + self.framework.observe( + self.charm.on.config_changed, + self._on_config_changed, + ) + self.framework.observe( + self.charm.on[self.certificates_relation_name].relation_joined, # pyright: ignore + self._on_certificates_relation_joined, + ) + self.framework.observe( + self.certificates.on.certificate_available, # pyright: ignore + self._on_certificate_available, + ) + self.framework.observe( + self.certificates.on.certificate_expiring, # pyright: ignore + self._on_certificate_expiring, + ) + self.framework.observe( + self.certificates.on.certificate_invalidated, # pyright: ignore + self._on_certificate_invalidated, + ) + self.framework.observe( + self.certificates.on.all_certificates_invalidated, # pyright: ignore + self._on_all_certificates_invalidated, + ) + self.framework.observe( + self.charm.on[self.certificates_relation_name].relation_broken, # pyright: ignore + self._on_certificates_relation_broken, + ) + self.framework.observe( + self.charm.on.upgrade_charm, # pyright: ignore + self._on_upgrade_charm, + ) + + def _on_upgrade_charm(self, _): + has_privkey = self.vault.get_value("private-key") + + self._migrate_vault() + + # If we already have a csr, but the pre-migration vault has no privkey stored, + # the csr must have been signed with a privkey that is now outdated and utterly lost. + # So we throw away the csr and generate a new one (and a new privkey along with it). + if not has_privkey and self._csr: + logger.debug("CSR and privkey out of sync after charm upgrade. Renewing CSR.") + # this will call `self.private_key` which will generate a new privkey. + self._generate_csr(renew=True) + + def _migrate_vault(self): + peer_backend = _RelationVaultBackend(self.charm, relation_name="peers") + + if self._check_juju_supports_secrets(): + # we are on recent juju + if self.vault.retrieve(): + # we already were on recent juju: nothing to migrate + logger.debug( + "Private key is already stored as a juju secret. Skipping private key migration." + ) + return + + # we used to be on old juju: our secret stuff is in peer data + if contents := peer_backend.retrieve(): + logger.debug( + "Private key found in relation data. " + "Migrating private key to a juju secret." + ) + # move over to secret-backed storage + self.vault.store(contents) + + # clear the peer storage + peer_backend.clear() + return + + # if we are downgrading, i.e. from juju with secrets to juju without, + # we have lost all that was in the secrets backend. + + @property + def enabled(self) -> bool: + """Boolean indicating whether the charm has a tls_certificates relation.""" + # We need to check for units as a temporary workaround because of https://bugs.launchpad.net/juju/+bug/2024583 + # This could in theory not work correctly on scale down to 0 but it is necessary for the moment. + + if not self.charm.model.get_relation(self.certificates_relation_name): + return False + + if not self.charm.model.get_relation( + self.certificates_relation_name + ).units: # pyright: ignore + return False + + if not self.charm.model.get_relation( + self.certificates_relation_name + ).app: # pyright: ignore + return False + + if not self.charm.model.get_relation( + self.certificates_relation_name + ).data: # pyright: ignore + return False + + return True + + def _on_certificates_relation_joined(self, _) -> None: + # this will only generate a csr if we don't have one already + self._generate_csr() + + def _on_config_changed(self, _): + # this will only generate a csr if we don't have one already + self._generate_csr() + + @property + def relation(self): + """The "certificates" relation.""" + return self.charm.model.get_relation(self.certificates_relation_name) + + def _generate_csr( + self, overwrite: bool = False, renew: bool = False, clear_cert: bool = False + ): + """Request a CSR "creation" if renew is False, otherwise request a renewal. + + Without overwrite=True, the CSR would be created only once, even if calling the method + multiple times. This is useful needed because the order of peer-created and + certificates-joined is not predictable. + + This method intentionally does not emit any events, leave it for caller's responsibility. + """ + # if we are in a relation-broken hook, we might not have a relation to publish the csr to. + if not self.relation: + logger.warning( + f"No {self.certificates_relation_name!r} relation found. " f"Cannot generate csr." + ) + return + + # In case we already have a csr, do not overwrite it by default. + if overwrite or renew or not self._csr: + private_key = self.private_key + csr = generate_csr( + private_key=private_key.encode(), + subject=self.cert_subject, + sans_dns=self.sans_dns, + sans_ip=self.sans_ip, + ) + + if renew and self._csr: + self.certificates.request_certificate_renewal( + old_certificate_signing_request=self._csr.encode(), + new_certificate_signing_request=csr, + ) + else: + logger.info( + "Creating CSR for %s with DNS %s and IPs %s", + self.cert_subject, + self.sans_dns, + self.sans_ip, + ) + self.certificates.request_certificate_creation(certificate_signing_request=csr) + + if clear_cert: + self.vault.clear() + + def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: + """Emit cert-changed.""" + self.on.cert_changed.emit() # pyright: ignore + + @property + def private_key(self) -> str: + """Private key. + + BEWARE: if the vault misbehaves, the backing secret is removed, the peer relation dies + or whatever, we might be calling generate_private_key() again and cause a desync + with the CSR because it's going to be signed with an outdated key we have no way of retrieving. + The caller needs to ensure that if the vault backend gets reset, then so does the csr. + + TODO: we could consider adding a way to verify if the csr was signed by our privkey, + and do that on collect_unit_status as a consistency check + """ + private_key = self.vault.get_value("private-key") + if private_key is None: + private_key = generate_private_key().decode() + self.vault.store({"private-key": private_key}) + return private_key + + @property + def _csr(self) -> Optional[str]: + csrs = self.certificates.get_requirer_csrs() + if not csrs: + return None + + # in principle we only ever need one cert. + # we might want to complicate this a bit once we get into cert rotations: during the rotation, we may need to + # keep the old one around for a little while. If there's multiple certs, at the moment we're + # ignoring all but the last one. + if len(csrs) > 1: + logger.warning( + "Multiple CSRs found in `certificates` relation. " + "cert_handler is not ready to expect it." + ) + + return csrs[-1].csr + + def get_cert(self) -> Optional[ProviderCertificate]: + """Get the certificate from relation data.""" + all_certs = self.certificates.get_provider_certificates() + # search for the cert matching our csr. + matching_cert = [c for c in all_certs if c.csr == self._csr] + return matching_cert[0] if matching_cert else None + + @property + def ca_cert(self) -> Optional[str]: + """CA Certificate.""" + cert = self.get_cert() + return cert.ca if cert else None + + @property + def server_cert(self) -> Optional[str]: + """Server Certificate.""" + cert = self.get_cert() + return cert.certificate if cert else None + + @property + def chain(self) -> Optional[str]: + """Return the ca chain bundled as a single PEM string.""" + cert = self.get_cert() + return cert.chain_as_pem() if cert else None + + def _on_certificate_expiring( + self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] + ) -> None: + """Generate a new CSR and request certificate renewal.""" + if event.certificate == self.server_cert: + self._generate_csr(renew=True) + # FIXME why are we not emitting cert_changed here? + + def _certificate_revoked(self, event) -> None: + """Remove the certificate and generate a new CSR.""" + # Note: assuming "limit: 1" in metadata + if event.certificate == self.server_cert: + self._generate_csr(overwrite=True, clear_cert=True) + self.on.cert_changed.emit() # pyright: ignore + + def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: + """Deal with certificate revocation and expiration.""" + if event.certificate == self.server_cert: + # if event.reason in ("revoked", "expired"): + # Currently, the reason does not matter to us because the action is the same. + self._generate_csr(overwrite=True, clear_cert=True) + self.on.cert_changed.emit() # pyright: ignore + + def _on_all_certificates_invalidated(self, _: AllCertificatesInvalidatedEvent) -> None: + # Do what you want with this information, probably remove all certificates + # Note: assuming "limit: 1" in metadata + self._generate_csr(overwrite=True, clear_cert=True) + self.on.cert_changed.emit() # pyright: ignore + + def _on_certificates_relation_broken(self, _: RelationBrokenEvent) -> None: + """Clear all secrets data when removing the relation.""" + self.vault.clear() + self.on.cert_changed.emit() # pyright: ignore + + def _check_juju_supports_secrets(self) -> bool: + version = JujuVersion.from_environ() + if not JujuVersion(version=str(version)).has_secrets: + msg = f"Juju version {version} does not supports Secrets. Juju >= 3.0.3 is needed" + logger.debug(msg) + return False + return True diff --git a/lib/charms/tempo_k8s/v1/charm_tracing.py b/lib/charms/tempo_k8s/v1/charm_tracing.py index 39ebcd46..7c118856 100644 --- a/lib/charms/tempo_k8s/v1/charm_tracing.py +++ b/lib/charms/tempo_k8s/v1/charm_tracing.py @@ -126,14 +126,15 @@ def my_tracing_endpoint(self) -> Optional[str]: from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import Span, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import INVALID_SPAN, Tracer -from opentelemetry.trace import get_current_span as otlp_get_current_span from opentelemetry.trace import ( + INVALID_SPAN, + Tracer, get_tracer, get_tracer_provider, set_span_in_context, set_tracer_provider, ) +from opentelemetry.trace import get_current_span as otlp_get_current_span from ops.charm import CharmBase from ops.framework import Framework @@ -146,7 +147,7 @@ def my_tracing_endpoint(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 6 PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] diff --git a/lib/charms/tempo_k8s/v1/tracing.py b/lib/charms/tempo_k8s/v1/tracing.py index 2b09ee75..8a038528 100644 --- a/lib/charms/tempo_k8s/v1/tracing.py +++ b/lib/charms/tempo_k8s/v1/tracing.py @@ -93,7 +93,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 6 PYDEPS = ["pydantic>=2"] @@ -361,6 +361,8 @@ def __init__( Args: charm: a `CharmBase` instance that manages this instance of the Tempo service. + host: hostname. + ingesters: list of ingester protocols that are enabled on this endpoint. relation_name: an optional string name of the relation between `charm` and the Tempo charmed service. The default is "tracing". diff --git a/lib/charms/tempo_k8s/v2/tracing.py b/lib/charms/tempo_k8s/v2/tracing.py index 8a125fe8..9e5defc8 100644 --- a/lib/charms/tempo_k8s/v2/tracing.py +++ b/lib/charms/tempo_k8s/v2/tracing.py @@ -12,7 +12,7 @@ object from this charm library. For the simplest use cases, using the `TracingEndpointRequirer` object only requires instantiating it, typically in the constructor of your charm. The `TracingEndpointRequirer` constructor requires the name of the relation over which a tracing endpoint - is exposed by the Tempo charm, and a list of protocols it intends to send traces with. + is exposed by the Tempo charm, and a list of protocols it intends to send traces with. This relation must use the `tracing` interface. The `TracingEndpointRequirer` object may be instantiated as follows @@ -21,7 +21,7 @@ def __init__(self, *args): super().__init__(*args) # ... - self.tracing = TracingEndpointRequirer(self, + self.tracing = TracingEndpointRequirer(self, protocols=['otlp_grpc', 'otlp_http', 'jaeger_http_thrift'] ) # ... @@ -29,20 +29,20 @@ def __init__(self, *args): Note that the first argument (`self`) to `TracingEndpointRequirer` is always a reference to the parent charm. -Alternatively to providing the list of requested protocols at init time, the charm can do it at -any point in time by calling the -`TracingEndpointRequirer.request_protocols(*protocol:str, relation:Optional[Relation])` method. +Alternatively to providing the list of requested protocols at init time, the charm can do it at +any point in time by calling the +`TracingEndpointRequirer.request_protocols(*protocol:str, relation:Optional[Relation])` method. Using this method also allows you to use per-relation protocols. -Units of provider charms obtain the tempo endpoint to which they will push their traces by calling +Units of provider charms obtain the tempo endpoint to which they will push their traces by calling `TracingEndpointRequirer.get_endpoint(protocol: str)`, where `protocol` is, for example: - `otlp_grpc` - `otlp_http` - `zipkin` - `tempo` -If the `protocol` is not in the list of protocols that the charm requested at endpoint set-up time, -the library will raise an error. +If the `protocol` is not in the list of protocols that the charm requested at endpoint set-up time, +the library will raise an error. ## Requirer Library Usage @@ -104,7 +104,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 4 +LIBPATCH = 5 PYDEPS = ["pydantic"] @@ -509,6 +509,7 @@ def __init__( if an ingress is present. relation_name: an optional string name of the relation between `charm` and the Tempo charmed service. The default is "tracing". + internal_scheme: scheme to use with internal urls. Raises: RelationNotFoundError: If there is no relation in the charm's metadata.yaml @@ -595,7 +596,7 @@ def publish_receivers(self, receivers: Sequence[RawReceiver]): try: TracingProviderAppData( host=self._host, - external_url=f"http://{self._external_url}" if self._external_url else None, + external_url=self._external_url or None, receivers=[ Receiver(port=port, protocol=protocol) for protocol, port in receivers ], @@ -830,11 +831,8 @@ def _get_endpoint( if app_data.external_url: url = f"{app_data.external_url}:{receiver.port}" else: - if app_data.internal_scheme: - url = f"{app_data.internal_scheme}://{app_data.host}:{receiver.port}" - else: - # if we didn't receive a scheme (old provider), we assume HTTP is used - url = f"http://{app_data.host}:{receiver.port}" + # if we didn't receive a scheme (old provider), we assume HTTP is used + url = f"{app_data.internal_scheme or 'http'}://{app_data.host}:{receiver.port}" if receiver.protocol.endswith("grpc"): # TCP protocols don't want an http/https scheme prefix diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py new file mode 100644 index 00000000..6fa26397 --- /dev/null +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -0,0 +1,1999 @@ +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + + +"""Library for the tls-certificates relation. + +This library contains the Requires and Provides classes for handling the tls-certificates +interface. + +Pre-requisites: + - Juju >= 3.0 + +## Getting Started +From a charm directory, fetch the library using `charmcraft`: + +```shell +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates +``` + +Add the following libraries to the charm's `requirements.txt` file: +- jsonschema +- cryptography >= 42.0.0 + +Add the following section to the charm's `charmcraft.yaml` file: +```yaml +parts: + charm: + build-packages: + - libffi-dev + - libssl-dev + - rustc + - cargo +``` + +### Provider charm +The provider charm is the charm providing certificates to another charm that requires them. In +this example, the provider charm is storing its private key using a peer relation interface called +`replicas`. + +Example: +```python +from charms.tls_certificates_interface.v3.tls_certificates import ( + CertificateCreationRequestEvent, + CertificateRevocationRequestEvent, + TLSCertificatesProvidesV3, + generate_private_key, +) +from ops.charm import CharmBase, InstallEvent +from ops.main import main +from ops.model import ActiveStatus, WaitingStatus + + +def generate_ca(private_key: bytes, subject: str) -> str: + return "whatever ca content" + + +def generate_certificate(ca: str, private_key: str, csr: str) -> str: + return "Whatever certificate" + + +class ExampleProviderCharm(CharmBase): + + def __init__(self, *args): + super().__init__(*args) + self.certificates = TLSCertificatesProvidesV3(self, "certificates") + self.framework.observe( + self.certificates.on.certificate_request, + self._on_certificate_request + ) + self.framework.observe( + self.certificates.on.certificate_revocation_request, + self._on_certificate_revocation_request + ) + self.framework.observe(self.on.install, self._on_install) + + def _on_install(self, event: InstallEvent) -> None: + private_key_password = b"banana" + private_key = generate_private_key(password=private_key_password) + ca_certificate = generate_ca(private_key=private_key, subject="whatever") + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + replicas_relation.data[self.app].update( + { + "private_key_password": "banana", + "private_key": private_key, + "ca_certificate": ca_certificate, + } + ) + self.unit.status = ActiveStatus() + + def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + ca_certificate = replicas_relation.data[self.app].get("ca_certificate") + private_key = replicas_relation.data[self.app].get("private_key") + certificate = generate_certificate( + ca=ca_certificate, + private_key=private_key, + csr=event.certificate_signing_request, + ) + + self.certificates.set_relation_certificate( + certificate=certificate, + certificate_signing_request=event.certificate_signing_request, + ca=ca_certificate, + chain=[ca_certificate, certificate], + relation_id=event.relation_id, + recommended_expiry_notification_time=720, + ) + + def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: + # Do what you want to do with this information + pass + + +if __name__ == "__main__": + main(ExampleProviderCharm) +``` + +### Requirer charm +The requirer charm is the charm requiring certificates from another charm that provides them. In +this example, the requirer charm is storing its certificates using a peer relation interface called +`replicas`. + +Example: +```python +from charms.tls_certificates_interface.v3.tls_certificates import ( + CertificateAvailableEvent, + CertificateExpiringEvent, + CertificateRevokedEvent, + TLSCertificatesRequiresV3, + generate_csr, + generate_private_key, +) +from ops.charm import CharmBase, RelationCreatedEvent +from ops.main import main +from ops.model import ActiveStatus, WaitingStatus +from typing import Union + + +class ExampleRequirerCharm(CharmBase): + + def __init__(self, *args): + super().__init__(*args) + self.cert_subject = "whatever" + self.certificates = TLSCertificatesRequiresV3(self, "certificates") + self.framework.observe(self.on.install, self._on_install) + self.framework.observe( + self.on.certificates_relation_created, self._on_certificates_relation_created + ) + self.framework.observe( + self.certificates.on.certificate_available, self._on_certificate_available + ) + self.framework.observe( + self.certificates.on.certificate_expiring, self._on_certificate_expiring + ) + self.framework.observe( + self.certificates.on.certificate_invalidated, self._on_certificate_invalidated + ) + self.framework.observe( + self.certificates.on.all_certificates_invalidated, + self._on_all_certificates_invalidated + ) + + def _on_install(self, event) -> None: + private_key_password = b"banana" + private_key = generate_private_key(password=private_key_password) + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + replicas_relation.data[self.app].update( + {"private_key_password": "banana", "private_key": private_key.decode()} + ) + + def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + private_key_password = replicas_relation.data[self.app].get("private_key_password") + private_key = replicas_relation.data[self.app].get("private_key") + csr = generate_csr( + private_key=private_key.encode(), + private_key_password=private_key_password.encode(), + subject=self.cert_subject, + ) + replicas_relation.data[self.app].update({"csr": csr.decode()}) + self.certificates.request_certificate_creation(certificate_signing_request=csr) + + def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + replicas_relation.data[self.app].update({"certificate": event.certificate}) + replicas_relation.data[self.app].update({"ca": event.ca}) + replicas_relation.data[self.app].update({"chain": event.chain}) + self.unit.status = ActiveStatus() + + def _on_certificate_expiring( + self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] + ) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + old_csr = replicas_relation.data[self.app].get("csr") + private_key_password = replicas_relation.data[self.app].get("private_key_password") + private_key = replicas_relation.data[self.app].get("private_key") + new_csr = generate_csr( + private_key=private_key.encode(), + private_key_password=private_key_password.encode(), + subject=self.cert_subject, + ) + self.certificates.request_certificate_renewal( + old_certificate_signing_request=old_csr, + new_certificate_signing_request=new_csr, + ) + replicas_relation.data[self.app].update({"csr": new_csr.decode()}) + + def _certificate_revoked(self) -> None: + old_csr = replicas_relation.data[self.app].get("csr") + private_key_password = replicas_relation.data[self.app].get("private_key_password") + private_key = replicas_relation.data[self.app].get("private_key") + new_csr = generate_csr( + private_key=private_key.encode(), + private_key_password=private_key_password.encode(), + subject=self.cert_subject, + ) + self.certificates.request_certificate_renewal( + old_certificate_signing_request=old_csr, + new_certificate_signing_request=new_csr, + ) + replicas_relation.data[self.app].update({"csr": new_csr.decode()}) + replicas_relation.data[self.app].pop("certificate") + replicas_relation.data[self.app].pop("ca") + replicas_relation.data[self.app].pop("chain") + self.unit.status = WaitingStatus("Waiting for new certificate") + + def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + if event.reason == "revoked": + self._certificate_revoked() + if event.reason == "expired": + self._on_certificate_expiring(event) + + def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: + # Do what you want with this information, probably remove all certificates. + pass + + +if __name__ == "__main__": + main(ExampleRequirerCharm) +``` + +You can relate both charms by running: + +```bash +juju relate +``` + +""" # noqa: D405, D410, D411, D214, D416 + +import copy +import json +import logging +import uuid +from contextlib import suppress +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from ipaddress import IPv4Address +from typing import List, Literal, Optional, Union + +from cryptography import x509 +from cryptography.hazmat._oid import ExtensionOID +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from jsonschema import exceptions, validate +from ops.charm import ( + CharmBase, + CharmEvents, + RelationBrokenEvent, + RelationChangedEvent, + SecretExpiredEvent, +) +from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + SecretNotFoundError, + Unit, +) + +# The unique Charmhub library identifier, never change it +LIBID = "afd8c2bccf834997afce12c2706d2ede" + +# Increment this major API version when introducing breaking changes +LIBAPI = 3 + +# Increment this PATCH version before using `charmcraft publish-lib` or reset +# to 0 if you are raising the major API version +LIBPATCH = 13 + +PYDEPS = ["cryptography", "jsonschema"] + +REQUIRER_JSON_SCHEMA = { + "$schema": "http://json-schema.org/draft-04/schema#", + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", + "type": "object", + "title": "`tls_certificates` requirer root schema", + "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 + "examples": [ + { + "certificate_signing_requests": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\\nAQEBBQADggEPADCCAQoCggEBANWlx9wE6cW7Jkb4DZZDOZoEjk1eDBMJ+8R4pyKp\\nFBeHMl1SQSDt6rAWsrfL3KOGiIHqrRY0B5H6c51L8LDuVrJG0bPmyQ6rsBo3gVke\\nDSivfSLtGvHtp8lwYnIunF8r858uYmblAR0tdXQNmnQvm+6GERvURQ6sxpgZ7iLC\\npPKDoPt+4GKWL10FWf0i82FgxWC2KqRZUtNbgKETQuARLig7etBmCnh20zmynorA\\ncY7vrpTPAaeQpGLNqqYvKV9W6yWVY08V+nqARrFrjk3vSioZSu8ZJUdZ4d9++SGl\\nbH7A6e77YDkX9i/dQ3Pa/iDtWO3tXS2MvgoxX1iSWlGNOHcCAwEAAaAAMA0GCSqG\\nSIb3DQEBCwUAA4IBAQCW1fKcHessy/ZhnIwAtSLznZeZNH8LTVOzkhVd4HA7EJW+\\nKVLBx8DnN7L3V2/uPJfHiOg4Rx7fi7LkJPegl3SCqJZ0N5bQS/KvDTCyLG+9E8Y+\\n7wqCmWiXaH1devimXZvazilu4IC2dSks2D8DPWHgsOdVks9bme8J3KjdNMQudegc\\newWZZ1Dtbd+Rn7cpKU3jURMwm4fRwGxbJ7iT5fkLlPBlyM/yFEik4SmQxFYrZCQg\\n0f3v4kBefTh5yclPy5tEH+8G0LMsbbo3dJ5mPKpAShi0QEKDLd7eR1R/712lYTK4\\ndi4XaEfqERgy68O4rvb4PGlJeRGS7AmL7Ss8wfAq\\n-----END CERTIFICATE REQUEST-----\\n" # noqa: E501 + }, + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\\nAQEBBQADggEPADCCAQoCggEBAMk3raaX803cHvzlBF9LC7KORT46z4VjyU5PIaMb\\nQLIDgYKFYI0n5hf2Ra4FAHvOvEmW7bjNlHORFEmvnpcU5kPMNUyKFMTaC8LGmN8z\\nUBH3aK+0+FRvY4afn9tgj5435WqOG9QdoDJ0TJkjJbJI9M70UOgL711oU7ql6HxU\\n4d2ydFK9xAHrBwziNHgNZ72L95s4gLTXf0fAHYf15mDA9U5yc+YDubCKgTXzVySQ\\nUx73VCJLfC/XkZIh559IrnRv5G9fu6BMLEuBwAz6QAO4+/XidbKWN4r2XSq5qX4n\\n6EPQQWP8/nd4myq1kbg6Q8w68L/0YdfjCmbyf2TuoWeImdUCAwEAAaAAMA0GCSqG\\nSIb3DQEBCwUAA4IBAQBIdwraBvpYo/rl5MH1+1Um6HRg4gOdQPY5WcJy9B9tgzJz\\nittRSlRGTnhyIo6fHgq9KHrmUthNe8mMTDailKFeaqkVNVvk7l0d1/B90Kz6OfmD\\nxN0qjW53oP7y3QB5FFBM8DjqjmUnz5UePKoX4AKkDyrKWxMwGX5RoET8c/y0y9jp\\nvSq3Wh5UpaZdWbe1oVY8CqMVUEVQL2DPjtopxXFz2qACwsXkQZxWmjvZnRiP8nP8\\nbdFaEuh9Q6rZ2QdZDEtrU4AodPU3NaukFr5KlTUQt3w/cl+5//zils6G5zUWJ2pN\\ng7+t9PTvXHRkH+LnwaVnmsBFU2e05qADQbfIn7JA\\n-----END CERTIFICATE REQUEST-----\\n" # noqa: E501 + }, + ] + } + ], + "properties": { + "certificate_signing_requests": { + "type": "array", + "items": { + "type": "object", + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, + "required": ["certificate_signing_request"], + }, + } + }, + "required": ["certificate_signing_requests"], + "additionalProperties": True, +} + +PROVIDER_JSON_SCHEMA = { + "$schema": "http://json-schema.org/draft-04/schema#", + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", + "type": "object", + "title": "`tls_certificates` provider root schema", + "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 + "examples": [ + { + "certificates": [ + { + "ca": "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n", # noqa: E501 + "chain": [ + "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n" # noqa: E501, W505 + ], + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\nAQEBBQADggEPADCCAQoCggEBANWlx9wE6cW7Jkb4DZZDOZoEjk1eDBMJ+8R4pyKp\nFBeHMl1SQSDt6rAWsrfL3KOGiIHqrRY0B5H6c51L8LDuVrJG0bPmyQ6rsBo3gVke\nDSivfSLtGvHtp8lwYnIunF8r858uYmblAR0tdXQNmnQvm+6GERvURQ6sxpgZ7iLC\npPKDoPt+4GKWL10FWf0i82FgxWC2KqRZUtNbgKETQuARLig7etBmCnh20zmynorA\ncY7vrpTPAaeQpGLNqqYvKV9W6yWVY08V+nqARrFrjk3vSioZSu8ZJUdZ4d9++SGl\nbH7A6e77YDkX9i/dQ3Pa/iDtWO3tXS2MvgoxX1iSWlGNOHcCAwEAAaAAMA0GCSqG\nSIb3DQEBCwUAA4IBAQCW1fKcHessy/ZhnIwAtSLznZeZNH8LTVOzkhVd4HA7EJW+\nKVLBx8DnN7L3V2/uPJfHiOg4Rx7fi7LkJPegl3SCqJZ0N5bQS/KvDTCyLG+9E8Y+\n7wqCmWiXaH1devimXZvazilu4IC2dSks2D8DPWHgsOdVks9bme8J3KjdNMQudegc\newWZZ1Dtbd+Rn7cpKU3jURMwm4fRwGxbJ7iT5fkLlPBlyM/yFEik4SmQxFYrZCQg\n0f3v4kBefTh5yclPy5tEH+8G0LMsbbo3dJ5mPKpAShi0QEKDLd7eR1R/712lYTK4\ndi4XaEfqERgy68O4rvb4PGlJeRGS7AmL7Ss8wfAq\n-----END CERTIFICATE REQUEST-----\n", # noqa: E501 + "certificate": "-----BEGIN CERTIFICATE-----\nMIICvDCCAaQCFFPAOD7utDTsgFrm0vS4We18OcnKMA0GCSqGSIb3DQEBCwUAMCAx\nCzAJBgNVBAYTAlVTMREwDwYDVQQDDAh3aGF0ZXZlcjAeFw0yMjA3MjkyMTE5Mzha\nFw0yMzA3MjkyMTE5MzhaMBUxEzARBgNVBAMMCmJhbmFuYS5jb20wggEiMA0GCSqG\nSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDVpcfcBOnFuyZG+A2WQzmaBI5NXgwTCfvE\neKciqRQXhzJdUkEg7eqwFrK3y9yjhoiB6q0WNAeR+nOdS/Cw7layRtGz5skOq7Aa\nN4FZHg0or30i7Rrx7afJcGJyLpxfK/OfLmJm5QEdLXV0DZp0L5vuhhEb1EUOrMaY\nGe4iwqTyg6D7fuBili9dBVn9IvNhYMVgtiqkWVLTW4ChE0LgES4oO3rQZgp4dtM5\nsp6KwHGO766UzwGnkKRizaqmLylfVusllWNPFfp6gEaxa45N70oqGUrvGSVHWeHf\nfvkhpWx+wOnu+2A5F/Yv3UNz2v4g7Vjt7V0tjL4KMV9YklpRjTh3AgMBAAEwDQYJ\nKoZIhvcNAQELBQADggEBAChjRzuba8zjQ7NYBVas89Oy7u++MlS8xWxh++yiUsV6\nWMk3ZemsPtXc1YmXorIQohtxLxzUPm2JhyzFzU/sOLmJQ1E/l+gtZHyRCwsb20fX\nmphuJsMVd7qv/GwEk9PBsk2uDqg4/Wix0Rx5lf95juJP7CPXQJl5FQauf3+LSz0y\nwF/j+4GqvrwsWr9hKOLmPdkyKkR6bHKtzzsxL9PM8GnElk2OpaPMMnzbL/vt2IAt\nxK01ZzPxCQCzVwHo5IJO5NR/fIyFbEPhxzG17QsRDOBR9fl9cOIvDeSO04vyZ+nz\n+kA2c3fNrZFAtpIlOOmFh8Q12rVL4sAjI5mVWnNEgvI=\n-----END CERTIFICATE-----\n", # noqa: E501 + } + ] + }, + { + "certificates": [ + { + "ca": "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n", # noqa: E501 + "chain": [ + "-----BEGIN CERTIFICATE-----\\nMIIDJTCCAg2gAwIBAgIUMsSK+4FGCjW6sL/EXMSxColmKw8wDQYJKoZIhvcNAQEL\\nBQAwIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdoYXRldmVyMB4XDTIyMDcyOTIx\\nMTgyN1oXDTIzMDcyOTIxMTgyN1owIDELMAkGA1UEBhMCVVMxETAPBgNVBAMMCHdo\\nYXRldmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA55N9DkgFWbJ/\\naqcdQhso7n1kFvt6j/fL1tJBvRubkiFMQJnZFtekfalN6FfRtA3jq+nx8o49e+7t\\nLCKT0xQ+wufXfOnxv6/if6HMhHTiCNPOCeztUgQ2+dfNwRhYYgB1P93wkUVjwudK\\n13qHTTZ6NtEF6EzOqhOCe6zxq6wrr422+ZqCvcggeQ5tW9xSd/8O1vNID/0MTKpy\\nET3drDtBfHmiUEIBR3T3tcy6QsIe4Rz/2sDinAcM3j7sG8uY6drh8jY3PWar9til\\nv2l4qDYSU8Qm5856AB1FVZRLRJkLxZYZNgreShAIYgEd0mcyI2EO/UvKxsIcxsXc\\nd45GhGpKkwIDAQABo1cwVTAfBgNVHQ4EGAQWBBRXBrXKh3p/aFdQjUcT/UcvICBL\\nODAhBgNVHSMEGjAYgBYEFFcGtcqHen9oV1CNRxP9Ry8gIEs4MA8GA1UdEwEB/wQF\\nMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAGmCEvcoFUrT9e133SHkgF/ZAgzeIziO\\nBjfAdU4fvAVTVfzaPm0yBnGqzcHyacCzbZjKQpaKVgc5e6IaqAQtf6cZJSCiJGhS\\nJYeosWrj3dahLOUAMrXRr8G/Ybcacoqc+osKaRa2p71cC3V6u2VvcHRV7HDFGJU7\\noijbdB+WhqET6Txe67rxZCJG9Ez3EOejBJBl2PJPpy7m1Ml4RR+E8YHNzB0lcBzc\\nEoiJKlDfKSO14E2CPDonnUoWBJWjEvJys3tbvKzsRj2fnLilytPFU0gH3cEjCopi\\nzFoWRdaRuNHYCqlBmso1JFDl8h4fMmglxGNKnKRar0WeGyxb4xXBGpI=\\n-----END CERTIFICATE-----\\n" # noqa: E501, W505 + ], + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----\nMIICWjCCAUICAQAwFTETMBEGA1UEAwwKYmFuYW5hLmNvbTCCASIwDQYJKoZIhvcN\nAQEBBQADggEPADCCAQoCggEBANWlx9wE6cW7Jkb4DZZDOZoEjk1eDBMJ+8R4pyKp\nFBeHMl1SQSDt6rAWsrfL3KOGiIHqrRY0B5H6c51L8LDuVrJG0bPmyQ6rsBo3gVke\nDSivfSLtGvHtp8lwYnIunF8r858uYmblAR0tdXQNmnQvm+6GERvURQ6sxpgZ7iLC\npPKDoPt+4GKWL10FWf0i82FgxWC2KqRZUtNbgKETQuARLig7etBmCnh20zmynorA\ncY7vrpTPAaeQpGLNqqYvKV9W6yWVY08V+nqARrFrjk3vSioZSu8ZJUdZ4d9++SGl\nbH7A6e77YDkX9i/dQ3Pa/iDtWO3tXS2MvgoxX1iSWlGNOHcCAwEAAaAAMA0GCSqG\nSIb3DQEBCwUAA4IBAQCW1fKcHessy/ZhnIwAtSLznZeZNH8LTVOzkhVd4HA7EJW+\nKVLBx8DnN7L3V2/uPJfHiOg4Rx7fi7LkJPegl3SCqJZ0N5bQS/KvDTCyLG+9E8Y+\n7wqCmWiXaH1devimXZvazilu4IC2dSks2D8DPWHgsOdVks9bme8J3KjdNMQudegc\newWZZ1Dtbd+Rn7cpKU3jURMwm4fRwGxbJ7iT5fkLlPBlyM/yFEik4SmQxFYrZCQg\n0f3v4kBefTh5yclPy5tEH+8G0LMsbbo3dJ5mPKpAShi0QEKDLd7eR1R/712lYTK4\ndi4XaEfqERgy68O4rvb4PGlJeRGS7AmL7Ss8wfAq\n-----END CERTIFICATE REQUEST-----\n", # noqa: E501 + "certificate": "-----BEGIN CERTIFICATE-----\nMIICvDCCAaQCFFPAOD7utDTsgFrm0vS4We18OcnKMA0GCSqGSIb3DQEBCwUAMCAx\nCzAJBgNVBAYTAlVTMREwDwYDVQQDDAh3aGF0ZXZlcjAeFw0yMjA3MjkyMTE5Mzha\nFw0yMzA3MjkyMTE5MzhaMBUxEzARBgNVBAMMCmJhbmFuYS5jb20wggEiMA0GCSqG\nSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDVpcfcBOnFuyZG+A2WQzmaBI5NXgwTCfvE\neKciqRQXhzJdUkEg7eqwFrK3y9yjhoiB6q0WNAeR+nOdS/Cw7layRtGz5skOq7Aa\nN4FZHg0or30i7Rrx7afJcGJyLpxfK/OfLmJm5QEdLXV0DZp0L5vuhhEb1EUOrMaY\nGe4iwqTyg6D7fuBili9dBVn9IvNhYMVgtiqkWVLTW4ChE0LgES4oO3rQZgp4dtM5\nsp6KwHGO766UzwGnkKRizaqmLylfVusllWNPFfp6gEaxa45N70oqGUrvGSVHWeHf\nfvkhpWx+wOnu+2A5F/Yv3UNz2v4g7Vjt7V0tjL4KMV9YklpRjTh3AgMBAAEwDQYJ\nKoZIhvcNAQELBQADggEBAChjRzuba8zjQ7NYBVas89Oy7u++MlS8xWxh++yiUsV6\nWMk3ZemsPtXc1YmXorIQohtxLxzUPm2JhyzFzU/sOLmJQ1E/l+gtZHyRCwsb20fX\nmphuJsMVd7qv/GwEk9PBsk2uDqg4/Wix0Rx5lf95juJP7CPXQJl5FQauf3+LSz0y\nwF/j+4GqvrwsWr9hKOLmPdkyKkR6bHKtzzsxL9PM8GnElk2OpaPMMnzbL/vt2IAt\nxK01ZzPxCQCzVwHo5IJO5NR/fIyFbEPhxzG17QsRDOBR9fl9cOIvDeSO04vyZ+nz\n+kA2c3fNrZFAtpIlOOmFh8Q12rVL4sAjI5mVWnNEgvI=\n-----END CERTIFICATE-----\n", # noqa: E501 + "revoked": True, + } + ] + }, + ], + "properties": { + "certificates": { + "$id": "#/properties/certificates", + "type": "array", + "items": { + "$id": "#/properties/certificates/items", + "type": "object", + "required": ["certificate_signing_request", "certificate", "ca", "chain"], + "properties": { + "certificate_signing_request": { + "$id": "#/properties/certificates/items/certificate_signing_request", + "type": "string", + }, + "certificate": { + "$id": "#/properties/certificates/items/certificate", + "type": "string", + }, + "ca": {"$id": "#/properties/certificates/items/ca", "type": "string"}, + "chain": { + "$id": "#/properties/certificates/items/chain", + "type": "array", + "items": { + "type": "string", + "$id": "#/properties/certificates/items/chain/items", + }, + }, + "revoked": { + "$id": "#/properties/certificates/items/revoked", + "type": "boolean", + }, + }, + "additionalProperties": True, + }, + } + }, + "required": ["certificates"], + "additionalProperties": True, +} + + +logger = logging.getLogger(__name__) + + +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + + +class CertificateAvailableEvent(EventBase): + """Charm Event triggered when a TLS certificate is available.""" + + def __init__( + self, + handle: Handle, + certificate: str, + certificate_signing_request: str, + ca: str, + chain: List[str], + ): + super().__init__(handle) + self.certificate = certificate + self.certificate_signing_request = certificate_signing_request + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate": self.certificate, + "certificate_signing_request": self.certificate_signing_request, + "ca": self.ca, + "chain": self.chain, + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = snapshot["certificate"] + self.certificate_signing_request = snapshot["certificate_signing_request"] + self.ca = snapshot["ca"] + self.chain = snapshot["chain"] + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + +class CertificateExpiringEvent(EventBase): + """Charm Event triggered when a TLS certificate is almost expired.""" + + def __init__(self, handle, certificate: str, expiry: str): + """CertificateExpiringEvent. + + Args: + handle (Handle): Juju framework handle + certificate (str): TLS Certificate + expiry (str): Datetime string representing the time at which the certificate + won't be valid anymore. + """ + super().__init__(handle) + self.certificate = certificate + self.expiry = expiry + + def snapshot(self) -> dict: + """Return snapshot.""" + return {"certificate": self.certificate, "expiry": self.expiry} + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = snapshot["certificate"] + self.expiry = snapshot["expiry"] + + +class CertificateInvalidatedEvent(EventBase): + """Charm Event triggered when a TLS certificate is invalidated.""" + + def __init__( + self, + handle: Handle, + reason: Literal["expired", "revoked"], + certificate: str, + certificate_signing_request: str, + ca: str, + chain: List[str], + ): + super().__init__(handle) + self.reason = reason + self.certificate_signing_request = certificate_signing_request + self.certificate = certificate + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "reason": self.reason, + "certificate_signing_request": self.certificate_signing_request, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.reason = snapshot["reason"] + self.certificate_signing_request = snapshot["certificate_signing_request"] + self.certificate = snapshot["certificate"] + self.ca = snapshot["ca"] + self.chain = snapshot["chain"] + + +class AllCertificatesInvalidatedEvent(EventBase): + """Charm Event triggered when all TLS certificates are invalidated.""" + + def __init__(self, handle: Handle): + super().__init__(handle) + + def snapshot(self) -> dict: + """Return snapshot.""" + return {} + + def restore(self, snapshot: dict): + """Restore snapshot.""" + pass + + +class CertificateCreationRequestEvent(EventBase): + """Charm Event triggered when a TLS certificate is required.""" + + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): + super().__init__(handle) + self.certificate_signing_request = certificate_signing_request + self.relation_id = relation_id + self.is_ca = is_ca + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate_signing_request": self.certificate_signing_request, + "relation_id": self.relation_id, + "is_ca": self.is_ca, + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate_signing_request = snapshot["certificate_signing_request"] + self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] + + +class CertificateRevocationRequestEvent(EventBase): + """Charm Event triggered when a TLS certificate needs to be revoked.""" + + def __init__( + self, + handle: Handle, + certificate: str, + certificate_signing_request: str, + ca: str, + chain: str, + ): + super().__init__(handle) + self.certificate = certificate + self.certificate_signing_request = certificate_signing_request + self.ca = ca + self.chain = chain + + def snapshot(self) -> dict: + """Return snapshot.""" + return { + "certificate": self.certificate, + "certificate_signing_request": self.certificate_signing_request, + "ca": self.ca, + "chain": self.chain, + } + + def restore(self, snapshot: dict): + """Restore snapshot.""" + self.certificate = snapshot["certificate"] + self.certificate_signing_request = snapshot["certificate_signing_request"] + self.ca = snapshot["ca"] + self.chain = snapshot["chain"] + + +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: + """Load relation data from the relation data bag. + + Json loads all data. + + Args: + relation_data_content: Relation data from the databag + + Returns: + dict: Relation data in dict format. + """ + certificate_data = {} + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass + return certificate_data + + +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time + if datetime.now(timezone.utc) < expiry_notification_time + else expiry_time + ) + + +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. + + Args: + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. + + Returns: + datetime: Time to notify the user about the certificate expiry. + """ + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = ( + expiry_time - timedelta(hours=provider_recommended_notification_time) + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = ( + expiry_time - timedelta(hours=requirer_recommended_notification_time) + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) + + +def generate_ca( + private_key: bytes, + subject: str, + private_key_password: Optional[bytes] = None, + validity: int = 365, + country: str = "US", +) -> bytes: + """Generate a CA Certificate. + + Args: + private_key (bytes): Private key + subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + private_key_password (bytes): Private key password + validity (int): Certificate validity time (in days) + country (str): Certificate Issuing country + + Returns: + bytes: CA Certificate. + """ + private_key_object = serialization.load_pem_private_key( + private_key, password=private_key_password + ) + subject_name = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), + x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), + ] + ) + subject_identifier_object = x509.SubjectKeyIdentifier.from_public_key( + private_key_object.public_key() # type: ignore[arg-type] + ) + subject_identifier = key_identifier = subject_identifier_object.public_bytes() + key_usage = x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject_name) + .issuer_name(subject_name) + .public_key(private_key_object.public_key()) # type: ignore[arg-type] + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) + .add_extension( + x509.AuthorityKeyIdentifier( + key_identifier=key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ) + .add_extension(key_usage, critical=True) + .add_extension( + x509.BasicConstraints(ca=True, path_length=None), + critical=True, + ) + .sign(private_key_object, hashes.SHA256()) # type: ignore[arg-type] + ) + return cert.public_bytes(serialization.Encoding.PEM) + + +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + +def generate_certificate( + csr: bytes, + ca: bytes, + ca_key: bytes, + ca_key_password: Optional[bytes] = None, + validity: int = 365, + alt_names: Optional[List[str]] = None, + is_ca: bool = False, +) -> bytes: + """Generate a TLS certificate based on a CSR. + + Args: + csr (bytes): CSR + ca (bytes): CA Certificate + ca_key (bytes): CA private key + ca_key_password: CA private key password + validity (int): Certificate validity (in days) + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + bytes: Certificate + """ + csr_object = x509.load_pem_x509_csr(csr) + subject = csr_object.subject + ca_pem = x509.load_pem_x509_certificate(ca) + issuer = ca_pem.issuer + private_key = serialization.load_pem_private_key(ca_key, password=ca_key_password) + + certificate_builder = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(csr_object.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) + ) + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: + try: + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, + ) + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) + + cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] + return cert.public_bytes(serialization.Encoding.PEM) + + +def generate_private_key( + password: Optional[bytes] = None, + key_size: int = 2048, + public_exponent: int = 65537, +) -> bytes: + """Generate a private key. + + Args: + password (bytes): Password for decrypting the private key + key_size (int): Key size in bytes + public_exponent: Public exponent. + + Returns: + bytes: Private Key + """ + private_key = rsa.generate_private_key( + public_exponent=public_exponent, + key_size=key_size, + ) + key_bytes = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), + ) + return key_bytes + + +def generate_csr( # noqa: C901 + private_key: bytes, + subject: str, + add_unique_id_to_subject_name: bool = True, + organization: Optional[str] = None, + email_address: Optional[str] = None, + country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, + private_key_password: Optional[bytes] = None, + sans: Optional[List[str]] = None, + sans_oid: Optional[List[str]] = None, + sans_ip: Optional[List[str]] = None, + sans_dns: Optional[List[str]] = None, + additional_critical_extensions: Optional[List] = None, +) -> bytes: + """Generate a CSR using private key and subject. + + Args: + private_key (bytes): Private key + subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). + add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's + subject name. Always leave to "True" when the CSR is used to request certificates + using the tls-certificates relation. + organization (str): Name of organization. + email_address (str): Email address. + country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. + private_key_password (bytes): Private key password + sans (list): Use sans_dns - this will be deprecated in a future release + List of DNS subject alternative names (keeping it for now for backward compatibility) + sans_oid (list): List of registered ID SANs + sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) + sans_ip (list): List of IP subject alternative names + additional_critical_extensions (list): List of critical additional extension objects. + Object must be a x509 ExtensionType. + + Returns: + bytes: CSR + """ + signing_key = serialization.load_pem_private_key(private_key, password=private_key_password) + subject_name = [x509.NameAttribute(x509.NameOID.COMMON_NAME, subject)] + if add_unique_id_to_subject_name: + unique_identifier = uuid.uuid4() + subject_name.append( + x509.NameAttribute(x509.NameOID.X500_UNIQUE_IDENTIFIER, str(unique_identifier)) + ) + if organization: + subject_name.append(x509.NameAttribute(x509.NameOID.ORGANIZATION_NAME, organization)) + if email_address: + subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) + if country_name: + subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) + csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) + + _sans: List[x509.GeneralName] = [] + if sans_oid: + _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) + if sans_ip: + _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + if sans: + _sans.extend([x509.DNSName(san) for san in sans]) + if sans_dns: + _sans.extend([x509.DNSName(san) for san in sans_dns]) + if _sans: + csr = csr.add_extension(x509.SubjectAlternativeName(set(_sans)), critical=False) + + if additional_critical_extensions: + for extension in additional_critical_extensions: + csr = csr.add_extension(extension, critical=True) + + signed_certificate = csr.sign(signing_key, hashes.SHA256()) # type: ignore[arg-type] + return signed_certificate.public_bytes(serialization.Encoding.PEM) + + +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Check whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + json_schema (dict): Json schema + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + +class CertificatesProviderCharmEvents(CharmEvents): + """List of events that the TLS Certificates provider charm can leverage.""" + + certificate_creation_request = EventSource(CertificateCreationRequestEvent) + certificate_revocation_request = EventSource(CertificateRevocationRequestEvent) + + +class CertificatesRequirerCharmEvents(CharmEvents): + """List of events that the TLS Certificates requirer charm can leverage.""" + + certificate_available = EventSource(CertificateAvailableEvent) + certificate_expiring = EventSource(CertificateExpiringEvent) + certificate_invalidated = EventSource(CertificateInvalidatedEvent) + all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) + + +class TLSCertificatesProvidesV3(Object): + """TLS certificates provider class to be instantiated by TLS certificates providers.""" + + on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relationship_name: str): + super().__init__(charm, relationship_name) + self.framework.observe( + charm.on[relationship_name].relation_changed, self._on_relation_changed + ) + self.charm = charm + self.relationship_name = relationship_name + + def _load_app_relation_data(self, relation: Relation) -> dict: + """Load relation data from the application relation data bag. + + Json loads all data. + + Args: + relation: Relation data from the application databag + + Returns: + dict: Relation data in dict format. + """ + # If unit is not leader, it does not try to reach relation data. + if not self.model.unit.is_leader(): + return {} + return _load_relation_data(relation.data[self.charm.app]) + + def _add_certificate( + self, + relation_id: int, + certificate: str, + certificate_signing_request: str, + ca: str, + chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, + ) -> None: + """Add certificate to relation data. + + Args: + relation_id (int): Relation id + certificate (str): Certificate + certificate_signing_request (str): Certificate Signing Request + ca (str): CA Certificate + chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. + + Returns: + None + """ + relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=relation_id + ) + if not relation: + raise RuntimeError( + f"Relation {self.relationship_name} does not exist - " + f"The certificate request can't be completed" + ) + new_certificate = { + "certificate": certificate, + "certificate_signing_request": certificate_signing_request, + "ca": ca, + "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, + } + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = provider_relation_data.get("certificates", []) + certificates = copy.deepcopy(provider_certificates) + if new_certificate in certificates: + logger.info("Certificate already in relation data - Doing nothing") + return + certificates.append(new_certificate) + relation.data[self.model.app]["certificates"] = json.dumps(certificates) + + def _remove_certificate( + self, + relation_id: int, + certificate: Optional[str] = None, + certificate_signing_request: Optional[str] = None, + ) -> None: + """Remove certificate from a given relation based on user provided certificate or csr. + + Args: + relation_id (int): Relation id + certificate (str): Certificate (optional) + certificate_signing_request: Certificate signing request (optional) + + Returns: + None + """ + relation = self.model.get_relation( + relation_name=self.relationship_name, + relation_id=relation_id, + ) + if not relation: + raise RuntimeError( + f"Relation {self.relationship_name} with relation id {relation_id} does not exist" + ) + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = provider_relation_data.get("certificates", []) + certificates = copy.deepcopy(provider_certificates) + for certificate_dict in certificates: + if certificate and certificate_dict["certificate"] == certificate: + certificates.remove(certificate_dict) + if ( + certificate_signing_request + and certificate_dict["certificate_signing_request"] == certificate_signing_request + ): + certificates.remove(certificate_dict) + relation.data[self.model.app]["certificates"] = json.dumps(certificates) + + def revoke_all_certificates(self) -> None: + """Revoke all certificates of this provider. + + This method is meant to be used when the Root CA has changed. + """ + for relation in self.model.relations[self.relationship_name]: + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = copy.deepcopy(provider_relation_data.get("certificates", [])) + for certificate in provider_certificates: + certificate["revoked"] = True + relation.data[self.model.app]["certificates"] = json.dumps(provider_certificates) + + def set_relation_certificate( + self, + certificate: str, + certificate_signing_request: str, + ca: str, + chain: List[str], + relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, + ) -> None: + """Add certificates to relation data. + + Args: + certificate (str): Certificate + certificate_signing_request (str): Certificate signing request + ca (str): CA Certificate + chain (list): CA Chain + relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. + + Returns: + None + """ + if not self.model.unit.is_leader(): + return + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=relation_id + ) + if not certificates_relation: + raise RuntimeError(f"Relation {self.relationship_name} does not exist") + self._remove_certificate( + certificate_signing_request=certificate_signing_request.strip(), + relation_id=relation_id, + ) + self._add_certificate( + relation_id=relation_id, + certificate=certificate.strip(), + certificate_signing_request=certificate_signing_request.strip(), + ca=ca.strip(), + chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, + ) + + def remove_certificate(self, certificate: str) -> None: + """Remove a given certificate from relation data. + + Args: + certificate (str): TLS Certificate + + Returns: + None + """ + certificates_relation = self.model.relations[self.relationship_name] + if not certificates_relation: + raise RuntimeError(f"Relation {self.relationship_name} does not exist") + for certificate_relation in certificates_relation: + self._remove_certificate(certificate=certificate, relation_id=certificate_relation.id) + + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates. + + Returns: + List: List of ProviderCertificate objects + """ + certificates: List[ProviderCertificate] = [] + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = provider_relation_data.get("certificates", []) + for certificate in provider_certificates: + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() + ) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), + ) + certificates.append(provider_certificate) + return certificates + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Handle relation changed event. + + Looks at the relation data and either emits: + - certificate request event: If the unit relation data contains a CSR for which + a certificate does not exist in the provider relation data. + - certificate revocation event: If the provider relation data contains a CSR for which + a csr does not exist in the requirer relation data. + + Args: + event: Juju event + + Returns: + None + """ + if event.unit is None: + logger.error("Relation_changed event does not have a unit.") + return + if not self.model.unit.is_leader(): + return + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) + provider_csrs = [ + certificate_creation_request.csr + for certificate_creation_request in provider_certificates + ] + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: + self.on.certificate_creation_request.emit( + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, + ) + self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) + + def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: + """Revoke certificates for which no unit has a CSR. + + Goes through all generated certificates and compare against the list of CSRs for all units. + + Returns: + None + """ + provider_certificates = self.get_provider_certificates(relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] + for certificate in provider_certificates: + if certificate.csr not in list_of_csrs: + self.on.certificate_revocation_request.emit( + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, + ) + self.remove_certificate(certificate=certificate.certificate) + + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCSR]: + """Return CSR's for which no certificate has been issued. + + Args: + relation_id (int): Relation id + + Returns: + list: List of RequirerCSR objects. + """ + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Return a list of requirers' CSRs. + + It returns CSRs from all relations if relation_id is not specified. + CSRs are returned per relation id, application name and unit name. + + Returns: + list: List[RequirerCSR] + """ + relation_csrs: List[RequirerCSR] = [] + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + + for relation in relations: + for unit in relation.units: + requirer_relation_data = _load_relation_data(relation.data[unit]) + unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs + + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR. + + Args: + app_name (str): Application name that the CSR belongs to. + csr (str): Certificate Signing Request. + relation_id (Optional[int]): Relation ID + + Returns: + bool: True/False depending on whether a certificate has been issued for the given CSR. + """ + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) + return False + + +class TLSCertificatesRequiresV3(Object): + """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" + + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] + + def __init__( + self, + charm: CharmBase, + relationship_name: str, + expiry_notification_time: Optional[int] = None, + ): + """Generate/use private key and observes relation changed event. + + Args: + charm: Charm object + relationship_name: Juju relation name + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. + """ + super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") + self.relationship_name = relationship_name + self.charm = charm + self.expiry_notification_time = expiry_notification_time + self.framework.observe( + charm.on[relationship_name].relation_changed, self._on_relation_changed + ) + self.framework.observe( + charm.on[relationship_name].relation_broken, self._on_relation_broken + ) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + + def get_requirer_csrs(self) -> List[RequirerCSR]: + """Return list of requirer's CSRs from relation unit data. + + Returns: + list: List of RequirerCSR objects. + """ + relation = self.model.get_relation(self.relationship_name) + if not relation: + return [] + requirer_csrs = [] + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs + + def get_provider_certificates(self) -> List[ProviderCertificate]: + """Return list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] + relation = self.model.get_relation(self.relationship_name) + if not relation: + logger.debug("No relation: %s", self.relationship_name) + return [] + if not relation.app: + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] + provider_relation_data = _load_relation_data(relation.data[relation.app]) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, + ) + provider_certificates.append(provider_certificate) + return provider_certificates + + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: + """Add CSR to relation data. + + Args: + csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + None + """ + relation = self.model.get_relation(self.relationship_name) + if not relation: + raise RuntimeError( + f"Relation {self.relationship_name} does not exist - " + f"The certificate request can't be completed" + ) + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { + "certificate_signing_request": csr, + "ca": is_ca, + } + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) + + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: + """Remove CSR from relation data. + + Args: + csr (str): Certificate signing request + + Returns: + None + """ + relation = self.model.get_relation(self.relationship_name) + if not relation: + raise RuntimeError( + f"Relation {self.relationship_name} does not exist - " + f"The certificate request can't be completed" + ) + if not self.get_requirer_csrs(): + logger.info("No CSRs in relation data - Doing nothing") + return + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: + if requirer_csr["certificate_signing_request"] == csr: + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) + + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: + """Request TLS certificate to provider charm. + + Args: + certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + None + """ + relation = self.model.get_relation(self.relationship_name) + if not relation: + raise RuntimeError( + f"Relation {self.relationship_name} does not exist - " + f"The certificate request can't be completed" + ) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) + logger.info("Certificate request sent to provider") + + def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: + """Remove CSR from relation data. + + The provider of this relation is then expected to remove certificates associated to this + CSR from the relation data as well and emit a request_certificate_revocation event for the + provider charm to interpret. + + Args: + certificate_signing_request (bytes): Certificate Signing Request + + Returns: + None + """ + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) + logger.info("Certificate revocation sent to provider") + + def request_certificate_renewal( + self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes + ) -> None: + """Renew certificate. + + Removes old CSR from relation data and adds new one. + + Args: + old_certificate_signing_request: Old CSR + new_certificate_signing_request: New CSR + + Returns: + None + """ + try: + self.request_certificate_revocation( + certificate_signing_request=old_certificate_signing_request + ) + except RuntimeError: + logger.warning("Certificate revocation failed.") + self.request_certificate_creation( + certificate_signing_request=new_certificate_signing_request + ) + logger.info("Certificate renewal request completed.") + + def get_assigned_certificates(self) -> List[ProviderCertificate]: + """Get a list of certificates that were assigned to this unit. + + Returns: + List: List[ProviderCertificate] + """ + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List: List[ProviderCertificate] + """ + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + if not cert.expiry_time or not cert.expiry_notification_time: + continue + if datetime.now(timezone.utc) > cert.expiry_notification_time: + expiring_certificates.append(cert) + return expiring_certificates + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[RequirerCSR]: + """Get the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. + + Args: + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + + Returns: + List of RequirerCSR objects. + """ + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + csrs.append(requirer_csr) + + return csrs + + def _on_relation_changed(self, event: RelationChangedEvent) -> None: + """Handle relation changed event. + + Goes through all providers certificates that match a requested CSR. + + If the provider certificate is revoked, emit a CertificateInvalidateEvent, + otherwise emit a CertificateAvailableEvent. + + Remove the secret for revoked certificate, or add a secret with the correct expiry + time for new certificates. + + Args: + event: Juju event + + Returns: + None + """ + if not event.app: + logger.warning("No remote app in relation - Skipping") + return + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates() + requirer_csrs = [ + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() + ] + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + if certificate.revoked: + with suppress(SecretNotFoundError): + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.remove_all_revisions() + self.on.certificate_invalidated.emit( + reason="revoked", + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, + ) + else: + try: + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.set_content({"certificate": certificate.certificate}) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate), + ) + except SecretNotFoundError: + logger.debug("Adding secret with label %s", f"{LIBID}-{certificate.csr}") + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate}, + label=f"{LIBID}-{certificate.csr}", + expire=self._get_next_secret_expiry_time(certificate), + ) + self.on.certificate_available.emit( + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, + ) + + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: + """Return the expiry time or expiry notification time. + + Extracts the expiry time from the provided certificate, calculates the + expiry notification time and return the closest of the two, that is in + the future. + + Args: + certificate: ProviderCertificate object + + Returns: + Optional[datetime]: None if the certificate expiry time cannot be read, + next expiry time otherwise. + """ + if not certificate.expiry_time or not certificate.expiry_notification_time: + return None + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) + + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + """Handle Relation Broken Event. + + Emitting `all_certificates_invalidated` from `relation-broken` rather + than `relation-departed` since certs are stored in app data. + + Args: + event: Juju event + + Returns: + None + """ + self.on.all_certificates_invalidated.emit() + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle Secret Expired Event. + + Loads the certificate from the secret, and will emit 1 of 2 + events. + + If the certificate is not yet expired, emits CertificateExpiringEvent + and updates the expiry time of the secret to the exact expiry time on + the certificate. + + If the certificate is expired, emits CertificateInvalidedEvent and + deletes the secret. + + Args: + event (SecretExpiredEvent): Juju event + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): + return + csr = event.secret.label[len(f"{LIBID}-") :] + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: + # A secret expired but we did not find matching certificate. Cleaning up + event.secret.remove_all_revisions() + return + + if not provider_certificate.expiry_time: + # A secret expired but matching certificate is invalid. Cleaning up + event.secret.remove_all_revisions() + return + + if datetime.now(timezone.utc) < provider_certificate.expiry_time: + logger.warning("Certificate almost expired") + self.on.certificate_expiring.emit( + certificate=provider_certificate.certificate, + expiry=provider_certificate.expiry_time.isoformat(), + ) + event.secret.set_info( + expire=provider_certificate.expiry_time, + ) + else: + logger.warning("Certificate is expired") + self.on.certificate_invalidated.emit( + reason="expired", + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + self.request_certificate_revocation(provider_certificate.certificate.encode()) + event.secret.remove_all_revisions() + + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: + """Return the certificate that match the given CSR.""" + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: + continue + return provider_certificate + return None diff --git a/pyproject.toml b/pyproject.toml index 78ac657e..b0d3eb84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ show_missing = true [tool.pytest.ini_options] minversion = "6.0" log_cli_level = "INFO" +markers = ["setup", "teardown"] # Formatting tools configuration [tool.isort] diff --git a/requirements.txt b/requirements.txt index 456929af..af2ddbf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,10 @@ tenacity==8.2.3 # lib/charms/tempo_k8s/v1/charm_tracing.py opentelemetry-exporter-otlp-proto-http==1.21.0 -# lib/charms/tempo_k8s/v1/tracing.py requires pydantic; we have higher standards: +# lib/charms/tls_certificates_interface/v2/tls_certificates.py +jsonschema +cryptography +# lib/charms/tempo_k8s/v1/tracing.py pydantic>=2 # lib/charms/prometheus_k8s/v0/prometheus_scrape.py cosl diff --git a/scripts/tracegen.py b/scripts/tracegen.py index 8b1e53f1..3505be48 100644 --- a/scripts/tracegen.py +++ b/scripts/tracegen.py @@ -1,9 +1,11 @@ +import os import time +from pathlib import Path +from typing import Any, Literal from opentelemetry import trace -from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter, -) +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPExporter from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import ( @@ -12,16 +14,43 @@ ) -def emit_trace(endpoint: str, log_trace_to_console: bool = False): - span_exporter = OTLPSpanExporter( - endpoint=endpoint, - insecure=True, +def emit_trace( + endpoint: str, + log_trace_to_console: bool = False, + cert: Path = None, + protocol: Literal["grpc", "http", "ALL"] = "grpc", + nonce: Any = None +): + os.environ['OTEL_EXPORTER_OTLP_TRACES_CERTIFICATE'] = str(Path(cert).absolute()) if cert else "" + + if protocol == "grpc": + span_exporter = GRPCExporter( + endpoint=endpoint, + insecure=not cert, + ) + elif protocol == "http": + span_exporter = HTTPExporter( + endpoint=endpoint, + ) + else: # ALL + return (emit_trace(endpoint, log_trace_to_console, cert, "grpc", nonce=nonce) and + emit_trace(endpoint, log_trace_to_console, cert, "http", nonce=nonce)) + + return _export_trace(span_exporter, log_trace_to_console=log_trace_to_console, nonce=nonce) + + +def _export_trace(span_exporter, log_trace_to_console: bool = False, nonce: Any = None): + resource = Resource.create(attributes={ + "service.name": "tracegen", + "nonce": str(nonce) + } ) - resource = Resource.create(attributes={"service.name": "tracegen"}) provider = TracerProvider(resource=resource) + if log_trace_to_console: processor = BatchSpanProcessor(ConsoleSpanExporter()) provider.add_span_processor(processor) + span_processor = BatchSpanProcessor(span_exporter) provider.add_span_processor(span_processor) trace.set_tracer_provider(provider) @@ -33,8 +62,14 @@ def emit_trace(endpoint: str, log_trace_to_console: bool = False): with tracer.start_as_current_span("baz"): time.sleep(.1) + return span_exporter.force_flush() -if __name__ == '__main__': - import os - emit_trace(os.getenv("TEMPO", "http://127.0.0.1:8080")) +if __name__ == '__main__': + emit_trace( + endpoint=os.getenv("TRACEGEN_ENDPOINT", "http://127.0.0.1:8080"), + cert=os.getenv("TRACEGEN_CERT", None), + log_trace_to_console=os.getenv("TRACEGEN_VERBOSE", False), + protocol=os.getenv("TRACEGEN_PROTOCOL", "http"), + nonce=os.getenv("TRACEGEN_NONCE", "24") + ) diff --git a/src/charm.py b/src/charm.py index b22748d0..29a8ff29 100755 --- a/src/charm.py +++ b/src/charm.py @@ -7,6 +7,7 @@ import logging import re import socket +from pathlib import Path from typing import Optional, Tuple import charms.tempo_k8s.v1.tracing as tracing_v1 @@ -15,6 +16,7 @@ from charms.grafana_k8s.v0.grafana_source import GrafanaSourceProvider from charms.loki_k8s.v0.loki_push_api import LogProxyConsumer from charms.observability_libs.v0.kubernetes_service_patch import KubernetesServicePatch +from charms.observability_libs.v1.cert_handler import CertHandler from charms.prometheus_k8s.v0.prometheus_scrape import MetricsEndpointProvider from charms.tempo_k8s.v1.charm_tracing import trace_charm from charms.tempo_k8s.v2.tracing import ( @@ -33,7 +35,6 @@ ) from ops.main import main from ops.model import ActiveStatus, MaintenanceStatus, Relation, WaitingStatus - from tempo import Tempo logger = logging.getLogger(__name__) @@ -48,6 +49,7 @@ @trace_charm( tracing_endpoint="tempo_otlp_http_endpoint", + server_cert="server_cert", extra_types=(Tempo, TracingEndpointProvider, tracing_v1.TracingEndpointProvider), ) class TempoCharm(CharmBase): @@ -57,12 +59,32 @@ def __init__(self, *args): super().__init__(*args) self.tempo = tempo = Tempo( self.unit.get_container("tempo"), + external_host=self.hostname, # we need otlp_http receiver for charm_tracing enable_receivers=["otlp_http"], ) + + # TODO: + # ingress route provisioning a separate TCP ingress for each receiver if GRPC doesn't work directly + self.ingress = TraefikRouteRequirer(self, self.model.get_relation("ingress"), "ingress") # type: ignore + + self.cert_handler = CertHandler( + self, + key="tempo-server-cert", + sans=[self.hostname], + ) + # configure this tempo as a datasource in grafana self.grafana_source_provider = GrafanaSourceProvider( - self, source_type="tempo", source_port=str(tempo.tempo_server_port) + self, + source_type="tempo", + source_url=self._external_http_server_url, + refresh_event=[ + # refresh the source url when TLS config might be changing + self.on[self.cert_handler.certificates_relation_name].relation_changed, + # or when ingress changes + self.ingress.on.ready, + ], ) # # Patch the juju-created Kubernetes service to contain the right ports external_ports = tempo.get_external_ports(self.app.name) @@ -71,7 +93,7 @@ def __init__(self, *args): self._scraping = MetricsEndpointProvider( self, relation_name="metrics-endpoint", - jobs=[{"static_configs": [{"targets": [f"*:{tempo.tempo_server_port}"]}]}], + jobs=[{"static_configs": [{"targets": [f"*:{tempo.tempo_http_server_port}"]}]}], ) # Enable log forwarding for Loki and other charms that implement loki_push_api self._logging = LogProxyConsumer( @@ -80,18 +102,12 @@ def __init__(self, *args): self._grafana_dashboards = GrafanaDashboardProvider( self, relation_name="grafana-dashboard" ) - # Enable profiling over a relation with Parca - # self._profiling = ProfilingEndpointProvider( - # self, jobs=[{"static_configs": [{"targets": ["*:4080"]}]}] - # ) - - self._ingress = TraefikRouteRequirer(self, self.model.get_relation("ingress"), "ingress") # type: ignore - self._tracing = TracingEndpointProvider( - # TODO set internal_scheme based on whether TLS is enabled + + self.tracing = TracingEndpointProvider( self, - host=self.tempo.host, - external_url=self._ingress.external_host, - internal_scheme="http", + host=self.hostname, + external_url=self._external_url, + internal_scheme="https" if self.tls_available else "http", ) self.framework.observe( @@ -101,22 +117,89 @@ def __init__(self, *args): self.on["ingress"].relation_joined, self._on_ingress_relation_joined ) self.framework.observe(self.on.leader_elected, self._on_leader_elected) - self.framework.observe(self._ingress.on.ready, self._on_ingress_ready) + self.framework.observe(self.ingress.on.ready, self._on_ingress_ready) self.framework.observe(self.on.tempo_pebble_ready, self._on_tempo_pebble_ready) self.framework.observe( self.on.tempo_pebble_custom_notice, self._on_tempo_pebble_custom_notice ) self.framework.observe(self.on.update_status, self._on_update_status) - self.framework.observe(self._tracing.on.request, self._on_tracing_request) + self.framework.observe(self.tracing.on.request, self._on_tracing_request) self.framework.observe(self.on.tracing_relation_created, self._on_tracing_relation_created) self.framework.observe(self.on.tracing_relation_joined, self._on_tracing_relation_joined) self.framework.observe(self.on.tracing_relation_changed, self._on_tracing_relation_changed) self.framework.observe(self.on.collect_unit_status, self._on_collect_unit_status) self.framework.observe(self.on.list_receivers_action, self._on_list_receivers_action) + self.framework.observe(self.cert_handler.on.cert_changed, self._on_cert_handler_changed) + self.framework.observe(self.on.config_changed, self._on_config_changed) + + @property + def _external_http_server_url(self) -> str: + """External url of the http(s) server.""" + return f"{self._external_url}:{self.tempo.tempo_http_server_port}" + + @property + def _external_url(self) -> str: + """Return the external url.""" + if self.ingress.is_ready(): + ingress_url = f"{self.ingress.scheme}://{self.ingress.external_host}" + logger.debug("This unit's ingress URL: %s", ingress_url) + return ingress_url + + # If we do not have an ingress, then use the pod hostname. + # The reason to prefer this over the pod name (which is the actual + # hostname visible from the pod) or a K8s service, is that those + # are routable virtually exclusively inside the cluster (as they rely) + # on the cluster's DNS service, while the ip address is _sometimes_ + # routable from the outside, e.g., when deploying on MicroK8s on Linux. + return self._internal_url + + @property + def _internal_url(self) -> str: + scheme = "https" if self.tls_available else "http" + return f"{scheme}://{self.hostname}" + + @property + def tls_available(self) -> bool: + """Return True if tls is enabled and the necessary certs are found.""" + return ( + self.cert_handler.enabled + and (self.cert_handler.server_cert is not None) + and (self.cert_handler.private_key is not None) + and (self.cert_handler.ca_cert is not None) + ) + + def _on_cert_handler_changed(self, _): + was_ready = self.tempo.tls_ready + + if self.tls_available: + logger.debug("enabling TLS") + self.tempo.configure_tls( + cert=self.cert_handler.server_cert, # type: ignore + key=self.cert_handler.private_key, # type: ignore + ca=self.cert_handler.ca_cert, # type: ignore + ) + else: + logger.debug("disabling TLS") + self.tempo.clear_tls_config() + + if was_ready != self.tempo.tls_ready: + # tls readiness change means config change. + self.tempo.update_config(self._requested_receivers()) + # sync scheme change with traefik and related consumers + self._configure_ingress(_) + self.tempo.restart() + + # sync the server cert with the charm container. + # technically, because of charm tracing, this will be called first thing on each event + self._update_server_cert() + + # update relations to reflect the new certificate + self._update_tracing_v1_relations() + self._update_tracing_v2_relations() def _is_legacy_v1_relation(self, relation): - if self._tracing.is_v2(relation): + if self.tracing.is_v2(relation): return False juju_keys = {"egress-subnets", "ingress-address", "private-address"} @@ -133,11 +216,11 @@ def _configure_ingress(self, _) -> None: if not self.unit.is_leader(): return - if self._ingress.is_ready(): - self._ingress.submit_to_traefik( + if self.ingress.is_ready(): + self.ingress.submit_to_traefik( self._ingress_config, static=self._static_ingress_config ) - if self._ingress.external_host: + if self.ingress.external_host: self._update_tracing_v1_relations() self._update_tracing_v2_relations() @@ -152,18 +235,18 @@ def _on_tracing_request(self, e: RequestEvent): self._update_tracing_v2_relations() def _on_tracing_relation_created(self, e: RelationEvent): - if not self._tracing.is_v2(e.relation): + if not self.tracing.is_v2(e.relation): self._publish_v1_data(e.relation) # if this is the first legacy relation we get, we need to update ALL other relations # as we might need to add all legacy protocols to the mix self._update_tracing_v2_relations() def _on_tracing_relation_joined(self, e: RelationEvent): - if not self._tracing.is_v2(e.relation): + if not self.tracing.is_v2(e.relation): self._publish_v1_data(e.relation) def _on_tracing_relation_changed(self, e: RelationEvent): - if not self._tracing.is_v2(e.relation): + if not self.tracing.is_v2(e.relation): self._publish_v1_data(e.relation) def _on_ingress_relation_created(self, e: RelationEvent): @@ -176,9 +259,19 @@ def _on_leader_elected(self, e: HookEvent): # as traefik_route goes through app data, we need to take lead of traefik_route if our leader dies. self._configure_ingress(e) + def _on_config_changed(self, _): + # check if certificate files haven't disappeared and recreate them if needed + if self.tls_available and not self.tempo.tls_ready: + logger.debug("enabling TLS") + self.tempo.configure_tls( + cert=self.cert_handler.server_cert, # type: ignore + key=self.cert_handler.private_key, # type: ignore + ca=self.cert_handler.ca_cert, # type: ignore + ) + def _update_tracing_v1_relations(self): - for relation in self.model.relations[self._tracing._relation_name]: - if not self._tracing.is_v2(relation): + for relation in self.model.relations[self.tracing._relation_name]: + if not self.tracing.is_v2(relation): self._publish_v1_data(relation) def _publish_v1_data(self, relation: Relation): @@ -196,9 +289,11 @@ def _publish_v1_data(self, relation: Relation): tracing_v1.Ingester(protocol=p, port=self.tempo.receiver_ports[p]) for p in LEGACY_RECEIVER_PROTOCOLS ] - tracing_v1.TracingProviderAppData(host=self.tempo.host, ingesters=receivers).dump( - relation.data[self.app] - ) + # this should be behind a leader guard + if self.unit.is_leader(): + tracing_v1.TracingProviderAppData(host=self.hostname, ingesters=receivers).dump( + relation.data[self.app] + ) def _update_tracing_v2_relations(self): tracing_relations = self.model.relations["tracing"] @@ -210,9 +305,10 @@ def _update_tracing_v2_relations(self): requested_receivers = self._requested_receivers() # publish requested protocols to all v2 relations - self._tracing.publish_receivers( - [(p, self.tempo.receiver_ports[p]) for p in requested_receivers] - ) + if self.unit.is_leader(): + self.tracing.publish_receivers( + [(p, self.tempo.receiver_ports[p]) for p in requested_receivers] + ) self._restart_if_receivers_changed(requested_receivers) @@ -231,7 +327,7 @@ def _restart_if_receivers_changed(self, requested_receivers): def _requested_receivers(self) -> Tuple[ReceiverProtocol, ...]: """List what receivers we should activate, based on the active tracing relations.""" # we start with the sum of the requested endpoints from the v2 requirers - requested_protocols = set(self._tracing.requested_protocols()) + requested_protocols = set(self.tracing.requested_protocols()) # if we have any v0/v1 requirer, we'll need to activate all supported legacy endpoints # and publish them too (only to v1 requirers). @@ -319,13 +415,28 @@ def _get_version(self) -> Optional[str]: return return version + def server_cert(self): + """For charm tracing.""" + self._update_server_cert() + return self.tempo.server_cert_path + + def _update_server_cert(self): + """Server certificate for charm tracing tls, if tls is enabled.""" + server_cert = Path(self.tempo.server_cert_path) + if self.tls_available: + if not server_cert.exists(): + server_cert.parent.mkdir(parents=True, exist_ok=True) + if self.cert_handler.server_cert: + server_cert.write_text(self.cert_handler.server_cert) + else: # tls unavailable: delete local cert + server_cert.unlink(missing_ok=True) + def tempo_otlp_http_endpoint(self) -> Optional[str]: """Endpoint at which the charm tracing information will be forwarded.""" # the charm container and the tempo workload container have apparently the same # IP, so we can talk to tempo at localhost. - # TODO switch to HTTPS once SSL support is added if self.tempo.is_ready(): - return f"http://localhost:{self.tempo.receiver_ports['otlp_http']}" + return f"{self._internal_url}:{self.tempo.receiver_ports['otlp_http']}" return None @@ -345,9 +456,9 @@ def hostname(self) -> str: def _on_list_receivers_action(self, event: ops.ActionEvent): res = {} for receiver in self._requested_receivers(): - res[ - receiver.replace("_", "-") - ] = f"{self._ingress.external_host or self.tempo.url}/{receiver}" + res[receiver.replace("_", "-")] = ( + f"{self.ingress.external_host or self.tempo.url}/{receiver}" + ) event.set_results(res) @property @@ -362,43 +473,29 @@ def _static_ingress_config(self) -> dict: @property def _ingress_config(self) -> dict: """Build a raw ingress configuration for Traefik.""" - tcp_routers = {} - tcp_services = {} http_routers = {} http_services = {} for protocol, port in self.tempo.all_ports.items(): sanitized_protocol = protocol.replace("_", "-") - if sanitized_protocol.endswith("grpc"): - # grpc handling - tcp_routers[ - f"juju-{self.model.name}-{self.model.app.name}-{sanitized_protocol}" - ] = { - "entryPoints": [sanitized_protocol], - "service": f"juju-{self.model.name}-{self.model.app.name}-service-{sanitized_protocol}", - # TODO better matcher - "rule": "ClientIP(`0.0.0.0/0`)", - } - tcp_services[ + http_routers[f"juju-{self.model.name}-{self.model.app.name}-{sanitized_protocol}"] = { + "entryPoints": [sanitized_protocol], + "service": f"juju-{self.model.name}-{self.model.app.name}-service-{sanitized_protocol}", + # TODO better matcher + "rule": "ClientIP(`0.0.0.0/0`)", + } + if sanitized_protocol.endswith("grpc") and not self.tls_available: + # to send traces to unsecured GRPC endpoints, we need h2c + # see https://doc.traefik.io/traefik/v2.0/user-guides/grpc/#with-http-h2c + http_services[ f"juju-{self.model.name}-{self.model.app.name}-service-{sanitized_protocol}" - ] = {"loadBalancer": {"servers": [{"address": f"{self.hostname}:{port}"}]}} + ] = {"loadBalancer": {"servers": [{"url": f"h2c://{self.hostname}:{port}"}]}} else: - # it's a http protocol, so we use a http section of the dynamic configuration - http_routers[ - f"juju-{self.model.name}-{self.model.app.name}-{sanitized_protocol}" - ] = { - "entryPoints": [sanitized_protocol], - "service": f"juju-{self.model.name}-{self.model.app.name}-service-{sanitized_protocol}", - # TODO better matcher - "rule": "ClientIP(`0.0.0.0/0`)", - } + # anything else, including secured GRPC, can use _internal_url + # ref https://doc.traefik.io/traefik/v2.0/user-guides/grpc/#with-https http_services[ f"juju-{self.model.name}-{self.model.app.name}-service-{sanitized_protocol}" - ] = {"loadBalancer": {"servers": [{"url": f"http://{self.hostname}:{port}"}]}} + ] = {"loadBalancer": {"servers": [{"url": f"{self._internal_url}:{port}"}]}} return { - "tcp": { - "routers": tcp_routers, - "services": tcp_services, - }, "http": { "routers": http_routers, "services": http_services, diff --git a/src/tempo.py b/src/tempo.py index 0de2e7f5..84a08df7 100644 --- a/src/tempo.py +++ b/src/tempo.py @@ -5,6 +5,7 @@ """Tempo workload configuration and client.""" import logging import socket +from pathlib import Path from subprocess import CalledProcessError, getoutput from typing import Dict, List, Optional, Sequence, Tuple @@ -21,13 +22,26 @@ class Tempo: """Class representing the Tempo client workload configuration.""" config_path = "/etc/tempo/tempo.yaml" + + # cert path on charm container + server_cert_path = "/usr/local/share/ca-certificates/ca.crt" + + # cert paths on tempo container + tls_cert_path = "/etc/tempo/tls/server.crt" + tls_key_path = "/etc/tempo/tls/server.key" + tls_ca_path = "/usr/local/share/ca-certificates/ca.crt" + + _tls_min_version = "" + # cfr https://grafana.com/docs/enterprise-traces/latest/configure/reference/#supported-contents-and-default-values + # "VersionTLS12" + wal_path = "/etc/tempo/tempo_wal" log_path = "/var/log/tempo.log" tempo_ready_notice_key = "canonical.com/tempo/workload-ready" server_ports = { "tempo_http": 3200, - # "tempo_grpc": 9096, # default grpc listen port is 9095, but that conflicts with promtail. + "tempo_grpc": 9096, # default grpc listen port is 9095, but that conflicts with promtail. } receiver_ports: Dict[ReceiverProtocol, int] = { @@ -48,19 +62,26 @@ class Tempo: def __init__( self, container: ops.Container, - local_host: str = "0.0.0.0", + external_host: Optional[str] = None, enable_receivers: Optional[Sequence[ReceiverProtocol]] = None, ): # ports source: https://github.com/grafana/tempo/blob/main/example/docker-compose/local/docker-compose.yaml - self._local_hostname = local_host + + # fqdn, if an ingress is not available, else the ingress address. + self._external_hostname = external_host or socket.getfqdn() self.container = container self.enabled_receivers = enable_receivers or [] @property - def tempo_server_port(self) -> int: + def tempo_http_server_port(self) -> int: """Return the receiver port for the built-in tempo_http protocol.""" return self.server_ports["tempo_http"] + @property + def tempo_grpc_server_port(self) -> int: + """Return the receiver port for the built-in tempo_http protocol.""" + return self.server_ports["tempo_grpc"] + def get_external_ports(self, service_name_prefix: str) -> List[Tuple[str, int, int]]: """List of service names and port mappings for the kubernetes service patch. @@ -77,15 +98,11 @@ def get_external_ports(self, service_name_prefix: str) -> List[Tuple[str, int, i for service_name in all_ports ] - @property - def host(self) -> str: - """Hostname at which tempo is running.""" - return socket.getfqdn() - @property def url(self) -> str: """Base url at which the tempo server is locally reachable over http.""" - return f"http://{self.host}" + scheme = "https" if self.tls_ready else "http" + return f"{scheme}://{self._external_hostname}" def plan(self): """Update pebble plan and start the tempo-ready service.""" @@ -161,17 +178,59 @@ def get_current_config(self) -> Optional[dict]: except ops.pebble.PathError: return None + def configure_tls(self, *, cert: str, key: str, ca: str): + """Push cert, key and CA to the tempo container.""" + # we save the cacert in the charm container too (for notices) + Path(self.server_cert_path).write_text(ca) + + self.container.push(self.tls_cert_path, cert, make_dirs=True) + self.container.push(self.tls_key_path, key, make_dirs=True) + self.container.push(self.tls_ca_path, ca, make_dirs=True) + self.container.exec(["update-ca-certificates"]) + + def clear_tls_config(self): + """Remove cert, key and CA files from the tempo container.""" + self.container.remove_path(self.tls_cert_path, recursive=True) + self.container.remove_path(self.tls_key_path, recursive=True) + self.container.remove_path(self.tls_ca_path, recursive=True) + + @property + def tls_ready(self) -> bool: + """Whether cert, key, and ca paths are found on disk and Tempo is ready to use tls.""" + if not self.container.can_connect(): + return False + return all( + self.container.exists(tls_path) + for tls_path in (self.tls_cert_path, self.tls_key_path, self.tls_ca_path) + ) + + def _build_server_config(self): + server_config = { + "http_listen_port": self.tempo_http_server_port, + # we need to specify a grpc server port even if we're not using the grpc server, + # otherwise it will default to 9595 and make promtail bork + "grpc_listen_port": self.tempo_grpc_server_port, + } + if self.tls_ready: + for cfg in ("http_tls_config", "grpc_tls_config"): + server_config[cfg] = { # type: ignore + "cert_file": str(self.tls_cert_path), + "key_file": str(self.tls_key_path), + "client_ca_file": str(self.tls_ca_path), + "client_auth_type": "VerifyClientCertIfGiven", + } + server_config["tls_min_version"] = self._tls_min_version # type: ignore + + return server_config + def generate_config(self, receivers: Sequence[ReceiverProtocol]) -> dict: """Generate the Tempo configuration. Only activate the provided receivers. """ - return { + config = { "auth_enabled": False, - "server": { - "http_listen_port": self.tempo_server_port, - # "grpc_listen_port": self.receiver_ports["tempo_grpc"], - }, + "server": self._build_server_config(), # more configuration information can be found at # https://github.com/open-telemetry/opentelemetry-collector/tree/overlord/receiver "distributor": {"receivers": self._build_receivers_config(receivers)}, @@ -215,6 +274,28 @@ def generate_config(self, receivers: Sequence[ReceiverProtocol]) -> dict: }, } + if self.tls_ready: + # cfr: + # https://grafana.com/docs/tempo/latest/configuration/network/tls/#client-configuration + tls_config = { + "tls_enabled": True, + "tls_cert_path": self.tls_cert_path, + "tls_key_path": self.tls_key_path, + "tls_ca_path": self.tls_ca_path, + # try with fqdn? + "tls_server_name": self._external_hostname, + } + config["ingester_client"] = {"grpc_client_config": tls_config} + config["metrics_generator_client"] = {"grpc_client_config": tls_config} + + # docs say it's `querier.query-frontend` but tempo complains about that + config["querier"] = {"frontend_worker": {"grpc_client_config": tls_config}} + + # this is not an error. + config["memberlist"] = tls_config + + return config + @property def pebble_layer(self) -> Layer: """Generate the pebble layer for the Tempo container.""" @@ -234,14 +315,15 @@ def pebble_layer(self) -> Layer: @property def tempo_ready_layer(self) -> Layer: """Generate the pebble layer to fire the tempo-ready custom notice.""" + s = "s" if self.tls_ready else "" return Layer( { "services": { "tempo-ready": { "override": "replace", "summary": "Notify charm when tempo is ready", - "command": f"""watch -n 5 '[ $(wget -q -O- localhost:{self.tempo_server_port}/ready) = "ready" ] && - ( /charm/bin/pebble notify {self.tempo_ready_notice_key} ) || + "command": f"""watch -n 5 '[ $(wget -q -O- --no-check-certificate http{s}://localhost:{self.tempo_http_server_port}/ready) = "ready" ] && + ( /charm/bin/pebble notify {self.tempo_ready_notice_key} ) || ( echo "tempo not ready" )'""", "startup": "disabled", } @@ -251,10 +333,16 @@ def tempo_ready_layer(self) -> Layer: def is_ready(self): """Whether the tempo built-in readiness check reports 'ready'.""" + if self.tls_ready: + tls, s = f" --cacert {self.server_cert_path}", "s" + else: + tls = s = "" + + # cert is for fqdn/ingress, not for IP + cmd = f"curl{tls} http{s}://{self._external_hostname}:{self.tempo_http_server_port}/ready" + try: - out = getoutput( - f"curl http://{self._local_hostname}:{self.tempo_server_port}/ready" - ).split("\n")[-1] + out = getoutput(cmd).split("\n")[-1] except (CalledProcessError, IndexError): return False return out == "ready" @@ -268,31 +356,42 @@ def _build_receivers_config(self, receivers: Sequence[ReceiverProtocol]): # noq if not receivers_set: logger.warning("No receivers set. Tempo will be up but not functional.") + if self.tls_ready: + receiver_config = { + "tls": { + "ca_file": str(self.tls_ca_path), + "cert_file": str(self.tls_cert_path), + "key_file": str(self.tls_key_path), + "min_version": self._tls_min_version, + } + } + else: + receiver_config = None + config = {} - # TODO: how do we pass the ports into this config? if "zipkin" in receivers_set: - config["zipkin"] = None + config["zipkin"] = receiver_config if "opencensus" in receivers_set: - config["opencensus"] = None + config["opencensus"] = receiver_config otlp_config = {} if "otlp_http" in receivers_set: - otlp_config["http"] = None + otlp_config["http"] = receiver_config if "otlp_grpc" in receivers_set: - otlp_config["grpc"] = None + otlp_config["grpc"] = receiver_config if otlp_config: config["otlp"] = {"protocols": otlp_config} jaeger_config = {} if "jaeger_thrift_http" in receivers_set: - jaeger_config["thrift_http"] = None + jaeger_config["thrift_http"] = receiver_config if "jaeger_grpc" in receivers_set: - jaeger_config["grpc"] = None + jaeger_config["grpc"] = receiver_config if "jaeger_thrift_binary" in receivers_set: - jaeger_config["thrift_binary"] = None + jaeger_config["thrift_binary"] = receiver_config if "jaeger_thrift_compact" in receivers_set: - jaeger_config["thrift_compact"] = None + jaeger_config["thrift_compact"] = receiver_config if jaeger_config: config["jaeger"] = {"protocols": jaeger_config} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d6511f5e..0ff45dc2 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -33,6 +33,8 @@ def tempo_oci_image(ops_test: OpsTest, tempo_metadata): def copy_charm_libs_into_tester_charm(ops_test): """Ensure the tester charm has the libraries it uses.""" libraries = [ + "observability_libs/v1/cert_handler.py", + "tls_certificates_interface/v3/tls_certificates.py", "tempo_k8s/v1/charm_tracing.py", "tempo_k8s/v1/tracing.py", "tempo_k8s/v2/tracing.py", diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py new file mode 100644 index 00000000..ef9c6b02 --- /dev/null +++ b/tests/integration/helpers.py @@ -0,0 +1,159 @@ +import subprocess +from dataclasses import dataclass +from typing import Dict + +import yaml + +_JUJU_DATA_CACHE = {} +_JUJU_KEYS = ("egress-subnets", "ingress-address", "private-address") + + +def purge(data: dict): + for key in _JUJU_KEYS: + if key in data: + del data[key] + + +def get_unit_info(unit_name: str, model: str = None) -> dict: + """Return unit-info data structure. + + for example: + + traefik-k8s/0: + opened-ports: [] + charm: local:focal/traefik-k8s-1 + leader: true + relation-info: + - endpoint: ingress-per-unit + related-endpoint: ingress + application-data: + _supported_versions: '- v1' + related-units: + prometheus-k8s/0: + in-scope: true + data: + egress-subnets: 10.152.183.150/32 + ingress-address: 10.152.183.150 + private-address: 10.152.183.150 + provider-id: traefik-k8s-0 + address: 10.1.232.144 + """ + cmd = f"juju show-unit {unit_name}".split(" ") + if model: + cmd.insert(2, "-m") + cmd.insert(3, model) + + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) + raw_data = proc.stdout.read().decode("utf-8").strip() + + data = yaml.safe_load(raw_data) if raw_data else None + + if not data: + raise ValueError( + f"no unit info could be grabbed for {unit_name}; " + f"are you sure it's a valid unit name?" + f"cmd={' '.join(proc.args)}" + ) + + if unit_name not in data: + raise KeyError(unit_name, f"not in {data!r}") + + unit_data = data[unit_name] + _JUJU_DATA_CACHE[unit_name] = unit_data + return unit_data + + +def get_relation_by_endpoint(relations, local_endpoint, remote_endpoint, remote_obj): + matches = [ + r + for r in relations + if ( + (r["endpoint"] == local_endpoint and r["related-endpoint"] == remote_endpoint) + or (r["endpoint"] == remote_endpoint and r["related-endpoint"] == local_endpoint) + ) + and remote_obj in r["related-units"] + ] + if not matches: + raise ValueError( + f"no matches found with endpoint==" + f"{local_endpoint} " + f"in {remote_obj} (matches={matches})" + ) + if len(matches) > 1: + raise ValueError( + "multiple matches found with endpoint==" + f"{local_endpoint} " + f"in {remote_obj} (matches={matches})" + ) + return matches[0] + + +@dataclass +class UnitRelationData: + unit_name: str + endpoint: str + leader: bool + application_data: Dict[str, str] + unit_data: Dict[str, str] + + +def get_content( + obj: str, other_obj, include_default_juju_keys: bool = False, model: str = None +) -> UnitRelationData: + """Get the content of the databag of `obj`, as seen from `other_obj`.""" + unit_name, endpoint = obj.split(":") + other_unit_name, other_endpoint = other_obj.split(":") + + unit_data, app_data, leader = get_databags( + unit_name, endpoint, other_unit_name, other_endpoint, model + ) + + if not include_default_juju_keys: + purge(unit_data) + + return UnitRelationData(unit_name, endpoint, leader, app_data, unit_data) + + +def get_databags(local_unit, local_endpoint, remote_unit, remote_endpoint, model): + """Get the databags of local unit and its leadership status. + + Given a remote unit and the remote endpoint name. + """ + local_data = get_unit_info(local_unit, model) + leader = local_data["leader"] + + data = get_unit_info(remote_unit, model) + relation_info = data.get("relation-info") + if not relation_info: + raise RuntimeError(f"{remote_unit} has no relations") + + raw_data = get_relation_by_endpoint(relation_info, local_endpoint, remote_endpoint, local_unit) + unit_data = raw_data["related-units"][local_unit]["data"] + app_data = raw_data["application-data"] + return unit_data, app_data, leader + + +@dataclass +class RelationData: + provider: UnitRelationData + requirer: UnitRelationData + + +def get_relation_data( + *, + provider_endpoint: str, + requirer_endpoint: str, + include_default_juju_keys: bool = False, + model: str = None, +): + """Get relation databags for a juju relation. + + >>> get_relation_data('prometheus/0:ingress', 'traefik/1:ingress-per-unit') + """ + provider_data = get_content( + provider_endpoint, requirer_endpoint, include_default_juju_keys, model + ) + requirer_data = get_content( + requirer_endpoint, provider_endpoint, include_default_juju_keys, model + ) + return RelationData(provider=provider_data, requirer=requirer_data) diff --git a/tests/integration/test_ingressed_tls.py b/tests/integration/test_ingressed_tls.py new file mode 100644 index 00000000..dc919916 --- /dev/null +++ b/tests/integration/test_ingressed_tls.py @@ -0,0 +1,180 @@ +import asyncio +import json +import logging +import random +import subprocess +import tempfile +from pathlib import Path + +import pytest +import requests +import yaml +from pytest_operator.plugin import OpsTest +from tempo import Tempo +from tenacity import retry, stop_after_attempt, wait_exponential + +from tests.integration.helpers import get_relation_data + +METADATA = yaml.safe_load(Path("./charmcraft.yaml").read_text()) +APP_NAME = "tempo" +SSC = "self-signed-certificates" +SSC_APP_NAME = "ssc" +TRAEFIK = "traefik-k8s" +TRAEFIK_APP_NAME = "trfk" +TRACEGEN_SCRIPT_PATH = Path() / "scripts" / "tracegen.py" + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="function") +def nonce(): + """Generate an integer nonce for easier trace querying.""" + return str(random.random())[2:] + + +@pytest.fixture(scope="function") +def server_cert(ops_test: OpsTest): + data = get_relation_data( + requirer_endpoint=f"{APP_NAME}/0:certificates", + provider_endpoint=f"{SSC_APP_NAME}/0:certificates", + model=ops_test.model.name, + ) + cert = json.loads(data.provider.application_data["certificates"])[0]["certificate"] + + with tempfile.NamedTemporaryFile() as f: + p = Path(f.name) + p.write_text(cert) + yield p + + +def get_traces(tempo_host: str, nonce, service_name="tracegen"): + req = requests.get( + "https://" + tempo_host + ":3200/api/search", + params={"service.name": service_name, "nonce": nonce}, + verify=False, + ) + assert req.status_code == 200 + return json.loads(req.text)["traces"] + + +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) +async def get_traces_patiently(tempo_host, nonce): + assert get_traces(tempo_host, nonce=nonce) + + +async def get_tempo_host(ops_test: OpsTest): + status = await ops_test.model.get_status() + app = status["applications"][TRAEFIK_APP_NAME] + return app.public_address + + +async def emit_trace( + endpoint, ops_test: OpsTest, nonce, proto: str = "http", verbose=0, use_cert=False +): + """Use juju ssh to run tracegen from the tempo charm; to avoid any DNS issues.""" + cmd = ( + f"juju ssh -m {ops_test.model_name} {APP_NAME}/0 " + f"TRACEGEN_ENDPOINT={endpoint} " + f"TRACEGEN_VERBOSE={verbose} " + f"TRACEGEN_PROTOCOL={proto} " + f"TRACEGEN_CERT={Tempo.server_cert_path if use_cert else ''} " + f"TRACEGEN_NONCE={nonce} " + "python3 tracegen.py" + ) + + return subprocess.getoutput(cmd) + + +@pytest.mark.setup +@pytest.mark.abort_on_fail +async def test_build_and_deploy(ops_test: OpsTest): + tempo_charm = await ops_test.build_charm(".") + resources = { + "tempo-image": METADATA["resources"]["tempo-image"]["upstream-source"], + } + await asyncio.gather( + ops_test.model.deploy(tempo_charm, resources=resources, application_name=APP_NAME), + ops_test.model.deploy(SSC, application_name=SSC_APP_NAME), + ops_test.model.deploy(TRAEFIK, application_name=TRAEFIK_APP_NAME, channel="edge"), + ) + + await asyncio.gather( + ops_test.model.wait_for_idle( + apps=[APP_NAME, SSC_APP_NAME, TRAEFIK_APP_NAME], + status="active", + raise_on_blocked=True, + timeout=10000, + raise_on_error=False, + ), + ) + + +@pytest.mark.setup +@pytest.mark.abort_on_fail +async def test_push_tracegen_script_and_deps(ops_test: OpsTest): + await ops_test.juju("scp", TRACEGEN_SCRIPT_PATH, f"{APP_NAME}/0:tracegen.py") + await ops_test.juju( + "ssh", + f"{APP_NAME}/0", + "python3 -m pip install opentelemetry-exporter-otlp-proto-grpc opentelemetry-exporter-otlp-proto-http", + ) + + +@pytest.mark.setup +@pytest.mark.abort_on_fail +async def test_relate(ops_test: OpsTest): + await ops_test.model.integrate(APP_NAME + ":certificates", SSC_APP_NAME + ":certificates") + await ops_test.model.integrate( + SSC_APP_NAME + ":certificates", TRAEFIK_APP_NAME + ":certificates" + ) + await ops_test.model.integrate(APP_NAME + ":ingress", TRAEFIK_APP_NAME + ":traefik-route") + await ops_test.model.wait_for_idle( + apps=[APP_NAME, SSC_APP_NAME, TRAEFIK_APP_NAME], + status="active", + timeout=1000, + ) + + +@pytest.mark.abort_on_fail +async def test_verify_ingressed_trace_http_upgrades_to_tls(ops_test: OpsTest, nonce): + tempo_host = await get_tempo_host(ops_test) + # IF tempo is related to SSC + # WHEN we emit an http trace, **unsecured** + await emit_trace( + f"http://{tempo_host}:4318/v1/traces", nonce=nonce, ops_test=ops_test + ) # this should fail + # THEN we can verify it's not been ingested + assert get_traces_patiently(tempo_host, nonce=nonce) + + +@pytest.mark.abort_on_fail +async def test_verify_ingressed_trace_http_tls(ops_test: OpsTest, nonce, server_cert): + tempo_host = await get_tempo_host(ops_test) + await emit_trace( + f"https://{tempo_host}:4318/v1/traces", nonce=nonce, ops_test=ops_test, use_cert=True + ) + # THEN we can verify it's been ingested + assert get_traces_patiently(tempo_host, nonce=nonce) + + +@pytest.mark.abort_on_fail +async def test_verify_ingressed_traces_grpc_tls(ops_test: OpsTest, nonce, server_cert): + tempo_host = await get_tempo_host(ops_test) + await emit_trace( + f"{tempo_host}:4317", nonce=nonce, proto="grpc", ops_test=ops_test, use_cert=True + ) + # THEN we can verify it's been ingested + assert get_traces_patiently(tempo_host, nonce=nonce) + + +@pytest.mark.teardown +@pytest.mark.abort_on_fail +async def test_remove_relation(ops_test: OpsTest): + await ops_test.juju( + "remove-relation", APP_NAME + ":certificates", SSC_APP_NAME + ":certificates" + ) + await asyncio.gather( + ops_test.model.wait_for_idle( + apps=[APP_NAME], status="active", raise_on_blocked=True, timeout=1000 + ), + ) diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 94edb84d..d0bcac4f 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) +@pytest.mark.setup @pytest.mark.abort_on_fail async def test_build_and_deploy(ops_test: OpsTest): # Given a fresh build of the charm @@ -71,6 +72,7 @@ async def test_build_and_deploy(ops_test: OpsTest): assert ops_test.model.applications[APP_NAME].units[0].workload_status == "active" +@pytest.mark.setup @pytest.mark.abort_on_fail async def test_relate(ops_test: OpsTest): # given a deployed charm @@ -114,7 +116,6 @@ async def test_verify_traces_http(ops_test: OpsTest): assert found, f"There's no trace of charm exec traces in tempo. {json.dumps(traces, indent=2)}" -@pytest.mark.abort_on_fail async def test_verify_traces_grpc(ops_test: OpsTest): # the tester-grpc charm emits a single grpc trace in its common exit hook # we verify it's there @@ -145,6 +146,7 @@ async def test_verify_traces_grpc(ops_test: OpsTest): ), f"There's no trace of generated grpc traces in tempo. {json.dumps(traces, indent=2)}" +@pytest.mark.teardown @pytest.mark.abort_on_fail async def test_remove_relation(ops_test: OpsTest): # given related charms diff --git a/tests/integration/test_integration_legacy.py b/tests/integration/test_integration_legacy.py index 7bc2a338..2cb49225 100644 --- a/tests/integration/test_integration_legacy.py +++ b/tests/integration/test_integration_legacy.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) +@pytest.mark.setup @pytest.mark.abort_on_fail async def test_build_and_deploy(ops_test: OpsTest): # Given a fresh build of the charm @@ -56,6 +57,7 @@ async def test_build_and_deploy(ops_test: OpsTest): assert ops_test.model.applications[APP_NAME].units[0].workload_status == "active" +@pytest.mark.setup @pytest.mark.abort_on_fail async def test_relate(ops_test: OpsTest): # given a deployed charm @@ -69,7 +71,6 @@ async def test_relate(ops_test: OpsTest): ) -@pytest.mark.abort_on_fail async def test_verify_traces(ops_test: OpsTest): # given a relation between charms # when traces endpoint is queried @@ -100,6 +101,7 @@ async def test_verify_traces(ops_test: OpsTest): assert found, f"There's no trace of charm exec traces in tempo. {json.dumps(traces, indent=2)}" +@pytest.mark.teardown @pytest.mark.abort_on_fail async def test_remove_relation(ops_test: OpsTest): # given related charms diff --git a/tests/integration/test_tls.py b/tests/integration/test_tls.py new file mode 100644 index 00000000..86958e87 --- /dev/null +++ b/tests/integration/test_tls.py @@ -0,0 +1,170 @@ +import asyncio +import json +import logging +import random +import tempfile +from pathlib import Path +from subprocess import getoutput + +import pytest +import requests +import yaml +from pytest_operator.plugin import OpsTest +from tempo import Tempo +from tenacity import retry, stop_after_attempt, wait_exponential + +from tests.integration.helpers import get_relation_data + +METADATA = yaml.safe_load(Path("./charmcraft.yaml").read_text()) +APP_NAME = "tempo" +SSC = "self-signed-certificates" +SSC_APP_NAME = "ssc" +TRACEGEN_SCRIPT_PATH = Path() / "scripts" / "tracegen.py" +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="function") +def nonce(): + """Generate an integer nonce for easier trace querying.""" + return str(random.random())[2:] + + +def get_traces(tempo_host: str, nonce): + url = "https://" + tempo_host + ":3200/api/search" + req = requests.get( + url, + params={"q": f'{{ .nonce = "{nonce}" }}'}, + # it would fail to verify as the cert was issued for fqdn, not IP. + verify=False, + ) + assert req.status_code == 200 + return json.loads(req.text)["traces"] + + +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10)) +async def get_traces_patiently(ops_test, nonce): + assert get_traces(await get_tempo_ip(ops_test), nonce=nonce) + + +async def get_tempo_ip(ops_test: OpsTest): + status = await ops_test.model.get_status() + app = status["applications"][APP_NAME] + return app.public_address + + +async def get_tempo_internal_host(ops_test: OpsTest): + return f"https://{APP_NAME}-0.{APP_NAME}-endpoints.{ops_test.model.name}.svc.cluster.local" + + +@pytest.fixture(scope="function") +def server_cert(ops_test: OpsTest): + data = get_relation_data( + requirer_endpoint=f"{APP_NAME}/0:certificates", + provider_endpoint=f"{SSC_APP_NAME}/0:certificates", + model=ops_test.model.name, + ) + cert = json.loads(data.provider.application_data["certificates"])[0]["certificate"] + + with tempfile.NamedTemporaryFile() as f: + p = Path(f.name) + p.write_text(cert) + yield p + + +async def emit_trace(ops_test: OpsTest, nonce, proto: str = "http", verbose=0, use_cert=False): + """Use juju ssh to run tracegen from the tempo charm; to avoid any DNS issues.""" + hostname = await get_tempo_internal_host(ops_test) + cmd = ( + f"juju ssh -m {ops_test.model_name} {APP_NAME}/0 " + f"TRACEGEN_ENDPOINT={hostname}:4318/v1/traces " + f"TRACEGEN_VERBOSE={verbose} " + f"TRACEGEN_PROTOCOL={proto} " + f"TRACEGEN_CERT={Tempo.server_cert_path if use_cert else ''} " + f"TRACEGEN_NONCE={nonce} " + "python3 tracegen.py" + ) + + return getoutput(cmd) + + +@pytest.mark.setup +@pytest.mark.abort_on_fail +async def test_build_and_deploy(ops_test: OpsTest): + tempo_charm = await ops_test.build_charm(".") + resources = { + "tempo-image": METADATA["resources"]["tempo-image"]["upstream-source"], + } + await asyncio.gather( + ops_test.model.deploy(tempo_charm, resources=resources, application_name=APP_NAME), + ops_test.model.deploy(SSC, application_name=SSC_APP_NAME), + ) + + await asyncio.gather( + ops_test.model.wait_for_idle( + apps=[APP_NAME, SSC_APP_NAME], + status="active", + raise_on_blocked=True, + timeout=10000, + raise_on_error=False, + ), + ) + + +@pytest.mark.setup +@pytest.mark.abort_on_fail +async def test_relate(ops_test: OpsTest): + await ops_test.model.integrate(APP_NAME + ":certificates", SSC_APP_NAME + ":certificates") + await ops_test.model.wait_for_idle( + apps=[APP_NAME, SSC_APP_NAME], + status="active", + timeout=1000, + ) + + +@pytest.mark.setup +@pytest.mark.abort_on_fail +async def test_push_tracegen_script_and_deps(ops_test: OpsTest): + await ops_test.juju("scp", TRACEGEN_SCRIPT_PATH, f"{APP_NAME}/0:tracegen.py") + await ops_test.juju( + "ssh", + f"{APP_NAME}/0", + "python3 -m pip install opentelemetry-exporter-otlp-proto-grpc opentelemetry-exporter-otlp-proto-http", + ) + + +async def test_verify_trace_http_no_tls_fails(ops_test: OpsTest, server_cert, nonce): + # IF tempo is related to SSC + # WHEN we emit an http trace, **unsecured** + await emit_trace(ops_test, nonce=nonce) # this should fail + # THEN we can verify it's not been ingested + tempo_ip = await get_tempo_ip(ops_test) + traces = get_traces(tempo_ip, nonce=nonce) + assert not traces + + +async def test_verify_trace_http_tls(ops_test: OpsTest, nonce, server_cert): + # WHEN we emit a trace secured with TLS + await emit_trace(ops_test, nonce=nonce, use_cert=True) + # THEN we can verify it's eventually ingested + await get_traces_patiently(ops_test, nonce) + + +@pytest.mark.xfail # expected to fail because in this context the grpc receiver is not enabled +async def test_verify_traces_grpc_tls(ops_test: OpsTest, nonce, server_cert): + # WHEN we emit a trace secured with TLS + await emit_trace(ops_test, nonce=nonce, verbose=1, proto="grpc", use_cert=True) + # THEN we can verify it's been ingested + await get_traces_patiently(ops_test, nonce) + + +@pytest.mark.teardown +@pytest.mark.abort_on_fail +async def test_remove_relation(ops_test: OpsTest): + await ops_test.juju( + "remove-relation", APP_NAME + ":certificates", SSC_APP_NAME + ":certificates" + ) + await asyncio.gather( + ops_test.model.wait_for_idle( + apps=[APP_NAME], status="active", raise_on_blocked=True, timeout=1000 + ), + ) diff --git a/tests/integration/tester-grpc/charmcraft.yaml b/tests/integration/tester-grpc/charmcraft.yaml index 7af5815c..234879ef 100644 --- a/tests/integration/tester-grpc/charmcraft.yaml +++ b/tests/integration/tester-grpc/charmcraft.yaml @@ -8,3 +8,10 @@ bases: run-on: - name: "ubuntu" channel: "22.04" +parts: + charm: + charm-binary-python-packages: + - "pydantic>=2" + - "cryptography" + - "jsonschema" + - "opentelemetry-exporter-otlp-proto-http==1.21.0" \ No newline at end of file diff --git a/tests/integration/tester/charmcraft.yaml b/tests/integration/tester/charmcraft.yaml index 7af5815c..234879ef 100644 --- a/tests/integration/tester/charmcraft.yaml +++ b/tests/integration/tester/charmcraft.yaml @@ -8,3 +8,10 @@ bases: run-on: - name: "ubuntu" channel: "22.04" +parts: + charm: + charm-binary-python-packages: + - "pydantic>=2" + - "cryptography" + - "jsonschema" + - "opentelemetry-exporter-otlp-proto-http==1.21.0" \ No newline at end of file diff --git a/tests/interface/conftest.py b/tests/interface/conftest.py index 4941189c..49841b62 100644 --- a/tests/interface/conftest.py +++ b/tests/interface/conftest.py @@ -3,13 +3,12 @@ from unittest.mock import patch import pytest +from charm import TempoCharm from charms.tempo_k8s.v1.charm_tracing import charm_tracing_disabled from interface_tester import InterfaceTester from ops.pebble import Layer from scenario.state import Container, State -from charm import TempoCharm - # Interface tests are centrally hosted at https://github.com/canonical/charm-relation-interfaces. # this fixture is used by the test runner of charm-relation-interfaces to test tempo's compliance diff --git a/tests/scenario/conftest.py b/tests/scenario/conftest.py index 36de650a..726d9e3b 100644 --- a/tests/scenario/conftest.py +++ b/tests/scenario/conftest.py @@ -1,9 +1,8 @@ from unittest.mock import patch import pytest -from scenario import Context - from charm import TempoCharm +from scenario import Context @pytest.fixture diff --git a/tests/scenario/test_charm.py b/tests/scenario/test_charm.py index 2ec9196a..c69a383f 100644 --- a/tests/scenario/test_charm.py +++ b/tests/scenario/test_charm.py @@ -8,7 +8,6 @@ from ops import pebble from scenario import Container, Mount, Relation, State from scenario.sequences import check_builtin_sequences - from tempo import Tempo TEMPO_CHARM_ROOT = Path(__file__).parent.parent.parent @@ -16,7 +15,7 @@ @pytest.fixture(params=(True, False)) def base_state(request): - return State(leader=request.param, containers=[Container("tempo", can_connect=False)]) + return State(leader=request.param, containers=[Container("tempo", can_connect=True)]) def test_builtin_sequences(tempo_charm, base_state): @@ -38,8 +37,12 @@ def test_tempo_restart_on_ingress_v2_changed(context, tmp_path, requested_protoc container = MagicMock() container.can_connect = lambda: True - + # prevent tls_ready from reporting True + container.exists = lambda path: ( + False if path in [Tempo.tls_cert_path, Tempo.tls_key_path, Tempo.tls_ca_path] else True + ) initial_config = Tempo(container).generate_config(["otlp_http"]) + tempo_config.write_text(yaml.safe_dump(initial_config)) tempo = Container( diff --git a/tests/scenario/test_ingressed_tracing.py b/tests/scenario/test_ingressed_tracing.py index 4e231811..c82df609 100644 --- a/tests/scenario/test_ingressed_tracing.py +++ b/tests/scenario/test_ingressed_tracing.py @@ -6,6 +6,7 @@ import yaml from charms.tempo_k8s.v1.charm_tracing import charm_tracing_disabled from scenario import Container, Relation, State +from tempo import Tempo @pytest.fixture @@ -38,23 +39,10 @@ def test_ingress_relation_set_with_dynamic_config(context, base_state): ingress = Relation("ingress", remote_app_data={"external_host": "1.2.3.4", "scheme": "http"}) state = base_state.replace(relations=[ingress]) - out = context.run(getattr(ingress, "joined_event"), state) + with patch.object(Tempo, "is_ready", lambda _: False): + out = context.run(ingress.joined_event, state) expected_rel_data = { - "tcp": { - "routers": { - f"juju-{state.model.name}-tempo-k8s-otlp-grpc": { - "entryPoints": ["otlp-grpc"], - "rule": "ClientIP(`0.0.0.0/0`)", - "service": f"juju-{state.model.name}-tempo-k8s-service-otlp-grpc", - } - }, - "services": { - f"juju-{state.model.name}-tempo-k8s-service-otlp-grpc": { - "loadBalancer": {"servers": [{"address": "1.2.3.4:4317"}]} - } - }, - }, "http": { "routers": { f"juju-{state.model.name}-tempo-k8s-jaeger-thrift-http": { @@ -77,6 +65,16 @@ def test_ingress_relation_set_with_dynamic_config(context, base_state): "rule": "ClientIP(`0.0.0.0/0`)", "service": f"juju-{state.model.name}-tempo-k8s-service-zipkin", }, + f"juju-{state.model.name}-tempo-k8s-otlp-grpc": { + "entryPoints": ["otlp-grpc"], + "rule": "ClientIP(`0.0.0.0/0`)", + "service": f"juju-{state.model.name}-tempo-k8s-service-otlp-grpc", + }, + f"juju-{state.model.name}-tempo-k8s-tempo-grpc": { + "entryPoints": ["tempo-grpc"], + "rule": "ClientIP(`0.0.0.0/0`)", + "service": f"juju-{state.model.name}-tempo-k8s-service-tempo-grpc", + }, }, "services": { f"juju-{state.model.name}-tempo-k8s-service-jaeger-thrift-http": { @@ -91,6 +89,12 @@ def test_ingress_relation_set_with_dynamic_config(context, base_state): f"juju-{state.model.name}-tempo-k8s-service-zipkin": { "loadBalancer": {"servers": [{"url": "http://1.2.3.4:9411"}]} }, + f"juju-{state.model.name}-tempo-k8s-service-otlp-grpc": { + "loadBalancer": {"servers": [{"url": "h2c://1.2.3.4:4317"}]}, + }, + f"juju-{state.model.name}-tempo-k8s-service-tempo-grpc": { + "loadBalancer": {"servers": [{"url": "h2c://1.2.3.4:9096"}]} + }, }, }, } diff --git a/tests/scenario/test_tls.py b/tests/scenario/test_tls.py new file mode 100644 index 00000000..c0d42339 --- /dev/null +++ b/tests/scenario/test_tls.py @@ -0,0 +1,135 @@ +import socket +from unittest.mock import patch + +import pytest +from charm import TempoCharm +from charms.tempo_k8s.v1.charm_tracing import charm_tracing_disabled +from charms.tempo_k8s.v2.tracing import TracingProviderAppData, TracingRequirerAppData +from scenario import Container, Relation, State + + +@pytest.fixture +def base_state(): + return State(leader=True, containers=[Container("tempo", can_connect=False)]) + + +def update_relations_tls_and_verify( + base_state, + context, + has_ingress, + local_has_tls, + local_scheme, + relations, + remote_scheme, + tracing, +): + state = base_state.replace(relations=relations) + with charm_tracing_disabled(), patch.object(TempoCharm, "tls_available", local_has_tls): + out = context.run(tracing.changed_event, state) + tracing_provider_app_data = TracingProviderAppData.load( + out.get_relations(tracing.endpoint)[0].local_app_data + ) + assert tracing_provider_app_data.host == socket.getfqdn() + assert ( + tracing_provider_app_data.external_url + == f"{remote_scheme if has_ingress else local_scheme}://{socket.getfqdn() if not has_ingress else 'foo.com.org'}" + ) + assert tracing_provider_app_data.internal_scheme == local_scheme + return out + + +@pytest.mark.parametrize("remote_has_tls", (True, False)) +@pytest.mark.parametrize("local_has_tls", (True, False)) +@pytest.mark.parametrize("has_ingress", (True, False)) +def test_tracing_endpoints_with_tls( + context, base_state, has_ingress, local_has_tls, remote_has_tls +): + tracing = Relation( + "tracing", + remote_app_data=TracingRequirerAppData(receivers=["otlp_http"]).dump(), + ) + relations = [tracing] + + local_scheme = "https" if local_has_tls else "http" + remote_scheme = "https" if remote_has_tls else "http" + + if has_ingress: + relations.append( + Relation( + "ingress", + remote_app_data={"scheme": remote_scheme, "external_host": "foo.com.org"}, + ) + ) + + update_relations_tls_and_verify( + base_state, + context, + has_ingress, + local_has_tls, + local_scheme, + relations, + remote_scheme, + tracing, + ) + + +@pytest.mark.parametrize("has_ingress", (True, False)) +def test_tracing_endpoints_tls_added_then_removed(context, base_state, has_ingress): + tracing = Relation( + "tracing", + remote_app_data=TracingRequirerAppData(receivers=["otlp_http"]).dump(), + ) + relations = [tracing] + + local_scheme = "http" + remote_scheme = "http" + + if has_ingress: + relations.append( + Relation( + "ingress", + remote_app_data={"scheme": remote_scheme, "external_host": "foo.com.org"}, + ) + ) + + result_state = update_relations_tls_and_verify( + base_state, context, has_ingress, False, local_scheme, relations, remote_scheme, tracing + ) + + # then we check the scenario where TLS gets enabled + + local_scheme = "https" + remote_scheme = "https" + + if has_ingress: + # as remote_scheme changed, we need to update the ingress relation + relations.pop() + relations.append( + Relation( + "ingress", + remote_app_data={"scheme": remote_scheme, "external_host": "foo.com.org"}, + ) + ) + + result_state = update_relations_tls_and_verify( + result_state, context, has_ingress, True, local_scheme, relations, remote_scheme, tracing + ) + + # then we again remove TLS and compare the same thing + + local_scheme = "http" + remote_scheme = "http" + + if has_ingress: + # as remote_scheme changed, we need to update the ingress relation + relations.pop() + relations.append( + Relation( + "ingress", + remote_app_data={"scheme": remote_scheme, "external_host": "foo.com.org"}, + ) + ) + + update_relations_tls_and_verify( + result_state, context, has_ingress, False, local_scheme, relations, remote_scheme, tracing + ) diff --git a/tests/scenario/test_tracing_legacy.py b/tests/scenario/test_tracing_legacy.py index bf00a14f..194f2302 100644 --- a/tests/scenario/test_tracing_legacy.py +++ b/tests/scenario/test_tracing_legacy.py @@ -2,6 +2,7 @@ import socket import pytest +from charm import LEGACY_RECEIVER_PROTOCOLS, TempoCharm from charms.tempo_k8s.v1.charm_tracing import charm_tracing_disabled from charms.tempo_k8s.v1.tracing import ( TracingProviderAppData as TracingProviderAppDataV1, @@ -10,8 +11,6 @@ TracingProviderAppData as TracingProviderAppDataV2, ) from scenario import Container, Relation, State - -from charm import LEGACY_RECEIVER_PROTOCOLS, TempoCharm from tempo import Tempo NO_RECEIVERS = 13 @@ -68,6 +67,7 @@ def test_tracing_v2_endpoint_published(context, evt_name, base_state): "receivers": '[{"protocol": "otlp_http", "port": 4318}]', "host": json.dumps(socket.getfqdn()), "internal_scheme": '"http"', + "external_url": f'"http://{socket.getfqdn()}"', } diff --git a/tests/scenario/test_tracing_requirer.py b/tests/scenario/test_tracing_requirer.py index 5c6bf673..fed8ebc3 100644 --- a/tests/scenario/test_tracing_requirer.py +++ b/tests/scenario/test_tracing_requirer.py @@ -11,7 +11,6 @@ ) from ops import CharmBase, Framework, RelationBrokenEvent, RelationChangedEvent from scenario import Context, Relation, State - from tempo import Tempo diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index 8015564b..13437c09 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -31,6 +31,7 @@ def test_entrypoints_are_generated_with_sanitized_names(self): expected_entrypoints = { "entryPoints": { "tempo-http": {"address": ":3200"}, + "tempo-grpc": {"address": ":9096"}, "zipkin": {"address": ":9411"}, "otlp-grpc": {"address": ":4317"}, "otlp-http": {"address": ":4318"}, diff --git a/tests/unit/test_tempo.py b/tests/unit/test_tempo.py index 100074aa..05e3a675 100644 --- a/tests/unit/test_tempo.py +++ b/tests/unit/test_tempo.py @@ -1,5 +1,6 @@ -import pytest +from unittest.mock import patch +import pytest from tempo import Tempo @@ -44,4 +45,5 @@ ), ) def test_tempo_receivers_config(protocols, expected_config): - assert Tempo(None)._build_receivers_config(protocols) == expected_config + with patch.object(Tempo, "tls_ready", False): + assert Tempo(None)._build_receivers_config(protocols) == expected_config diff --git a/tox.ini b/tox.ini index c2bb6a15..a56427cf 100644 --- a/tox.ini +++ b/tox.ini @@ -26,23 +26,22 @@ passenv = description = Apply coding style standards to code deps = black + ruff isort commands = isort {[vars]all_path} black {[vars]all_path} + ruff check {[vars]all_path} --fix [testenv:lint] description = Check code against coding style standards deps = - # renovate: datasource=pypi - black==23.1.0 - # renovate: datasource=pypi - ruff==0.0.243 - # renovate: datasource=pypi - codespell==2.2.2 + black + ruff + codespell commands = codespell {[vars]all_path} - ruff {[vars]all_path} --fix + ruff check {[vars]all_path} black --check --diff {[vars]all_path} [testenv:unit] @@ -79,6 +78,8 @@ deps = pytest-operator requests -r{toxinidir}/requirements.txt + # tracegen + opentelemetry-exporter-otlp-proto-grpc commands = pytest -v --tb native --log-cli-level=INFO {[vars]tst_path}integration -s {posargs}