Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patching wrong environment variables when resuming #1923

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 27 additions & 39 deletions renku_notebooks/api/classes/k8s_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,8 @@ def patch_image_pull_secret(self, server_name: str, gitlab_token: GitlabToken):
patch,
)

def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"""Patch the Renku and Gitlab access tokens that are used in the session statefulset."""
try:
sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace)
except ApiException as err:
if err.status == 404:
# NOTE: It can happen potentially that another request or something else
# deleted the session as this request was going on, in this case we ignore
# the missing statefulset
return
raise

@staticmethod
def _get_statefulset_token_patches(sts: client.V1StatefulSet, renku_tokens: RenkuTokens) -> list[dict[str, str]]:
containers: list[V1Container] = sts.spec.template.spec.containers
init_containers: list[V1Container] = sts.spec.template.spec.init_containers

Expand All @@ -266,15 +256,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
(None, None),
)
git_clone_container_index, git_clone_container = next(
((i, c) for i, c in enumerate(init_containers) if c.name == "git-proxy"),
((i, c) for i, c in enumerate(init_containers) if c.name == "git-clone"),
(None, None),
)
secrets_container_index, secrets_container = next(
(
(i, c)
for i, c in enumerate(init_containers)
if c.name == "init-user-secrets"
),
((i, c) for i, c in enumerate(init_containers) if c.name == "init-user-secrets"),
(None, None),
)

Expand All @@ -294,16 +280,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
else None
)
secrets_renku_access_token_env = (
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN")
if secrets_container is not None
else None
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") if secrets_container is not None else None
)

patches = list()
if (
git_proxy_container_index is not None
and git_proxy_renku_access_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -314,10 +295,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.access_token,
}
)
if (
git_proxy_container_index is not None
and git_proxy_renku_refresh_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_refresh_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -328,35 +306,45 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.refresh_token,
},
)
if (
git_clone_container_index is not None
and git_clone_renku_access_token_env is not None
):
if git_clone_container_index is not None and git_clone_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
"path": (
f"/spec/template/spec/containers/{git_clone_container_index}"
f"/spec/template/spec/initContainers/{git_clone_container_index}"
f"/env/{git_clone_renku_access_token_env[0]}/value"
),
"value": renku_tokens.access_token,
},
)
if (
secrets_container_index is not None
and secrets_renku_access_token_env is not None
):
if secrets_container_index is not None and secrets_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
"path": (
f"/spec/template/spec/containers/{secrets_container_index}"
f"/spec/template/spec/initContainers/{secrets_container_index}"
f"/env/{secrets_renku_access_token_env[0]}/value"
),
"value": renku_tokens.access_token,
},
)

return patches

def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"""Patch the Renku and Gitlab access tokens that are used in the session statefulset."""
try:
sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace)
except ApiException as err:
if err.status == 404:
# NOTE: It can happen potentially that another request or something else
# deleted the session as this request was going on, in this case we ignore
# the missing statefulset
return
raise

patches = self._get_statefulset_token_patches(sts, renku_tokens)

if not patches:
return

Expand Down
32 changes: 10 additions & 22 deletions renku_notebooks/util/kubernetes_.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Any

import escapism
from kubernetes.client import V1Container
from kubernetes.client import V1Container, V1EnvVarSource


def filter_resources_by_annotations(
Expand All @@ -37,10 +37,7 @@ def filter_resources_by_annotations(
def filter_resource(resource):
res = []
for annotation_name in annotations:
res.append(
resource["metadata"]["annotations"].get(annotation_name)
== annotations[annotation_name]
)
res.append(resource["metadata"]["annotations"].get(annotation_name) == annotations[annotation_name])
if len(res) == 0:
return True
else:
Expand All @@ -49,16 +46,12 @@ def filter_resource(resource):
return list(filter(filter_resource, resources))


def renku_1_make_server_name(
safe_username: str, namespace: str, project: str, branch: str, commit_sha: str
) -> str:
def renku_1_make_server_name(safe_username: str, namespace: str, project: str, branch: str, commit_sha: str) -> str:
"""Form a unique server name for Renku 1.0 sessions.

This is used in naming all the k8s resources created by amalthea.
"""
server_string_for_hashing = (
f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}"
)
server_string_for_hashing = f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}"
server_hash = md5(server_string_for_hashing.encode()).hexdigest().lower()
prefix = _make_server_name_prefix(safe_username)
# NOTE: A K8s object name can only contain lowercase alphanumeric characters, hyphens, or dots.
Expand All @@ -75,9 +68,7 @@ def renku_1_make_server_name(
)


