Skip to content

Commit

Permalink
feat: allow patching of resource class (#1728)
Browse files Browse the repository at this point in the history
Co-authored-by: Ralf Grubenmann <ralf.grubenmann@sdsc.ethz.ch>
  • Loading branch information
olevski and Panaetius authored Jan 16, 2024
1 parent 97e3a50 commit 5219719
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 39 deletions.
44 changes: 42 additions & 2 deletions renku_notebooks/api/classes/k8s_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def __init__(
except ConfigException:
load_config()
self._custom_objects = client.CustomObjectsApi(client.ApiClient())
self._custom_objects_patch = client.CustomObjectsApi(client.ApiClient())
self._custom_objects_patch.api_client.set_default_header(
"Content-Type", "application/json-patch+json"
)
self._core_v1 = client.CoreV1Api()
self._apps_v1 = client.AppsV1Api()

Expand Down Expand Up @@ -120,22 +124,49 @@ def create_server(self, manifest: Dict[str, Any]) -> Dict[str, Any]:
server = retry_with_exponential_backoff(lambda x: x is None)(self.get_server)(server_name)
return server

def patch_server(self, server_name: str, patch: Dict[str, Any]):
def patch_server(self, server_name: str, patch: Dict[str, Any] | List[Dict[str, Any]]):
try:
server = self._custom_objects.patch_namespaced_custom_object(
if isinstance(patch, list):
# NOTE: The _custom_objects_patch will only accept rfc6902 json-patch.
# We can recognize the type of patch because this is the only one that uses a list
client = self._custom_objects_patch
else:
# NOTE: The _custom_objects will accept the usual rfc7386 merge patches
client = self._custom_objects

server = client.patch_namespaced_custom_object(
group=self.amalthea_group,
version=self.amalthea_version,
namespace=self.namespace,
plural=self.amalthea_plural,
name=server_name,
body=patch,
)

except ApiException as e:
logging.exception(f"Cannot patch server {server_name} because of {e}")
raise PatchServerError()

return server

def patch_statefulset(
self, server_name: str, patch: Dict[str, Any] | List[Dict[str, Any]] | client.V1StatefulSet
) -> client.V1StatefulSet | None:
try:
ss = self._apps_v1.patch_namespaced_stateful_set(
server_name,
self.namespace,
patch,
)
except ApiException as err:
if err.status == 404:
# NOTE: It can happen potentially that another request or something else
# deleted the session as this request was going on, in this case we ignore
# the missing statefulset
return
raise
return ss

def delete_server(self, server_name: str, forced: bool = False):
try:
status = self._custom_objects.delete_namespaced_custom_object(
Expand Down Expand Up @@ -479,6 +510,15 @@ def patch_server(self, server_name: str, safe_username: str, patch: Dict[str, An
else:
return self.session_ns_client.patch_server(server_name=server_name, patch=patch)

def patch_statefulset(
self, server_name: str, patch: Dict[str, Any]
) -> client.V1StatefulSet | None:
if self.session_ns_client:
client = self.session_ns_client
else:
client = self.renku_ns_client
return client.patch_statefulset(server_name=server_name, patch=patch)

def delete_server(self, server_name: str, safe_username: str, forced: bool = False):
server = self.get_server(server_name, safe_username)
if not server:
Expand Down
25 changes: 5 additions & 20 deletions renku_notebooks/api/classes/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,25 +151,6 @@ def _commit_sha_exists(self):
return True
return False

def _get_session_k8s_resources(self):
cpu_request = float(self.server_options.cpu)
mem = self.server_options.memory
gpu_req = self.server_options.gpu
gpu = {"nvidia.com/gpu": str(gpu_req)} if gpu_req > 0 else None
resources = {
"requests": {"memory": mem, "cpu": cpu_request},
"limits": {"memory": mem},
}
if config.sessions.enforce_cpu_limits == "lax":
lax_cpu_limit_allowance_factor = 3
resources["limits"]["cpu"] = lax_cpu_limit_allowance_factor * cpu_request
elif config.sessions.enforce_cpu_limits == "strict":
resources["limits"]["cpu"] = cpu_request
if gpu:
resources["requests"] = {**resources["requests"], **gpu}
resources["limits"] = {**resources["limits"], **gpu}
return resources

def _get_session_manifest(self):
"""Compose the body of the user session for the k8s operator"""
patches = list(
Expand Down Expand Up @@ -266,7 +247,9 @@ def _get_session_manifest(self):
"defaultUrl": self.server_options.default_url,
"image": self.image,
"rootDir": self.work_dir.absolute().as_posix(),
"resources": self._get_session_k8s_resources(),
"resources": self.server_options.to_k8s_resources(
enforce_cpu_limits=config.sessions.enforce_cpu_limits
),
},
"routing": {
"host": urlparse(self.server_url).netloc,
Expand Down Expand Up @@ -375,6 +358,8 @@ def get_annotations(self):
f"{prefix}lastActivityDate": "",
f"{prefix}idleSecondsThreshold": str(self.idle_seconds_threshold),
}
if self.server_options.resource_class_id:
annotations[f"{prefix}resourceClassId"] = str(self.server_options.resource_class_id)
if self.gl_project is not None:
annotations[f"{prefix}gitlabProjectId"] = str(self.gl_project.id)
annotations[f"{prefix}repository"] = self.gl_project.web_url
Expand Down
78 changes: 67 additions & 11 deletions renku_notebooks/api/notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .classes.auth import GitlabToken, RenkuTokens
from ..errors.intermittent import AnonymousUserPatchError, PVDisabledError
from ..errors.programming import ProgrammingError
from ..errors.user import InvalidPatchArgumentError, MissingResourceError, UserInputError
from ..errors.user import MissingResourceError, UserInputError
from ..util.kubernetes_ import make_server_name
from .auth import authenticated
from .classes.image import Image
Expand Down Expand Up @@ -384,9 +384,9 @@ def launch_notebook(

@bp.route("servers/<server_name>", methods=["PATCH"])
@use_args({"server_name": fields.Str(required=True)}, location="view_args", as_kwargs=True)
@use_args(PatchServerRequest(), location="json", as_kwargs=True)
@use_args(PatchServerRequest(), location="json", arg_name="patch_body")
@authenticated
def patch_server(user, server_name, state):
def patch_server(user, server_name, patch_body):
"""
Patch a user server by name based on the query param.
Expand All @@ -405,7 +405,7 @@ def patch_server(user, server_name, state):
required: true
description: The name of the server that should be patched.
responses:
204:
200:
description: The server was patched successfully.
content:
application/json:
Expand Down Expand Up @@ -435,13 +435,71 @@ def patch_server(user, server_name, state):
raise AnonymousUserPatchError()

server = config.k8s.client.get_server(server_name, user.safe_username)
new_server = server
currently_hibernated = server.get("spec", {}).get("jupyterServer", {}).get("hibernated", False)
currently_failing = server.get("status", {}).get("state", "running") == "failed"
state = patch_body.get("state")
resource_class_id = patch_body.get("resource_class_id")
if server and not (currently_hibernated or currently_failing) and resource_class_id:
raise UserInputError(
"The resource class can be changed only if the server is hibernated or failing"
)

if resource_class_id:
parsed_server_options = config.crc_validator.validate_class_storage(
user, resource_class_id, storage=None # we do not care about validating storage
)
js_patch = [
{
"op": "replace",
"path": "/spec/jupyterServer/resources",
"value": parsed_server_options.to_k8s_resources(config.sessions.enforce_cpu_limits),
},
{
"op": "replace",
# NOTE: ~1 is how you escape '/' in json-patch
"path": "/metadata/annotations/renku.io~1resourceClassId",
"value": str(resource_class_id),
},
]
if parsed_server_options.priority_class:
js_patch.append(
{
"op": "replace",
# NOTE: ~1 is how you escape '/' in json-patch
"path": "/metadata/labels/renku.io~1quota",
"value": parsed_server_options.priority_class,
}
)
elif server.get("metadata", {}).get("labels", {}).get("renku.io/quota"):
js_patch.append(
{
"op": "remove",
# NOTE: ~1 is how you escape '/' in json-patch
"path": "/metadata/labels/renku.io~1quota",
}
)
new_server = config.k8s.client.patch_server(
server_name=server_name, safe_username=user.safe_username, patch=js_patch
)
ss_patch = [
{
"op": "replace",
"path": "/spec/template/spec/priorityClassName",
"value": parsed_server_options.priority_class,
}
]
config.k8s.client.patch_statefulset(server_name=server_name, patch=ss_patch)

if state == PatchServerStatusEnum.Hibernated.value:
# NOTE: Do nothing if server is already hibernated
if server and server.get("spec", {}).get("jupyterServer", {}).get("hibernated", False):
currently_hibernated = (
server.get("spec", {}).get("jupyterServer", {}).get("hibernated", False)
)
if server and currently_hibernated:
logging.warning(f"Server {server_name} is already hibernated.")

return NotebookResponse().dump(UserServerManifest(server)), 204
return NotebookResponse().dump(UserServerManifest(server)), 200

hibernation = {"branch": "", "commit": "", "dirty": "", "synchronized": ""}

Expand Down Expand Up @@ -474,7 +532,7 @@ def patch_server(user, server_name, state):
},
}

server = config.k8s.client.patch_server(
new_server = config.k8s.client.patch_server(
server_name=server_name, safe_username=user.safe_username, patch=patch
)
elif state == PatchServerStatusEnum.Running.value:
Expand All @@ -494,13 +552,11 @@ def patch_server(user, server_name, state):
access_token=user.git_token, expires_at=user.git_token_expires_at
)
config.k8s.client.patch_tokens(server_name, renku_tokens, gitlab_token)
server = config.k8s.client.patch_server(
new_server = config.k8s.client.patch_server(
server_name=server_name, safe_username=user.safe_username, patch=patch
)
else:
raise InvalidPatchArgumentError(f"Invalid PATCH argument value: '{state}'")

return NotebookResponse().dump(UserServerManifest(server)), 204
return NotebookResponse().dump(UserServerManifest(new_server)), 200


@bp.route("servers/<server_name>", methods=["DELETE"])
Expand Down
26 changes: 25 additions & 1 deletion renku_notebooks/api/schemas/server_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from marshmallow import Schema, fields, post_load

from ...config import config
from ...errors.programming import ProgrammingError
from ...config.dynamic import CPUEnforcement
from .custom_fields import ByteSizeField, CpuField, GpuField
from ...errors.programming import ProgrammingError


@dataclass
Expand Down Expand Up @@ -49,6 +50,7 @@ class ServerOptions:
priority_class: Optional[str] = None
node_affinities: List[NodeAffinity] = field(default_factory=list)
tolerations: List[Toleration] = field(default_factory=list)
resource_class_id: Optional[int] = None

def __post_init__(self):
if self.default_url is None:
Expand Down Expand Up @@ -149,6 +151,27 @@ def __eq__(self, other: "ServerOptions"):
and self.priority_class == other.priority_class
)

def to_k8s_resources(
self, enforce_cpu_limits: CPUEnforcement = CPUEnforcement.OFF
) -> Dict[str, Any]:
"""Convert to the K8s resource requests and limits for cpu, memory and gpus."""
cpu_request = float(self.cpu)
mem = f"{self.memory}G" if self.gigabytes else self.memory
gpu_req = self.gpu
gpu = {"nvidia.com/gpu": str(gpu_req)} if gpu_req > 0 else None
resources = {
"requests": {"memory": mem, "cpu": cpu_request},
"limits": {"memory": mem},
}
if enforce_cpu_limits == CPUEnforcement.LAX:
resources["limits"]["cpu"] = 3 * cpu_request
elif enforce_cpu_limits == CPUEnforcement.STRICT:
resources["limits"]["cpu"] = cpu_request
if gpu:
resources["requests"] = {**resources["requests"], **gpu}
resources["limits"] = {**resources["limits"], **gpu}
return resources

@classmethod
def from_resource_class(cls, data: Dict[str, Any]) -> "ServerOptions":
"""Convert a CRC resource class to server options. CRC users GB for storage and memory
Expand All @@ -160,6 +183,7 @@ def from_resource_class(cls, data: Dict[str, Any]) -> "ServerOptions":
storage=data["default_storage"] * 1000000000,
node_affinities=[NodeAffinity(**a) for a in data.get("node_affinities", [])],
tolerations=[Toleration(t) for t in data.get("tolerations", [])],
resource_class_id=data.get("id"),
)

@classmethod
Expand Down
9 changes: 7 additions & 2 deletions renku_notebooks/api/schemas/servers_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum

from marshmallow import Schema, fields, validate
from marshmallow import EXCLUDE, Schema, fields, validate


class PatchServerStatusEnum(Enum):
Expand All @@ -17,4 +17,9 @@ def list(cls):
class PatchServerRequest(Schema):
"""Simple Enum for server status."""

state = fields.String(required=True, validate=validate.OneOf(PatchServerStatusEnum.list()))
class Meta:
# passing unknown params does not error, but the params are ignored
unknown = EXCLUDE

state = fields.String(required=False, validate=validate.OneOf(PatchServerStatusEnum.list()))
resource_class_id = fields.Int(required=False, validate=lambda x: x > 0)
2 changes: 1 addition & 1 deletion renku_notebooks/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def get_config(default_config: str) -> _NotebooksConfig:
]
}
ssh {}
enforce_cpu_limits: false
enforce_cpu_limits: off
termination_warning_duration_seconds: 43200
image_default_workdir: /home/jovyan
node_selector: "{}"
Expand Down
9 changes: 8 additions & 1 deletion renku_notebooks/config/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Text, Union

import yaml
Expand All @@ -24,6 +25,12 @@ def _parse_value_as_float(val: Any) -> float:
return float(val)


class CPUEnforcement(str, Enum):
LAX: str = "lax" # CPU limit equals 3x cpu request
STRICT: str = "strict" # CPU limit equals cpu request
OFF: str = "off" # no CPU limit at all


@dataclass
class _ServerOptionsConfig:
defaults_path: Text
Expand Down Expand Up @@ -195,7 +202,7 @@ class _SessionConfig:
containers: _SessionContainers
ssh: _SessionSshConfig
default_image: Text = "renku/singleuser:latest"
enforce_cpu_limits: Union[Text, bool] = False
enforce_cpu_limits: CPUEnforcement = CPUEnforcement.OFF
termination_warning_duration_seconds: int = 12 * 60 * 60
image_default_workdir: Text = "/home/jovyan"
node_selector: Text = "{}"
Expand Down
1 change: 1 addition & 0 deletions renku_notebooks/config/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class _ServersGetEndpointAnnotations:
"renku.io/username",
"renku.io/git-host",
"renku.io/gitlabProjectId",
"renku.io/resourceClassId",
"jupyter.org/servername",
"jupyter.org/username",
]
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def anonymous_gitlab_client():

@pytest.fixture(
scope="session",
params=[os.environ["SESSION_TYPE"]],
params=[os.environ.get("SESSION_TYPE")],
)
def gitlab_client(request, anonymous_gitlab_client, registered_gitlab_client):
if request.param == "registered":
Expand Down

0 comments on commit 5219719

Please sign in to comment.