Skip to content

Commit

Permalink
fix: patching wrong environment variables when resuming (#1923)
Browse files Browse the repository at this point in the history
Fixes #1921. Reported by a user. The wrong environment variable was
patches when the session was hibernated and the access tokens were
expired.
  • Loading branch information
olevski authored Jul 1, 2024
1 parent 625a0eb commit 1c9204b
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 68 deletions.
66 changes: 27 additions & 39 deletions renku_notebooks/api/classes/k8s_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,8 @@ def patch_image_pull_secret(self, server_name: str, gitlab_token: GitlabToken):
patch,
)

def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"""Patch the Renku and Gitlab access tokens that are used in the session statefulset."""
try:
sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace)
except ApiException as err:
if err.status == 404:
# NOTE: It can happen potentially that another request or something else
# deleted the session as this request was going on, in this case we ignore
# the missing statefulset
return
raise

@staticmethod
def _get_statefulset_token_patches(sts: client.V1StatefulSet, renku_tokens: RenkuTokens) -> list[dict[str, str]]:
containers: list[V1Container] = sts.spec.template.spec.containers
init_containers: list[V1Container] = sts.spec.template.spec.init_containers

Expand All @@ -266,15 +256,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
(None, None),
)
git_clone_container_index, git_clone_container = next(
((i, c) for i, c in enumerate(init_containers) if c.name == "git-proxy"),
((i, c) for i, c in enumerate(init_containers) if c.name == "git-clone"),
(None, None),
)
secrets_container_index, secrets_container = next(
(
(i, c)
for i, c in enumerate(init_containers)
if c.name == "init-user-secrets"
),
((i, c) for i, c in enumerate(init_containers) if c.name == "init-user-secrets"),
(None, None),
)

Expand All @@ -294,16 +280,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
else None
)
secrets_renku_access_token_env = (
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN")
if secrets_container is not None
else None
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") if secrets_container is not None else None
)

patches = list()
if (
git_proxy_container_index is not None
and git_proxy_renku_access_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -314,10 +295,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.access_token,
}
)
if (
git_proxy_container_index is not None
and git_proxy_renku_refresh_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_refresh_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -328,35 +306,45 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.refresh_token,
},
)
if (
git_clone_container_index is not None
and git_clone_renku_access_token_env is not None
):
if git_clone_container_index is not None and git_clone_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
"path": (
f"/spec/template/spec/containers/{git_clone_container_index}"
f"/spec/template/spec/initContainers/{git_clone_container_index}"
f"/env/{git_clone_renku_access_token_env[0]}/value"
),
"value": renku_tokens.access_token,
},
)
if (
secrets_container_index is not None
and secrets_renku_access_token_env is not None
):
if secrets_container_index is not None and secrets_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
"path": (
f"/spec/template/spec/containers/{secrets_container_index}"
f"/spec/template/spec/initContainers/{secrets_container_index}"
f"/env/{secrets_renku_access_token_env[0]}/value"
),
"value": renku_tokens.access_token,
},
)

return patches

def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"""Patch the Renku and Gitlab access tokens that are used in the session statefulset."""
try:
sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace)
except ApiException as err:
if err.status == 404:
# NOTE: It can happen potentially that another request or something else
# deleted the session as this request was going on, in this case we ignore
# the missing statefulset
return
raise

patches = self._get_statefulset_token_patches(sts, renku_tokens)

if not patches:
return

Expand Down
32 changes: 10 additions & 22 deletions renku_notebooks/util/kubernetes_.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any

import escapism
from kubernetes.client import V1Container
from kubernetes.client import V1Container, V1EnvVarSource


