Skip to content

Commit

Permalink
Ensure SANs from apiserver cert and refresh certs if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
HomayoonAlimohammadi committed Jan 17, 2025
1 parent 8f72caa commit f2af08d
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 20 deletions.
92 changes: 92 additions & 0 deletions charms/worker/k8s/lib/charms/k8s/v0/k8sd_api_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class utilises different connection factories (UnixSocketConnectionFactory
import logging
import socket
from contextlib import contextmanager
from datetime import datetime
from http.client import HTTPConnection, HTTPException
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar

Expand Down Expand Up @@ -598,6 +599,66 @@ class GetKubeConfigResponse(BaseRequestModel):
metadata: KubeConfigMetadata


class RefreshCertificatesPlanMetadata(BaseModel, allow_population_by_field_name=True):
"""Metadata for the certificates plan response.
Attributes:
seed (int): The seed for the new certificates.
certificate_signing_requests (Optional[list[str]]): List of names
of the CertificateSigningRequests that need to be signed externally (for worker nodes).
"""

seed: int
certificate_signing_requests: Optional[list[str]] = Field(
default=None, alias="certificate-signing-requests"
)


class RefreshCertificatesPlanResponse(BaseRequestModel):
"""Response model for the refresh certificates plan.
Attributes:
metadata (RefreshCertificatesPlanMetadata): Metadata for the certificates plan response.
"""

metadata: RefreshCertificatesPlanMetadata


class RefreshCertificatesRunRequest(BaseModel, allow_population_by_field_name=True):
"""Request model for running the refresh certificates run.
Attributes:
seed (int): The seed for the new certificates from plan response.
expiration_seconds (int): The duration of the new certificates.
extra_sans (list[str]): List of extra sans for the new certificates.
"""

seed: int
expiration_seconds: int = Field(alias="expiration-seconds")
extra_sans: Optional[list[str]] = Field(alias="extra-sans")


class RefreshCertificatesRunMetadata(BaseModel, allow_population_by_field_name=True):
"""Metadata for RefreshCertificatesRunResponse.
Attributes:
expiration_seconds (int): The duration of the new certificates
(might not match the requested value).
"""

expiration_seconds: int = Field(alias="expiration-seconds")


class RefreshCertificatesRunResponse(BaseRequestModel):
"""Response model for the refresh certificates run.
Attributes:
metadata (RefreshCertificatesRunMetadata): Metadata for the certificates run response.
"""

metadata: RefreshCertificatesRunMetadata


T = TypeVar("T", bound=BaseRequestModel)


Expand Down Expand Up @@ -920,3 +981,34 @@ def get_kubeconfig(self, server: Optional[str]) -> str:
body = {"server": server or ""}
response = self._send_request(endpoint, "GET", GetKubeConfigResponse, body)
return response.metadata.kubeconfig

def refresh_certs(
self, extra_sans: list[str], expiration_seconds: Optional[int] = None
) -> None:
"""Refresh the certificates for the cluster.
Args:
extra_sans (list[str]): List of extra SANs for the certificates.
expiration_seconds (Optional[int]): The duration of the new certificates.
"""
plan_endpoint = "/1.0/k8sd/refresh-certs/plan"
plan_resp = self._send_request(plan_endpoint, "POST", RefreshCertificatesPlanResponse, {})

# NOTE(Hue): Default certificate expiration is set to 20 years:
# https://github.com/canonical/k8s-snap/blob/32e35128394c0880bcc4ce87447f4247cc315ba5/src/k8s/pkg/k8sd/app/hooks_bootstrap.go#L331-L338
if expiration_seconds is None:
now = datetime.now()
twenty_years_later = datetime(
now.year + 20, now.month, now.day, now.hour, now.minute, now.second
)
expiration_seconds = int((twenty_years_later - now).total_seconds())

run_endpoint = "/1.0/k8sd/refresh-certs/run"
run_req = RefreshCertificatesRunRequest( # type: ignore
seed=plan_resp.metadata.seed,
expiration_seconds=expiration_seconds,
extra_sans=extra_sans,
)

body = run_req.dict(exclude_none=True, by_alias=True)
self._send_request(run_endpoint, "POST", RefreshCertificatesRunResponse, body)
1 change: 1 addition & 0 deletions charms/worker/k8s/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ poetry-core==1.9.1
lightkube==0.16.0
httpx==0.27.2
loadbalancer_interface == 1.2.0
cryptography==44.0.0
87 changes: 70 additions & 17 deletions charms/worker/k8s/src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
)
from loadbalancer_interface import LBProvider
from ops.interface_kube_control import KubeControlProvides
from pki import extract_sans_from_cert, get_api_server_cert
from pydantic import SecretStr
from snap import management as snap_management
from snap import version as snap_version
Expand Down Expand Up @@ -369,18 +370,29 @@ def _check_k8sd_ready(self):
self.api_manager.check_k8sd_ready()