def renku_2_make_server_name(
safe_username: str, project_id: str, launcher_id: str
) -> str:
def renku_2_make_server_name(safe_username: str, project_id: str, launcher_id: str) -> str:
"""Form a unique server name for Renku 2.0 sessions.

This is used in naming all the k8s resources created by amalthea.
Expand All @@ -95,7 +86,7 @@ def renku_2_make_server_name(
return f"{prefix[:12]}-renku-2-{server_hash[:21]}"


def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | None:
def find_env_var(container: V1Container, env_name: str) -> tuple[int, str | V1EnvVarSource] | None:
"""Find the index and value of a specific environment variable by name from a Kubernetes container."""
env_var = next(
filter(
Expand All @@ -108,16 +99,15 @@ def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | Non
return None
ind = env_var[0]
val = env_var[1].value
if val is None:
val = env_var[1].value_from
return ind, val


def _make_server_name_prefix(safe_username: str):
safe_username_lowercase = safe_username.lower()
prefix = ""
if (
not safe_username_lowercase[0].isalpha()
or not safe_username_lowercase[0].isascii()
):
if not safe_username_lowercase[0].isalpha() or not safe_username_lowercase[0].isascii():
# NOTE: Username starts with an invalid character. This has to be modified because a
# k8s service object cannot start with anything other than a lowercase alphabet character.
# NOTE: We do not have worry about collisions with already existing servers from older
Expand All @@ -130,9 +120,7 @@ def _make_server_name_prefix(safe_username: str):
return prefix


def find_container(
patches: list[dict[str, Any]], container_name: str
) -> dict[str, Any] | None:
def find_container(patches: list[dict[str, Any]], container_name: str) -> dict[str, Any] | None:
"""Find the json patch corresponding a given container."""
for patch_obj in patches:
inner_patches = patch_obj.get("patch", [])
Expand Down
101 changes: 94 additions & 7 deletions tests/unit/test_k8s_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import pytest

from kubernetes.client import (
V1Container,
V1EnvVar,
V1EnvVarSource,
V1LabelSelector,
V1PodSpec,
V1PodTemplateSpec,
V1StatefulSet,
V1StatefulSetSpec,
)

from renku_notebooks.api.classes.auth import RenkuTokens
from renku_notebooks.api.classes.k8s_client import JsServerCache, K8sClient, NamespacedK8sClient
from renku_notebooks.errors.intermittent import JSCacheError
from renku_notebooks.errors.programming import ProgrammingError
from renku_notebooks.util.kubernetes_ import find_env_var


@pytest.fixture
Expand Down Expand Up @@ -37,9 +49,7 @@ def test_list_cache_preference(mock_server_cache, mock_namespaced_client):
renku_ns_client = mock_namespaced_client("renku")
sessions_ns_client = mock_namespaced_client("renku-sessions")
sample_server_manifest = {"metadata": {"labels": {"username": "username"}, "name": "server1"}}
sample_server_manifest_preferred = {
"metadata": {"labels": {"username": "username"}, "name": "preferred"}
}
sample_server_manifest_preferred = {"metadata": {"labels": {"username": "username"}, "name": "preferred"}}
mock_server_cache.list_servers.return_value = [sample_server_manifest_preferred]
renku_ns_client.list_servers.return_value = []
sessions_ns_client.list_servers.return_value = [sample_server_manifest]
Expand Down Expand Up @@ -86,9 +96,7 @@ def test_get_two_results_raises_error(mock_server_cache, mock_namespaced_client)
def test_get_cache_is_preferred(mock_server_cache, mock_namespaced_client):
renku_ns_client = mock_namespaced_client("renku")
sessions_ns_client = mock_namespaced_client("renku-sessions")
sample_server_manifest_cache = {
"metadata": {"labels": {"username": "username"}, "name": "server"}
}
sample_server_manifest_cache = {"metadata": {"labels": {"username": "username"}, "name": "server"}}
sample_server_manifest_non_cache = {
"metadata": {
"labels": {"username": "username", "not_preferred": True},
Expand All @@ -112,3 +120,82 @@ def test_get_server_no_match(mock_server_cache, mock_namespaced_client):
client = K8sClient(mock_server_cache, renku_ns_client, "username", sessions_ns_client)
server = client.get_server("server", "username")
assert server is None


def test_find_env_var():
container = V1Container(
name="test", env=[V1EnvVar(name="key1", value="val1"), V1EnvVar(name="key2", value_from=V1EnvVarSource())]
)
assert find_env_var(container, "key1") == (0, "val1")
assert find_env_var(container, "key2") == (1, V1EnvVarSource())
assert find_env_var(container, "missing") is None


def test_patch_statefulset_tokens():
git_clone_access_env = "GIT_CLONE_USER__RENKU_TOKEN"
git_proxy_access_env = "GIT_PROXY_RENKU_ACCESS_TOKEN"
git_proxy_refresh_env = "GIT_PROXY_RENKU_REFRESH_TOKEN"
secrets_access_env = "RENKU_ACCESS_TOKEN"
git_clone = V1Container(
name="git-clone",
env=[
V1EnvVar(name="test", value="value"),
V1EnvVar(git_clone_access_env, "old_value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
],
)
git_proxy = V1Container(
name="git-proxy",
env=[
V1EnvVar(name="test", value="value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
V1EnvVar(git_proxy_refresh_env, "old_value"),
V1EnvVar(git_proxy_access_env, "old_value"),
],
)
secrets = V1Container(
name="init-user-secrets",
env=[
V1EnvVar(secrets_access_env, "old_value"),
V1EnvVar(name="test", value="value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
],
)
random1 = V1Container(name="random1")
random2 = V1Container(
name="random2",
env=[
V1EnvVar(name="test", value="value"),
V1EnvVar(name="test-from-source", value_from=V1EnvVarSource()),
],
)

new_renku_tokens = RenkuTokens(access_token="new_renku_access_token", refresh_token="new_renku_refresh_token")

sts = V1StatefulSet(
spec=V1StatefulSetSpec(
service_name="test",
selector=V1LabelSelector(),
template=V1PodTemplateSpec(
spec=V1PodSpec(
containers=[git_proxy, random1, random2], init_containers=[git_clone, random1, secrets, random2]
)
),
)
)
patches = NamespacedK8sClient._get_statefulset_token_patches(sts, new_renku_tokens)

# Order of patches should be git proxy access, git proxy refresh, git clone, secrets
assert len(patches) == 4
# Git proxy access token
assert patches[0]["path"] == "/spec/template/spec/containers/0/env/3/value"
assert patches[0]["value"] == new_renku_tokens.access_token
# Git proxy refresh token
assert patches[1]["path"] == "/spec/template/spec/containers/0/env/2/value"
assert patches[1]["value"] == new_renku_tokens.refresh_token
# Git clone
assert patches[2]["path"] == "/spec/template/spec/initContainers/0/env/1/value"
assert patches[2]["value"] == new_renku_tokens.access_token
# Secrets init
assert patches[3]["path"] == "/spec/template/spec/initContainers/2/env/0/value"
assert patches[3]["value"] == new_renku_tokens.access_token
Loading