def filter_resources_by_annotations(
Expand All @@ -37,10 +37,7 @@ def filter_resources_by_annotations(
def filter_resource(resource):
res = []
for annotation_name in annotations:
res.append(
resource["metadata"]["annotations"].get(annotation_name)
== annotations[annotation_name]
)
res.append(resource["metadata"]["annotations"].get(annotation_name) == annotations[annotation_name])
if len(res) == 0:
return True
else:
Expand All @@ -49,16 +46,12 @@ def filter_resource(resource):
return list(filter(filter_resource, resources))


def renku_1_make_server_name(
safe_username: str, namespace: str, project: str, branch: str, commit_sha: str
) -> str:
def renku_1_make_server_name(safe_username: str, namespace: str, project: str, branch: str, commit_sha: str) -> str:
"""Form a unique server name for Renku 1.0 sessions.
This is used in naming all the k8s resources created by amalthea.
"""
server_string_for_hashing = (
f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}"
)
server_string_for_hashing = f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}"
server_hash = md5(server_string_for_hashing.encode()).hexdigest().lower()
prefix = _make_server_name_prefix(safe_username)
# NOTE: A K8s object name can only contain lowercase alphanumeric characters, hyphens, or dots.
Expand All @@ -75,9 +68,7 @@ def renku_1_make_server_name(
)


