diff --git a/renku_notebooks/api/classes/k8s_client.py b/renku_notebooks/api/classes/k8s_client.py index 87a4903ce..65361d1a3 100644 --- a/renku_notebooks/api/classes/k8s_client.py +++ b/renku_notebooks/api/classes/k8s_client.py @@ -53,6 +53,10 @@ def __init__( except ConfigException: load_config() self._custom_objects = client.CustomObjectsApi(client.ApiClient()) + self._custom_objects_patch = client.CustomObjectsApi(client.ApiClient()) + self._custom_objects_patch.api_client.set_default_header( + "Content-Type", "application/json-patch+json" + ) self._core_v1 = client.CoreV1Api() self._apps_v1 = client.AppsV1Api() @@ -120,9 +124,17 @@ def create_server(self, manifest: Dict[str, Any]) -> Dict[str, Any]: server = retry_with_exponential_backoff(lambda x: x is None)(self.get_server)(server_name) return server - def patch_server(self, server_name: str, patch: Dict[str, Any]): + def patch_server(self, server_name: str, patch: Dict[str, Any] | List[Dict[str, Any]]): try: - server = self._custom_objects.patch_namespaced_custom_object( + if isinstance(patch, list): + # NOTE: The _custom_objects_patch will only accept rfc6902 json-patch. + # We can recognize the type of patch because this is the only one that uses a list + client = self._custom_objects_patch + else: + # NOTE: The _custom_objects will accept the usual rfc7386 merge patches + client = self._custom_objects + + server = client.patch_namespaced_custom_object( group=self.amalthea_group, version=self.amalthea_version, namespace=self.namespace, @@ -130,12 +142,31 @@ def patch_server(self, server_name: str, patch: Dict[str, Any]): name=server_name, body=patch, ) + except ApiException as e: logging.exception(f"Cannot patch server {server_name} because of {e}") raise PatchServerError() return server + def patch_statefulset( + self, server_name: str, patch: Dict[str, Any] | List[Dict[str, Any]] | client.V1StatefulSet + ) -> client.V1StatefulSet | None: + try: + ss = self._apps_v1.patch_namespaced_stateful_set( + server_name, + self.namespace, + patch, + ) + except ApiException as err: + if err.status == 404: + # NOTE: It can happen potentially that another request or something else + # deleted the session as this request was going on, in this case we ignore + # the missing statefulset + return + raise + return ss + def delete_server(self, server_name: str, forced: bool = False): try: status = self._custom_objects.delete_namespaced_custom_object( @@ -479,6 +510,15 @@ def patch_server(self, server_name: str, safe_username: str, patch: Dict[str, An else: return self.session_ns_client.patch_server(server_name=server_name, patch=patch) + def patch_statefulset( + self, server_name: str, patch: Dict[str, Any] + ) -> client.V1StatefulSet | None: + if self.session_ns_client: + client = self.session_ns_client + else: + client = self.renku_ns_client + return client.patch_statefulset(server_name=server_name, patch=patch) + def delete_server(self, server_name: str, safe_username: str, forced: bool = False): server = self.get_server(server_name, safe_username) if not server: diff --git a/renku_notebooks/api/classes/server.py b/renku_notebooks/api/classes/server.py index 847a34497..b303ddec5 100644 --- a/renku_notebooks/api/classes/server.py +++ b/renku_notebooks/api/classes/server.py @@ -151,25 +151,6 @@ def _commit_sha_exists(self): return True return False - def _get_session_k8s_resources(self): - cpu_request = float(self.server_options.cpu) - mem = self.server_options.memory - gpu_req = self.server_options.gpu - gpu = {"nvidia.com/gpu": str(gpu_req)} if gpu_req > 0 else None - resources = { - "requests": {"memory": mem, "cpu": cpu_request}, - "limits": {"memory": mem}, - } - if config.sessions.enforce_cpu_limits == "lax": - lax_cpu_limit_allowance_factor = 3 - resources["limits"]["cpu"] = lax_cpu_limit_allowance_factor * cpu_request - elif config.sessions.enforce_cpu_limits == "strict": - resources["limits"]["cpu"] = cpu_request - if gpu: - resources["requests"] = {**resources["requests"], **gpu} - resources["limits"] = {**resources["limits"], **gpu} - return resources - def _get_session_manifest(self): """Compose the body of the user session for the k8s operator""" patches = list( @@ -266,7 +247,9 @@ def _get_session_manifest(self): "defaultUrl": self.server_options.default_url, "image": self.image, "rootDir": self.work_dir.absolute().as_posix(), - "resources": self._get_session_k8s_resources(), + "resources": self.server_options.to_k8s_resources( + enforce_cpu_limits=config.sessions.enforce_cpu_limits + ), }, "routing": { "host": urlparse(self.server_url).netloc, @@ -375,6 +358,8 @@ def get_annotations(self): f"{prefix}lastActivityDate": "", f"{prefix}idleSecondsThreshold": str(self.idle_seconds_threshold), } + if self.server_options.resource_class_id: + annotations[f"{prefix}resourceClassId"] = str(self.server_options.resource_class_id) if self.gl_project is not None: annotations[f"{prefix}gitlabProjectId"] = str(self.gl_project.id) annotations[f"{prefix}repository"] = self.gl_project.web_url diff --git a/renku_notebooks/api/notebooks.py b/renku_notebooks/api/notebooks.py index cdd19e788..f2f7aed98 100644 --- a/renku_notebooks/api/notebooks.py +++ b/renku_notebooks/api/notebooks.py @@ -34,7 +34,7 @@ from .classes.auth import GitlabToken, RenkuTokens from ..errors.intermittent import AnonymousUserPatchError, PVDisabledError from ..errors.programming import ProgrammingError -from ..errors.user import InvalidPatchArgumentError, MissingResourceError, UserInputError +from ..errors.user import MissingResourceError, UserInputError from ..util.kubernetes_ import make_server_name from .auth import authenticated from .classes.image import Image @@ -384,9 +384,9 @@ def launch_notebook( @bp.route("servers/", methods=["PATCH"]) @use_args({"server_name": fields.Str(required=True)}, location="view_args", as_kwargs=True) -@use_args(PatchServerRequest(), location="json", as_kwargs=True) +@use_args(PatchServerRequest(), location="json", arg_name="patch_body") @authenticated -def patch_server(user, server_name, state): +def patch_server(user, server_name, patch_body): """ Patch a user server by name based on the query param. @@ -405,7 +405,7 @@ def patch_server(user, server_name, state): required: true description: The name of the server that should be patched. responses: - 204: + 200: description: The server was patched successfully. content: application/json: @@ -435,13 +435,71 @@ def patch_server(user, server_name, state): raise AnonymousUserPatchError() server = config.k8s.client.get_server(server_name, user.safe_username) + new_server = server + currently_hibernated = server.get("spec", {}).get("jupyterServer", {}).get("hibernated", False) + currently_failing = server.get("status", {}).get("state", "running") == "failed" + state = patch_body.get("state") + resource_class_id = patch_body.get("resource_class_id") + if server and not (currently_hibernated or currently_failing) and resource_class_id: + raise UserInputError( + "The resource class can be changed only if the server is hibernated or failing" + ) + + if resource_class_id: + parsed_server_options = config.crc_validator.validate_class_storage( + user, resource_class_id, storage=None # we do not care about validating storage + ) + js_patch = [ + { + "op": "replace", + "path": "/spec/jupyterServer/resources", + "value": parsed_server_options.to_k8s_resources(config.sessions.enforce_cpu_limits), + }, + { + "op": "replace", + # NOTE: ~1 is how you escape '/' in json-patch + "path": "/metadata/annotations/renku.io~1resourceClassId", + "value": str(resource_class_id), + }, + ] + if parsed_server_options.priority_class: + js_patch.append( + { + "op": "replace", + # NOTE: ~1 is how you escape '/' in json-patch + "path": "/metadata/labels/renku.io~1quota", + "value": parsed_server_options.priority_class, + } + ) + elif server.get("metadata", {}).get("labels", {}).get("renku.io/quota"): + js_patch.append( + { + "op": "remove", + # NOTE: ~1 is how you escape '/' in json-patch + "path": "/metadata/labels/renku.io~1quota", + } + ) + new_server = config.k8s.client.patch_server( + server_name=server_name, safe_username=user.safe_username, patch=js_patch + ) + ss_patch = [ + { + "op": "replace", + "path": "/spec/template/spec/priorityClassName", + "value": parsed_server_options.priority_class, + } + ] + config.k8s.client.patch_statefulset(server_name=server_name, patch=ss_patch) if state == PatchServerStatusEnum.Hibernated.value: # NOTE: Do nothing if server is already hibernated - if server and server.get("spec", {}).get("jupyterServer", {}).get("hibernated", False): + currently_hibernated = ( + server.get("spec", {}).get("jupyterServer", {}).get("hibernated", False) + ) + if server and currently_hibernated: logging.warning(f"Server {server_name} is already hibernated.") - return NotebookResponse().dump(UserServerManifest(server)), 204 + return NotebookResponse().dump(UserServerManifest(server)), 200 hibernation = {"branch": "", "commit": "", "dirty": "", "synchronized": ""} @@ -474,7 +532,7 @@ def patch_server(user, server_name, state): }, } - server = config.k8s.client.patch_server( + new_server = config.k8s.client.patch_server( server_name=server_name, safe_username=user.safe_username, patch=patch ) elif state == PatchServerStatusEnum.Running.value: @@ -494,13 +552,11 @@ def patch_server(user, server_name, state): access_token=user.git_token, expires_at=user.git_token_expires_at ) config.k8s.client.patch_tokens(server_name, renku_tokens, gitlab_token) - server = config.k8s.client.patch_server( + new_server = config.k8s.client.patch_server( server_name=server_name, safe_username=user.safe_username, patch=patch ) - else: - raise InvalidPatchArgumentError(f"Invalid PATCH argument value: '{state}'") - return NotebookResponse().dump(UserServerManifest(server)), 204 + return NotebookResponse().dump(UserServerManifest(new_server)), 200 @bp.route("servers/", methods=["DELETE"]) diff --git a/renku_notebooks/api/schemas/server_options.py b/renku_notebooks/api/schemas/server_options.py index 599cffb72..09bfb1942 100644 --- a/renku_notebooks/api/schemas/server_options.py +++ b/renku_notebooks/api/schemas/server_options.py @@ -4,8 +4,9 @@ from marshmallow import Schema, fields, post_load from ...config import config -from ...errors.programming import ProgrammingError +from ...config.dynamic import CPUEnforcement from .custom_fields import ByteSizeField, CpuField, GpuField +from ...errors.programming import ProgrammingError @dataclass @@ -49,6 +50,7 @@ class ServerOptions: priority_class: Optional[str] = None node_affinities: List[NodeAffinity] = field(default_factory=list) tolerations: List[Toleration] = field(default_factory=list) + resource_class_id: Optional[int] = None def __post_init__(self): if self.default_url is None: @@ -149,6 +151,27 @@ def __eq__(self, other: "ServerOptions"): and self.priority_class == other.priority_class ) + def to_k8s_resources( + self, enforce_cpu_limits: CPUEnforcement = CPUEnforcement.OFF + ) -> Dict[str, Any]: + """Convert to the K8s resource requests and limits for cpu, memory and gpus.""" + cpu_request = float(self.cpu) + mem = f"{self.memory}G" if self.gigabytes else self.memory + gpu_req = self.gpu + gpu = {"nvidia.com/gpu": str(gpu_req)} if gpu_req > 0 else None + resources = { + "requests": {"memory": mem, "cpu": cpu_request}, + "limits": {"memory": mem}, + } + if enforce_cpu_limits == CPUEnforcement.LAX: + resources["limits"]["cpu"] = 3 * cpu_request + elif enforce_cpu_limits == CPUEnforcement.STRICT: + resources["limits"]["cpu"] = cpu_request + if gpu: + resources["requests"] = {**resources["requests"], **gpu} + resources["limits"] = {**resources["limits"], **gpu} + return resources + @classmethod def from_resource_class(cls, data: Dict[str, Any]) -> "ServerOptions": """Convert a CRC resource class to server options. CRC users GB for storage and memory @@ -160,6 +183,7 @@ def from_resource_class(cls, data: Dict[str, Any]) -> "ServerOptions": storage=data["default_storage"] * 1000000000, node_affinities=[NodeAffinity(**a) for a in data.get("node_affinities", [])], tolerations=[Toleration(t) for t in data.get("tolerations", [])], + resource_class_id=data.get("id"), ) @classmethod diff --git a/renku_notebooks/api/schemas/servers_patch.py b/renku_notebooks/api/schemas/servers_patch.py index dc63a626c..4fb195703 100644 --- a/renku_notebooks/api/schemas/servers_patch.py +++ b/renku_notebooks/api/schemas/servers_patch.py @@ -1,6 +1,6 @@ from enum import Enum -from marshmallow import Schema, fields, validate +from marshmallow import EXCLUDE, Schema, fields, validate class PatchServerStatusEnum(Enum): @@ -17,4 +17,9 @@ def list(cls): class PatchServerRequest(Schema): """Simple Enum for server status.""" - state = fields.String(required=True, validate=validate.OneOf(PatchServerStatusEnum.list())) + class Meta: + # passing unknown params does not error, but the params are ignored + unknown = EXCLUDE + + state = fields.String(required=False, validate=validate.OneOf(PatchServerStatusEnum.list())) + resource_class_id = fields.Int(required=False, validate=lambda x: x > 0) diff --git a/renku_notebooks/config/__init__.py b/renku_notebooks/config/__init__.py index 558030c13..1a95dbad1 100644 --- a/renku_notebooks/config/__init__.py +++ b/renku_notebooks/config/__init__.py @@ -228,7 +228,7 @@ def get_config(default_config: str) -> _NotebooksConfig: ] } ssh {} - enforce_cpu_limits: false + enforce_cpu_limits: off termination_warning_duration_seconds: 43200 image_default_workdir: /home/jovyan node_selector: "{}" diff --git a/renku_notebooks/config/dynamic.py b/renku_notebooks/config/dynamic.py index 95f1ba5b8..7bbb1623e 100644 --- a/renku_notebooks/config/dynamic.py +++ b/renku_notebooks/config/dynamic.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Any, Dict, List, Optional, Text, Union import yaml @@ -24,6 +25,12 @@ def _parse_value_as_float(val: Any) -> float: return float(val) +class CPUEnforcement(str, Enum): + LAX: str = "lax" # CPU limit equals 3x cpu request + STRICT: str = "strict" # CPU limit equals cpu request + OFF: str = "off" # no CPU limit at all + + @dataclass class _ServerOptionsConfig: defaults_path: Text @@ -195,7 +202,7 @@ class _SessionConfig: containers: _SessionContainers ssh: _SessionSshConfig default_image: Text = "renku/singleuser:latest" - enforce_cpu_limits: Union[Text, bool] = False + enforce_cpu_limits: CPUEnforcement = CPUEnforcement.OFF termination_warning_duration_seconds: int = 12 * 60 * 60 image_default_workdir: Text = "/home/jovyan" node_selector: Text = "{}" diff --git a/renku_notebooks/config/static.py b/renku_notebooks/config/static.py index e6a575f76..5c9bc47fe 100644 --- a/renku_notebooks/config/static.py +++ b/renku_notebooks/config/static.py @@ -65,6 +65,7 @@ class _ServersGetEndpointAnnotations: "renku.io/username", "renku.io/git-host", "renku.io/gitlabProjectId", + "renku.io/resourceClassId", "jupyter.org/servername", "jupyter.org/username", ] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 984dedd38..622c65ee6 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -120,7 +120,7 @@ def anonymous_gitlab_client(): @pytest.fixture( scope="session", - params=[os.environ["SESSION_TYPE"]], + params=[os.environ.get("SESSION_TYPE")], ) def gitlab_client(request, anonymous_gitlab_client, registered_gitlab_client): if request.param == "registered":