Skip to content

Commit

Permalink
Add RBAC to workspaces and add default user workspace setting
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Feb 24, 2025
1 parent 42c7921 commit f820ab2
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 90 deletions.
17 changes: 17 additions & 0 deletions src/zenml/models/v2/core/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class UserBase(BaseModel):
default=None,
title="The metadata associated with the user.",
)
default_workspace_id: Optional[UUID] = Field(
default=None,
title="The default workspace ID for the user.",
)

@classmethod
def _get_crypt_context(cls) -> "CryptContext":
Expand Down Expand Up @@ -279,6 +283,10 @@ class UserResponseBody(BaseDatedResponseBody):
is_admin: bool = Field(
title="Whether the account is an administrator.",
)
default_workspace_id: Optional[UUID] = Field(
default=None,
title="The default workspace ID for the user.",
)


class UserResponseMetadata(BaseResponseMetadata):
Expand Down Expand Up @@ -395,6 +403,15 @@ def is_admin(self) -> bool:
"""
return self.get_body().is_admin

@property
def default_workspace_id(self) -> Optional[UUID]:
"""The `default_workspace_id` property.
Returns:
the value of the property.
"""
return self.get_body().default_workspace_id

@property
def email(self) -> Optional[str]:
"""The `email` property.
Expand Down
18 changes: 13 additions & 5 deletions src/zenml/zen_server/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ class ResourceType(StrEnum):
TAG = "tag"
TRIGGER = "trigger"
TRIGGER_EXECUTION = "trigger_execution"
WORKSPACE = "workspace"
# Deactivated for now
# USER = "user"
# WORKSPACE = "workspace"

def is_workspace_scoped(self) -> bool:
"""Check if a resource type is workspace scoped.
Expand All @@ -93,9 +93,9 @@ def is_workspace_scoped(self) -> bool:
self.STACK,
self.STACK_COMPONENT,
self.SERVICE_ACCOUNT,
self.WORKSPACE,
# Deactivated for now
# cls.USER,
# cls.WORKSPACE,
# self.USER,
]


Expand All @@ -112,8 +112,15 @@ def __str__(self) -> str:
Returns:
Resource string representation.
"""
if self.workspace_id:
representation = f"{self.workspace_id}:"
workspace_id = self.workspace_id
if self.type == ResourceType.WORKSPACE and self.id:
# TODO: For now, we duplicate the workspace ID in the string
# representation when describing a workspace instance, because
# this is what is expected by the RBAC implementation.
workspace_id = self.id