def renku_2_make_server_name(
safe_username: str, project_id: str, launcher_id: str
) -> str:
def renku_2_make_server_name(safe_username: str, project_id: str, launcher_id: str) -> str:
"""Form a unique server name for Renku 2.0 sessions.
This is used in naming all the k8s resources created by amalthea.
Expand All @@ -95,7 +86,7 @@ def renku_2_make_server_name(
return f"{prefix[:12]}-renku-2-{server_hash[:21]}"


def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | None:
def find_env_var(container: V1Container, env_name: str) -> tuple[int, str | V1EnvVarSource] | None:
"""Find the index and value of a specific environment variable by name from a Kubernetes container."""
env_var = next(
filter(
Expand All @@ -108,16 +99,15 @@ def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | Non
return None
ind = env_var[0]
val = env_var[1].value
if val is None:
val = env_var[1].value_from
return ind, val


def _make_server_name_prefix(safe_username: str):
safe_username_lowercase = safe_username.lower()
prefix = ""
if (
not safe_username_lowercase[0].isalpha()
or not safe_username_lowercase[0].isascii()
):
if not safe_username_lowercase[0].isalpha() or not safe_username_lowercase[0].isascii():
# NOTE: Username starts with an invalid character. This has to be modified because a
# k8s service object cannot start with anything other than a lowercase alphabet character.
# NOTE: We do not have worry about collisions with already existing servers from older
Expand All @@ -130,9 +120,7 @@ def _make_server_name_prefix(safe_username: str):
return prefix


def find_container(
patches: list[dict[str, Any]], container_name: str
) -> dict[str, Any] | None:
def find_container(patches: list[dict[str, Any]], container_name: str) -> dict[str, Any] | None:
"""Find the json patch corresponding a given container."""
for patch_obj in patches:
inner_patches = patch_obj.get("patch", [])
Expand Down
101 changes: 94 additions & 7 deletions tests/unit/test_k8s_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import pytest

from kubernetes.client import (
V1Container,
V1EnvVar,
V1EnvVarSource,
V1LabelSelector,
V1PodSpec,
V1PodTemplateSpec,
V1StatefulSet,
V1StatefulSetSpec,
)

from renku_notebooks.api.classes.auth import RenkuTokens
from renku_notebooks.api.classes.k8s_client import JsServerCache, K8sClient, NamespacedK8sClient
from renku_notebooks.errors.intermittent import JSCacheError
from renku_notebooks.errors.programming import ProgrammingError
from renku_notebooks.util.kubernetes_ import find_env_var


@pytest.fixture
Expand Down Expand Up @@ -37,9 +49,7 @@ def test_list_cache_preference(mock_server_cache, mock_namespaced_client):
renku_ns_client = mock_namespaced_client("renku")
sessions_ns_client = mock_namespaced_client("renku-sessions")
sample_server_manifest = {"metadata": {"labels": {"username": "username"}, "name": "server1"}}
sample_server_manifest_preferred = {
"metadata": {"labels": {"username": "username"}, "name": "preferred"}
}
sample_server_manifest_preferred = {"metadata": {"labels": {"username": "username"}, "name": "preferred"}}
mock_server_cache.list_servers.return_value = [sample_server_manifest_preferred]
renku_ns_client.list_servers.return_value = []
sessions_ns_client.list_servers.return_value = [sample_server_manifest]
Expand Down Expand Up @@ -86,9 +96,7 @@ def test_get_two_results_raises_error(mock_server_cache, mock_namespaced_client)
def test_get_cache_is_preferred(mock_server_cache, mock_namespaced_client):
renku_ns_client = mock_namespaced_client("renku")
sessions_ns_client = mock_namespaced_client("renku-sessions")
sample_server_manifest_cache = {
"metadata": {"labels": {"username": "username"}, "name": "server"}
}
sample_server_manifest_cache = {"metadata": {"labels": {"username": "username"}, "name": "server"}}
sample_server_manifest_non_cache = {
"metadata": {
"labels": {"username": "username", "not_preferred": True},
Expand All @@ -112,3 +120,82 @@ def test_get_server_no_match(mock_server_cache, mock_namespaced_client):
client = K8sClient(mock_server_cache, renku_ns_client, "username", sessions_ns_client)
server = client.get_server("server", "username")
assert server is None


def test_find_env_var():
container = V1Container(
name="test", env=[V1EnvVar(name="key1", value="val1"), V1EnvVar(name="key2", value_from=V1EnvVarSource())]
)
assert find_env_var(container, "key1") == (0, "val1")
assert find_env_var(container, "key2") == (1, V1EnvVarSource())
assert find_env_var(container, "missing") is None


def test_patch_statefulset_tokens():
git_clone_access_env = "GIT_CLONE_USER__RENKU_TOKEN"
git_proxy_access_env = "GIT_PROXY_RENKU_ACCESS_TOKEN"
git_proxy_refresh_env = "GIT_PROXY_RENKU_REFRESH_TOKEN"
secrets_access_env = "RENKU_ACCESS_TOKEN"
git_clone = V1Container(
name="git-clone",
env=[
V1EnvVar(name="test", value="value"),
V1EnvVar(git_clone_access_env, "old_value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
],
)
git_proxy = V1Container(
name="git-proxy",
env=[
V1EnvVar(name="test", value="value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
V1EnvVar(git_proxy_refresh_env, "old_value"),
V1EnvVar(git_proxy_access_env, "old_value"),
],
)
secrets = V1Container(
name="init-user-secrets",
env=[
V1EnvVar(secrets_access_env, "old_value"),
V1EnvVar(name="test", value="value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
],
)
random1 = V1Container(name="random1")
random2 = V1Container(
name="random2",
env=[
V1EnvVar(name="test", value="value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
],
)

new_renku_tokens = RenkuTokens(access_token="new_renku_access_token", refresh_token="new_renku_refresh_token")

sts = V1StatefulSet(
spec=V1StatefulSetSpec(
service_name="test",
selector=V1LabelSelector(),
template=V1PodTemplateSpec(
spec=V1PodSpec(
containers=[git_proxy, random1, random2], init_containers=[git_clone, random1, secrets, random2]
)
),
)
)
patches = NamespacedK8sClient._get_statefulset_token_patches(sts, new_renku_tokens)

# Order of patches should be git proxy access, git proxy refresh, git clone, secrets
assert len(patches) == 4
# Git proxy access token
assert patches[0]["path"] == "/spec/template/spec/containers/0/env/3/value"
assert patches[0]["value"] == new_renku_tokens.access_token
# Git proxy refresh token
assert patches[1]["path"] == "/spec/template/spec/containers/0/env/2/value"
assert patches[1]["value"] == new_renku_tokens.refresh_token
# Git clone
assert patches[2]["path"] == "/spec/template/spec/initContainers/0/env/1/value"
assert patches[2]["value"] == new_renku_tokens.access_token
# Secrets init
assert patches[3]["path"] == "/spec/template/spec/initContainers/2/env/0/value"
assert patches[3]["value"] == new_renku_tokens.access_token

0 comments on commit 1c9204b

Please sign in to comment.