def _get_extra_sans(self):
"""Retrieve the certificate extra SANs.
Raises:
ReconcilerError: If the public address cannot be retrieved.
"""
"""Retrieve the certificate extra SANs."""
# Get the extra SANs from the configuration
extra_sans_str = str(self.config.get("kube-apiserver-extra-sans") or "")
extra_sans = {san for san in extra_sans_str.strip().split() if san}
if public_address := self._get_public_address():
log.info("Public address %s found, adding it to extra SANs", public_address)
extra_sans.add(public_address)
else:
raise ReconcilerError("Failed to get public address")
extra_sans = set(extra_sans_str.strip().split())

# Add the ingress addresses of all units
extra_sans.add(_get_juju_public_address())
binding = self.model.get_binding(CLUSTER_RELATION)
try:
addresses = binding and binding.network.ingress_addresses
if addresses:
for addr in addresses:
extra_sans.add(str(addr))
except ops.RelationNotFoundError as e:
log.error(f"Failed to get ingress addresses for extra SANs: {e}")

# Add the external load balancer address
if self.is_control_plane and self.external_load_balancer.is_available:
if external_lb_addr := self._get_external_load_balancer_address():
extra_sans.add(external_lb_addr)
else:
log.warning("Failed to get external load balancer address for extra SANs")

return sorted(extra_sans)

def _assemble_bootstrap_config(self):
Expand Down Expand Up @@ -949,6 +961,7 @@ def _reconcile(self, event: ops.EventBase):
if self.is_control_plane:
self._copy_internal_kubeconfig()
self._expose_ports()
self._ensure_sans()