if workspace_id:
representation = f"{workspace_id}:"
else:
representation = ""
representation += self.type
Expand All @@ -134,6 +141,7 @@ def validate_workspace_id(self) -> "Resource":
The validated resource.
"""
resource_type = ResourceType(self.type)

if resource_type.is_workspace_scoped() and not self.workspace_id:
raise ValueError(
"workspace_id must be set for workspace-scoped resource type "
Expand Down
13 changes: 2 additions & 11 deletions src/zenml/zen_server/rbac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Page,
UserResponse,
UserScopedResponse,
WorkspaceResponse,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
)
Expand Down Expand Up @@ -342,16 +343,6 @@ def get_allowed_resource_ids(
if not server_config().rbac_enabled:
return None

if ResourceType(resource_type).is_workspace_scoped() and not workspace_id:
raise ValueError(
"Workspace ID is required to list workspace scoped resources."
)
if not ResourceType(resource_type).is_workspace_scoped() and workspace_id:
raise ValueError(
"Workspace ID is not allowed to list resources that are not "
"workspace scoped."
)

auth_context = get_auth_context()
assert auth_context

Expand Down Expand Up @@ -540,7 +531,7 @@ def get_resource_type_for_model(
TriggerResponse: ResourceType.TRIGGER,
TriggerExecutionRequest: ResourceType.TRIGGER_EXECUTION,
TriggerExecutionResponse: ResourceType.TRIGGER_EXECUTION,
# WorkspaceResponse: ResourceType.WORKSPACE,
WorkspaceResponse: ResourceType.WORKSPACE,
# UserResponse: ResourceType.USER,
}

Expand Down
8 changes: 7 additions & 1 deletion src/zenml/zen_server/rbac/zenml_cloud_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from zenml.zen_server.cloud_utils import cloud_connection
from zenml.zen_server.rbac.models import Action, Resource
from zenml.zen_server.rbac.models import Action, Resource, ResourceType
from zenml.zen_server.rbac.rbac_interface import RBACInterface
from zenml.zen_server.utils import server_config

Expand Down Expand Up @@ -77,6 +77,12 @@ def _convert_from_cloud_resource(cloud_resource: str) -> Resource:
if "/" in resource_type_and_id:
resource_type, resource_id = resource_type_and_id.split("/")

if resource_type == ResourceType.WORKSPACE and workspace_id is not None:
# TODO: For now, we duplicate the workspace ID in the string
# representation when describing a workspace instance, because
# this is what is expected by the RBAC implementation.
workspace_id = None

return Resource(
type=resource_type, id=resource_id, workspace_id=workspace_id
)
Expand Down
38 changes: 37 additions & 1 deletion src/zenml/zen_server/routers/users_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
)
from zenml.zen_server.exceptions import error_response
from zenml.zen_server.rate_limit import RequestLimiter
from zenml.zen_server.rbac.endpoint_utils import (
verify_permissions_and_get_entity,
)
from zenml.zen_server.rbac.models import Action, Resource, ResourceType
from zenml.zen_server.rbac.utils import (
dehydrate_page,
Expand Down Expand Up @@ -751,6 +754,7 @@ def update_user_resource_membership(
"Not allowed to call endpoint with the authenticated user."
)

resource_type = ResourceType(resource_type)
schema_class = get_schema_for_resource_type(resource_type)
model = zen_store().get_entity_by_id(
entity_id=resource_id, schema_class=schema_class
Expand All @@ -762,7 +766,6 @@ def update_user_resource_membership(
"not exist."
)

resource_type = ResourceType(resource_type)
workspace_id = None
if isinstance(model, WorkspaceScopedResponse):
workspace_id = model.workspace.id
Expand All @@ -782,3 +785,36 @@ def update_user_resource_membership(
resource=resource,
actions=[Action(action) for action in actions],
)


@current_user_router.put(
"/default-workspace",
response_model=UserResponse,
responses={
401: error_response,
404: error_response,
422: error_response,
},
)
@handle_exceptions
def update_user_default_workspace(
workspace_name_or_id: Union[str, UUID],
auth_context: AuthContext = Security(authorize),
) -> UserResponse:
"""Updates the default workspace of the current user.
Args:
workspace_name_or_id: Name or ID of the workspace.
auth_context: Authentication context.
"""
workspace = verify_permissions_and_get_entity(
id=workspace_name_or_id,
get_method=zen_store().get_workspace,
)

user = zen_store().update_user(
user_id=auth_context.user.id,
user_update=UserUpdate(default_workspace_id=workspace.id),
)

return dehydrate_response_model(user)
50 changes: 34 additions & 16 deletions src/zenml/zen_server/routers/workspaces_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@
)
from zenml.zen_server.auth import AuthContext, authorize
from zenml.zen_server.exceptions import error_response
from zenml.zen_server.rbac.endpoint_utils import (
verify_permissions_and_create_entity,
verify_permissions_and_delete_entity,
verify_permissions_and_get_entity,
verify_permissions_and_list_entities,
verify_permissions_and_update_entity,
)
from zenml.zen_server.rbac.models import ResourceType
from zenml.zen_server.rbac.utils import (
dehydrate_page,
dehydrate_response_model,
get_allowed_resource_ids,
)
from zenml.zen_server.utils import (
Expand Down Expand Up @@ -80,10 +85,12 @@ def list_workspaces(
Returns:
A list of workspaces.
"""
workspaces = zen_store().list_workspaces(
workspace_filter_model, hydrate=hydrate
return verify_permissions_and_list_entities(
filter_model=workspace_filter_model,
resource_type=ResourceType.WORKSPACE,
list_method=zen_store().list_workspaces,
hydrate=hydrate,
)
return dehydrate_page(workspaces)


@router.post(
Expand All @@ -105,8 +112,10 @@ def create_workspace(
Returns:
The created workspace.
"""
workspace = zen_store().create_workspace(workspace_request)
return dehydrate_response_model(workspace)
return verify_permissions_and_create_entity(
request_model=workspace_request,
create_method=zen_store().create_workspace,
)


@router.get(
Expand All @@ -132,10 +141,11 @@ def get_workspace(
Returns:
The requested workspace.
"""
workspace = zen_store().get_workspace(
workspace_name_or_id, hydrate=hydrate
return verify_permissions_and_get_entity(
id=workspace_name_or_id,
get_method=zen_store().get_workspace,
hydrate=hydrate,
)
return dehydrate_response_model(workspace)


@router.put(
Expand All @@ -159,11 +169,12 @@ def update_workspace(
Returns:
The updated workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id, hydrate=False)
updated_workspace = zen_store().update_workspace(
workspace_id=workspace.id, workspace_update=workspace_update
return verify_permissions_and_update_entity(
id=workspace_name_or_id,
update_model=workspace_update,
get_method=zen_store().get_workspace,
update_method=zen_store().update_workspace,
)
return dehydrate_response_model(updated_workspace)


@router.delete(
Expand All @@ -180,7 +191,11 @@ def delete_workspace(
Args:
workspace_name_or_id: Name or ID of the workspace.
"""
zen_store().delete_workspace(workspace_name_or_id)
verify_permissions_and_delete_entity(
id=workspace_name_or_id,
get_method=zen_store().get_workspace,
delete_method=zen_store().delete_workspace,
)


@router.get(
Expand All @@ -204,7 +219,10 @@ def get_workspace_statistics(
Returns:
All pipelines within the workspace.
"""
workspace = zen_store().get_workspace(workspace_name_or_id)
workspace = verify_permissions_and_get_entity(
id=workspace_name_or_id,
get_method=zen_store().get_workspace,
)

user_id = auth_context.user.id
component_filter = ComponentFilter(workspace_id=workspace.id)
Expand Down
16 changes: 14 additions & 2 deletions src/zenml/zen_stores/schemas/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
)
from zenml.utils.time_utils import utc_now
from zenml.zen_stores.schemas.base_schemas import NamedSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema

if TYPE_CHECKING:
from zenml.zen_stores.schemas import (
Expand Down Expand Up @@ -80,6 +82,14 @@ class UserSchema(NamedSchema, table=True):
external_user_id: Optional[UUID] = Field(nullable=True)
is_admin: bool = Field(default=False)
user_metadata: Optional[str] = Field(nullable=True)
default_workspace_id: Optional[UUID] = build_foreign_key_field(
source=__tablename__,
target=WorkspaceSchema.__tablename__,
source_column="default_workspace_id",
target_column="id",
ondelete="SET NULL",
nullable=True,
)

stacks: List["StackSchema"] = Relationship(back_populates="user")
components: List["StackComponentSchema"] = Relationship(
Expand Down Expand Up @@ -179,6 +189,7 @@ def from_user_request(cls, model: UserRequest) -> "UserSchema":
user_metadata=json.dumps(model.user_metadata)
if model.user_metadata
else None,
default_workspace_id=model.default_workspace_id,
)

@classmethod
Expand Down Expand Up @@ -266,8 +277,8 @@ def to_model(
include_resources: Whether the resources will be filled.
**kwargs: Keyword arguments to allow schema specific logic
include_private: Whether to include the user private information
this is to limit the amount of data one can get
about other users
this is to limit the amount of data one can get about other
users.
Returns:
The converted `UserResponse`.
Expand All @@ -293,6 +304,7 @@ def to_model(
created=self.created,
updated=self.updated,
is_admin=self.is_admin,
default_workspace_id=self.default_workspace_id,
),
metadata=metadata,
)
Expand Down
Loading

0 comments on commit f820ab2

Please sign in to comment.