Skip to content

Commit

Permalink
Implement RBAC scoping for workspaces
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Feb 18, 2025
1 parent e3efd14 commit cbc0ba7
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 23 deletions.
8 changes: 8 additions & 0 deletions src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
BaseDatedResponseBody,
)
from zenml.models.v2.base.scoped import (
FlexibleScopedFilter,
FlexibleScopedRequest,
FlexibleScopedResponse,
FlexibleScopedUpdate,
TaggableFilter,
UserScopedRequest,
UserScopedFilter,
Expand Down Expand Up @@ -490,6 +494,10 @@
"BaseDatedResponseBody",
"BaseZenModel",
"BasePluginFlavorResponse",
"FlexibleScopedFilter",
"FlexibleScopedRequest",
"FlexibleScopedResponse",
"FlexibleScopedUpdate",
"UserScopedRequest",
"UserScopedFilter",
"UserScopedResponse",
Expand Down
47 changes: 43 additions & 4 deletions src/zenml/zen_server/rbac/endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""High-level helper functions to write endpoints with RBAC."""

from typing import Any, Callable, List, TypeVar, Union
from typing import Any, Callable, List, Optional, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel
Expand All @@ -25,8 +25,11 @@
BaseFilter,
BaseIdentifiedResponse,
BaseRequest,
FlexibleScopedFilter,
Page,
UserScopedRequest,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
)
from zenml.zen_server.auth import get_auth_context
from zenml.zen_server.feature_gate.endpoint_utils import (
Expand Down Expand Up @@ -70,10 +73,19 @@ def verify_permissions_and_create_entity(
assert auth_context

# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead.
# the current user's ID instead. This prevents the current user from
# being able to create entities on behalf of other users.
request_model.user = auth_context.user.id

verify_permission(resource_type=resource_type, action=Action.CREATE)
if isinstance(request_model, WorkspaceScopedRequest):
# A workspace scoped request is always scoped to a specific workspace
workspace_id = request_model.workspace

verify_permission(
resource_type=resource_type,
action=Action.CREATE,
workspace_id=workspace_id,
)

needs_usage_increment = (
resource_type in server_config().reportable_resources
Expand Down Expand Up @@ -111,13 +123,25 @@ def verify_permissions_and_batch_create_entity(
auth_context = get_auth_context()
assert auth_context

workspace_ids = set()
for request_model in batch:
if isinstance(request_model, UserScopedRequest):
# Ignore the user field set in the request model, if any, and set it to
# the current user's ID instead.
request_model.user = auth_context.user.id

verify_permission(resource_type=resource_type, action=Action.CREATE)
if isinstance(request_model, WorkspaceScopedRequest):
# A workspace scoped request is always scoped to a specific workspace
workspace_ids.add(request_model.workspace)
else:
workspace_ids.add(None)

for workspace_id in workspace_ids:
verify_permission(
resource_type=resource_type,
action=Action.CREATE,
workspace_id=workspace_id,
)

if resource_type in server_config().reportable_resources:
raise RuntimeError(
Expand Down Expand Up @@ -164,10 +188,25 @@ def verify_permissions_and_list_entities(
Returns:
A page of entity models.
Raises:
ValueError: If the workspace ID is not set for workspace-scoped resources.
"""
auth_context = get_auth_context()
assert auth_context

workspace_id: Optional[UUID] = None
if isinstance(filter_model, WorkspaceScopedFilter):
# A workspace scoped request is always scoped to a specific workspace
workspace_id = filter_model.workspace
if workspace_id is None:
raise ValueError("Workspace ID is required for workspace-scoped resources.")

elif isinstance(filter_model, FlexibleScopedFilter):
# A flexible scoped request is always scoped to a specific workspace
workspace_id = filter_model.workspace


allowed_ids = get_allowed_resource_ids(resource_type=resource_type)
filter_model.configure_rbac(
authenticated_user_id=auth_context.user.id, id=allowed_ids
Expand Down
87 changes: 84 additions & 3 deletions src/zenml/zen_server/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
# permissions and limitations under the License.
"""RBAC model classes."""

from typing import Optional
from typing import Any, Dict, Optional
from uuid import UUID

from pydantic import BaseModel, ConfigDict
from pydantic import (
BaseModel,
ConfigDict,
ValidationInfo,
field_validator,
model_validator,
)

from zenml.utils.enum_utils import StrEnum

Expand Down Expand Up @@ -73,23 +79,98 @@ class ResourceType(StrEnum):
# USER = "user"
# WORKSPACE = "workspace"

def is_flexible_scoped(self) -> bool:
"""Check if a resource type may flexibly be scoped to a workspace.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type may flexibly be scoped to a workspace.
"""
return self in [
self.FLAVOR,
self.SECRET,
self.SERVICE_CONNECTOR,
self.STACK,
self.STACK_COMPONENT,
]

def is_workspace_scoped(self) -> bool:
"""Check if a resource type is workspace scoped.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type is workspace scoped.
"""
return not self.is_flexible_scoped() and not self.is_unscoped()

def is_unscoped(self) -> bool:
"""Check if a resource type is unscoped.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type is unscoped.
"""
return self in [
self.SERVICE_ACCOUNT,
# Deactivated for now
# cls.USER,
# cls.WORKSPACE,
]


class Resource(BaseModel):
"""RBAC resource model."""

type: str
id: Optional[UUID] = None
workspace_id: Optional[UUID] = None

def __str__(self) -> str:
"""Convert to a string.
Returns:
Resource string representation.
"""
representation = self.type
if self.workspace_id:
representation = f"{self.workspace_id}:"
else:
representation = ""
representation += self.type
if self.id:
representation += f"/{self.id}"

return representation

@model_validator(mode="after")
def validate_workspace_id(self) -> "Resource":
"""Validate that workspace_id is set in combination with the correct resource types.
Raises:
ValueError: If workspace_id is not set for a workspace-scoped
resource or set for an unscoped resource.
Returns:
The validated resource.
"""
resource_type = ResourceType(self.type)
if resource_type.is_workspace_scoped() and not self.workspace_id:
raise ValueError(
"workspace_id must be set for workspace-scoped resource type "
f"'{self.type}'"
)

if resource_type.is_unscoped() and self.workspace_id:
raise ValueError(
"workspace_id must not be set for unscoped resource type "
f"'{self.type}'"
)

return self

model_config = ConfigDict(frozen=True)
33 changes: 31 additions & 2 deletions src/zenml/zen_server/rbac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
Page,
UserResponse,
UserScopedResponse,
FlexibleScopedResponse,
WorkspaceScopedResponse,
)
from zenml.zen_server.auth import get_auth_context
from zenml.zen_server.rbac.models import Action, Resource, ResourceType
Expand Down Expand Up @@ -283,6 +285,7 @@ def verify_permission(
resource_type: str,
action: Action,
resource_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
) -> None:
"""Verifies if a user has permission to perform an action on a resource.
Expand All @@ -291,8 +294,12 @@ def verify_permission(
action on.
action: The action the user wants to perform.
resource_id: ID of the resource the user wants to perform the action on.
workspace_id: ID of the workspace the user wants to perform the action
on. Only used for workspace scoped resources.
"""
resource = Resource(type=resource_type, id=resource_id)
resource = Resource(
type=resource_type, id=resource_id, workspace_id=workspace_id
)
batch_verify_permissions(resources={resource}, action=action)


Expand Down Expand Up @@ -346,7 +353,11 @@ def get_resource_for_model(model: AnyResponse) -> Optional[Resource]:
# This model is not tied to any RBAC resource type
return None

return Resource(type=resource_type, id=model.id)
workspace_id: Optional[UUID] = None
if isinstance(model, WorkspaceScopedResponse):
workspace_id = model.workspace.id

return Resource(type=resource_type, id=model.id, workspace_id=workspace_id)


def get_surrogate_permission_model_for_model(
Expand Down Expand Up @@ -448,6 +459,24 @@ def get_resource_type_for_model(
return mapping.get(type(model))


def is_resource_type_workspace_scoped(resource_type: ResourceType) -> bool:
"""Check if a resource type is workspace scoped.
Args:
resource_type: The resource type to check.
Returns:
Whether the resource type is workspace scoped.
"""
return resource_type in [
ResourceType.STACK,
ResourceType.PIPELINE,
ResourceType.CODE_REPOSITORY,
ResourceType.SECRET,
ResourceType.MODEL,
]


def is_owned_by_authenticated_user(model: AnyResponse) -> bool:
"""Returns whether the currently authenticated user owns the model.
Expand Down
30 changes: 19 additions & 11 deletions src/zenml/zen_server/rbac/zenml_cloud_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""Cloud RBAC implementation."""

from typing import TYPE_CHECKING, Dict, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from zenml.zen_server.cloud_utils import cloud_connection
from zenml.zen_server.rbac.models import Action, Resource
Expand Down Expand Up @@ -42,12 +42,7 @@ def _convert_to_cloud_resource(resource: Resource) -> str:
Returns:
The converted resource.
"""
resource_string = f"{SERVER_ID}@{SERVER_SCOPE_IDENTIFIER}:{resource.type}"

if resource.id:
resource_string += f"/{resource.id}"

return resource_string
return f"{SERVER_ID}@{SERVER_SCOPE_IDENTIFIER}:{resource}"


def _convert_from_cloud_resource(cloud_resource: str) -> Resource:
Expand All @@ -62,16 +57,29 @@ def _convert_from_cloud_resource(cloud_resource: str) -> Resource:
Returns:
The converted resource.
"""
scope, resource_type_and_id = cloud_resource.rsplit(":", maxsplit=1)
scope, workspace_resource_type_and_id = cloud_resource.rsplit(
":", maxsplit=1
)

if scope != f"{SERVER_ID}@{SERVER_SCOPE_IDENTIFIER}":
raise ValueError("Invalid scope for server resource.")

workspace_id: Optional[str] = None
if ":" in workspace_resource_type_and_id:
workspace_id, resource_type_and_id = (
workspace_resource_type_and_id.split(":", maxsplit=1)
)
else:
workspace_id = None
resource_type_and_id = workspace_resource_type_and_id

resource_id: Optional[str] = None
if "/" in resource_type_and_id:
resource_type, resource_id = resource_type_and_id.split("/")
return Resource(type=resource_type, id=resource_id)
else:
return Resource(type=resource_type_and_id)

return Resource(
type=resource_type, id=resource_id, workspace_id=workspace_id
)


class ZenMLCloudRBAC(RBACInterface):
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11372,7 +11372,7 @@ def _attach_tags_to_resource_new(
resource_type = self._get_taggable_resource_type(
resource=resource
)
self.create_tag_resource(
self._create_tag_resource(
TagResourceRequest(
tag_id=tag.id,
resource_id=resource.id,
Expand Down Expand Up @@ -11401,7 +11401,7 @@ def _attach_tags_to_resource(
except KeyError:
tag = self.create_tag(TagRequest(name=tag_name))
try:
self.create_tag_resource(
self._create_tag_resource(
TagResourceRequest(
tag_id=tag.id,
resource_id=resource_id,
Expand Down Expand Up @@ -11583,7 +11583,7 @@ def update_tag(
# Tags <> resources
####################

def create_tag_resource(
def _create_tag_resource(
self, tag_resource: TagResourceRequest
) -> TagResourceResponse:
"""Creates a new tag resource relationship.
Expand Down

0 comments on commit cbc0ba7

Please sign in to comment.