def _evaluate_removal(self, event: ops.EventBase) -> bool:
"""Determine if my unit is being removed.
Expand Down Expand Up @@ -1107,15 +1120,18 @@ def _get_external_kubeconfig(self, event: ops.ActionEvent):
server = event.params.get("server")
if not server:
log.info("No server requested, use public address")

server = self._get_public_address()
if not server:
event.fail("Failed to get public address. Check logs for details.")
return
log.info("Found public address: %s", server)
port = str(APISERVER_PORT)
if self.is_control_plane and self.external_load_balancer.is_available:
log.info("Using external load balancer port as the public port")
port = str(EXTERNAL_LOAD_BALANCER_PORT)

port = (
str(EXTERNAL_LOAD_BALANCER_PORT)
if self.external_load_balancer.is_available
else str(APISERVER_PORT)
)

server = build_url(server, port, "https")
log.info("Formatted server address: %s", server)
log.info("Requesting kubeconfig for server=%s", server)
Expand All @@ -1125,7 +1141,12 @@ def _get_external_kubeconfig(self, event: ops.ActionEvent):
event.fail(f"Failed to retrieve kubeconfig: {e}")

def _get_public_address(self) -> Optional[str]:
"""Get public address either from external load balancer or from juju.
"""Get the most public address either from external load balancer or from juju.
If the external load balancer is available and the unit is a control-plane unit,
the external load balancer address will be used. Otherwise, the juju public address
will be used.
NOTE: Don't ignore the unit's IP in the extra SANs just because there's a load balancer.
Returns:
str: public ip address of the unit
Expand Down Expand Up @@ -1162,6 +1183,38 @@ def _get_external_load_balancer_address(self) -> Optional[str]:

return response.address

@on_error(
ops.WaitingStatus("Ensuring SANs are up-to-date"),
InvalidResponseError,
K8sdConnectionError,
)
def _ensure_sans(self):
"""Ensure the extra SANs are up-to-date.
This method checks if the extra SANs are present in the API server certificate.
If they are not, the certificates are refreshed with the new SANs.
"""
if not self.is_control_plane:
return

extra_sans = self._get_extra_sans()
if not extra_sans:
log.info("No extra SANs to update")
return

dns_sans, ip_addresses = extract_sans_from_cert(get_api_server_cert())
ip_addresses = [str(ip) for ip in ip_addresses]
all_cert_sans = dns_sans + ip_addresses

for san in extra_sans:
if san not in all_cert_sans:
log.info(f"{san} not in cert SANs, refreshing certs with new SANs: {extra_sans}")
status.add(ops.MaintenanceStatus("Refreshing Certificates"))
self.api_manager.refresh_certs(extra_sans)
log.info("Certificates have been refreshed")

log.info("Extra SANs are up-to-date")


if __name__ == "__main__": # pragma: nocover
ops.main(K8sCharm)
4 changes: 2 additions & 2 deletions charms/worker/k8s/src/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import re
from ipaddress import ip_address
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunsplit

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,4 +89,4 @@ def build_url(addr: str, new_port: str, new_scheme: str) -> str:
if is_ipv6:
ip = f"[{ip}]"

return f"{new_scheme}://{ip}:{new_port}"
return urlunsplit((new_scheme, f"{ip}:{new_port}", "", "", ""))
2 changes: 2 additions & 0 deletions charms/worker/k8s/src/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# Charm
CONTAINERD_BASE_PATH = Path("/etc/containerd")
ETC_KUBERNETES = Path("/etc/kubernetes")
PKI_DIR = ETC_KUBERNETES / "pki"
APISERVER_CERT = PKI_DIR / "apiserver.crt"
HOSTSD_PATH = CONTAINERD_BASE_PATH / "hosts.d/"
KUBECONFIG = Path.home() / ".kube/config"
KUBECTL_PATH = Path("/snap/k8s/current/bin/kubectl")
Expand Down
67 changes: 67 additions & 0 deletions charms/worker/k8s/src/pki.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

# Copyright 2025 Canonical Ltd.
# See LICENSE file for licensing details.

# Learn more at: https://juju.is/docs/sdk

"""A module providing PKI related functionalities."""

import ipaddress
import logging
import os
import typing

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.x509.extensions import ExtensionNotFound
from literals import APISERVER_CERT

_IPAddressTypes = typing.Union[
ipaddress.IPv4Address,
ipaddress.IPv6Address,
ipaddress.IPv4Network,
ipaddress.IPv6Network,
]

log = logging.getLogger(__name__)


def get_api_server_cert() -> x509.Certificate:
"""Retrieve the API server certificate from the specified file path.
Returns:
`x509.Certificate`: The certificate object.
Raises:
FileNotFoundError: If the certificate file does not exist.
"""
if not os.path.exists(APISERVER_CERT):
raise FileNotFoundError(f"Certificate file not found: {APISERVER_CERT}")

with open(APISERVER_CERT, "rb") as f:
cert_data = f.read()

return x509.load_pem_x509_certificate(cert_data, default_backend())


def extract_sans_from_cert(cert: x509.Certificate) -> tuple[list[str], list[_IPAddressTypes]]:
"""Extract the Subject Alternative Name (SAN) extension from the certificate.
Args:
cert (`x509.Certificate`): The certificate to extract the SAN extension from.
Returns:
`tuple[list[str], list[_IPAddressTypes]]`: A tuple containing the DNS names
and IP addresses extracted from the SAN extension.
"""
dns_names: list[str] = []
ip_addresses: list[_IPAddressTypes] = []
try:
san_extension = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
dns_names = san_extension.get_values_for_type(x509.DNSName)
ip_addresses = san_extension.get_values_for_type(x509.IPAddress)
except ExtensionNotFound as e:
log.warning(f"Subject Alternative Name (SAN) extension not found in the certificate: {e}")

return dns_names, ip_addresses
1 change: 1 addition & 0 deletions charms/worker/k8s/terraform/outputs.tf
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ output "requires" {
etcd = "etcd"
external_cloud_provider = "external-cloud-provider"
gcp = "gcp"
external_load_balancer = "external-load-balancer"
}
}

Expand Down
3 changes: 2 additions & 1 deletion charms/worker/k8s/tests/unit/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def mock_reconciler_handlers(harness):
"_ensure_cluster_config",
"_expose_ports",
"_announce_kubernetes_version",
"_ensure_sans",
}

mocked = [mock.patch(f"charm.K8sCharm.{name}") for name in handler_names]
Expand Down Expand Up @@ -198,7 +199,7 @@ def test_configure_datastore_runtime_config_etcd(harness):
assert uccr_config.datastore.type == "external"


def test_configure_boostrap_extra_sans(harness):
def test_configure_bootstrap_extra_sans(harness):
"""Test configuring kube-apiserver-extra-sans on bootstrap.
Args:
Expand Down
Loading

0 comments on commit f2af08d

Please sign in to comment.