Skip to content

Commit

Permalink
refactor: Updated to accommodate member update
Browse files Browse the repository at this point in the history
  • Loading branch information
mike-pisman committed Oct 22, 2023
1 parent 8199204 commit 73e4017
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 53 deletions.
9 changes: 7 additions & 2 deletions src/unipoll_api/actions/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unipoll_api.schemas import GroupSchemas, WorkspaceSchemas
from unipoll_api.exceptions import GroupExceptions, WorkspaceExceptions, ResourceExceptions
from unipoll_api.utils import Permissions
from unipoll_api.dependencies import get_member


# Get list of groups
Expand All @@ -22,7 +23,7 @@ async def get_groups(workspace: Workspace | None = None,
if workspace:
search_filter['workspace._id'] = workspace.id # type: ignore
if account:
search_filter['members._id'] = account.id # type: ignore
search_filter['members.account._id'] = account.id # type: ignore
search_result = await Group.find(search_filter, fetch_links=True).to_list()

# TODO: Rewrite to iterate over list of workspaces
Expand All @@ -47,6 +48,10 @@ async def create_group(workspace: Workspace,
await Permissions.check_permissions(workspace, "add_groups", check_permissions)
account = AccountManager.active_user.get()



member = await get_member(account, workspace)

# Check if group name is unique
group: Group # For type hinting, until Link type is supported
for group in workspace.groups: # type: ignore
Expand All @@ -63,7 +68,7 @@ async def create_group(workspace: Workspace,
raise GroupExceptions.ErrorWhileCreating(new_group)

# Add the account to group member list
await new_group.add_member(account, Permissions.GROUP_ALL_PERMISSIONS)
await new_group.add_member(member, Permissions.GROUP_ALL_PERMISSIONS)

# Create a policy for the new group
await workspace.add_policy(new_group, Permissions.WORKSPACE_BASIC_PERMISSIONS, False)
Expand Down
33 changes: 26 additions & 7 deletions src/unipoll_api/actions/members.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from beanie import WriteRules
from beanie.operators import In
from unipoll_api.documents import Account, Group, ResourceID, Workspace
from unipoll_api.documents import Account, Group, ResourceID, Workspace, Member
from unipoll_api.utils import Permissions
from unipoll_api.schemas import MemberSchemas
# from unipoll_api import AccountManager
from unipoll_api.exceptions import ResourceExceptions
from unipoll_api.dependencies import get_member


async def get_members(resource: Workspace | Group, check_permissions: bool = True) -> MemberSchemas.MemberList:
# Check if the user has permission to add members
await Permissions.check_permissions(resource, "get_members", check_permissions)

def build_member_scheme(member: Account) -> MemberSchemas.Member:
member_data = member.model_dump(include={'id', 'first_name', 'last_name', 'email'})
member_scheme = MemberSchemas.Member(**member_data)
return member_scheme
def build_member_scheme(member: Member) -> MemberSchemas.Member:
return MemberSchemas.Member(id=member.id,
account_id=member.account.id,
first_name=member.account.first_name,
last_name=member.account.last_name,
email=member.account.email)

member_list = [build_member_scheme(member) for member in resource.members] # type: ignore
# Return the list of members
Expand All @@ -36,13 +39,29 @@ async def add_members(resource: Workspace | Group,
account_list = await Account.find(In(Account.id, accounts)).to_list()
# Add the accounts to the group member list with basic permissions

new_members = []

for account in account_list:
default_permissions = eval("Permissions." + resource.resource_type.upper() + "_BASIC_PERMISSIONS")
await resource.add_member(account, default_permissions, save=False)
if resource.resource_type == "group":
member = await get_member(account, resource.workspace)
new_member = await resource.add_member(member, default_permissions, save=False)
new_members.append(new_member)
elif resource.resource_type == "workspace":
new_member = await resource.add_member(account, default_permissions, save=False)
new_members.append(new_member)
await resource.save(link_rule=WriteRules.WRITE) # type: ignore

member_list = []
for new_member in new_members:
member_list.append(MemberSchemas.Member(id=new_member.id,
account_id=new_member.account.id,
first_name=new_member.account.first_name,
last_name=new_member.account.last_name,
email=new_member.account.email))

# Return the list of members added to the group
return MemberSchemas.MemberList(members=[MemberSchemas.Member(**account.model_dump()) for account in account_list])
return MemberSchemas.MemberList(members=member_list)


# Remove a member from a workspace
Expand Down
44 changes: 34 additions & 10 deletions src/unipoll_api/actions/policy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from unipoll_api import AccountManager
from unipoll_api.documents import Account, Workspace, Group, Policy, Resource
from unipoll_api.schemas import MemberSchemas, PolicySchemas
from unipoll_api.documents import Account, Workspace, Group, Policy, Resource, Member
from unipoll_api.schemas import MemberSchemas, PolicySchemas, GroupSchemas
from unipoll_api.exceptions import ResourceExceptions
from unipoll_api.utils import Permissions
from unipoll_api.utils.permissions import check_permissions
from unipoll_api.dependencies import get_member


# Helper function to get policies from a resource
Expand All @@ -14,15 +15,17 @@ async def get_policies_from_resource(resource: Resource) -> list[Policy]:
await check_permissions(resource, "get_policies")
return resource.policies # type: ignore
except ResourceExceptions.UserNotAuthorized:
print("User not authorized")
account = AccountManager.active_user.get()
member = await get_member(account, resource)
for policy in resource.policies:
if policy.policy_holder.ref.id == account.id: # type: ignore
if policy.policy_holder.ref.id == member.id: # type: ignore
policies.append(policy) # type: ignore
return policies


# Get all policies of a workspace
async def get_policies(policy_holder: Account | Group | None = None,
async def get_policies(policy_holder: Member | Group | None = None,
resource: Resource | None = None) -> PolicySchemas.PolicyList:
policy_list = []
policy: Policy
Expand Down Expand Up @@ -58,7 +61,18 @@ async def get_policy(policy: Policy, permission_check: bool = True) -> PolicySch

# Get the policy holder
policy_holder = await policy.get_policy_holder()
member = MemberSchemas.Member(**policy_holder.model_dump())
member, group = None, None
if policy_holder.document_type == "member":
await policy_holder.fetch_link("account")
member = MemberSchemas.Member(id=policy_holder.id,
account_id=policy_holder.account.id,
email=policy_holder.account.email,
first_name=policy_holder.account.first_name,
last_name=policy_holder.account.last_name)
elif policy_holder.document_type == "group":
group = GroupSchemas.Group(id=policy_holder.id,
name=policy_holder.name,
description=policy_holder.description)

# Get the permissions based on the resource type and convert it to a list of strings
permission_type = Permissions.PermissionTypes[parent_resource.resource_type]
Expand All @@ -67,7 +81,7 @@ async def get_policy(policy: Policy, permission_check: bool = True) -> PolicySch
# Return the policy
return PolicySchemas.PolicyShort(id=policy.id,
policy_holder_type=policy.policy_holder_type,
policy_holder=member.model_dump(exclude_unset=True),
policy_holder=member or group,
permissions=permissions)


Expand All @@ -93,7 +107,17 @@ async def update_policy(policy: Policy,
await Policy.save(policy)

policy_holder = await policy.get_policy_holder()

return PolicySchemas.PolicyOutput(
permissions=permission_type(policy.permissions).name.split('|'), # type: ignore
policy_holder=policy_holder.model_dump())
if policy_holder.document_type == "member":
await policy_holder.fetch_link("account")
policy_holder_schema = MemberSchemas.Member(id=policy_holder.id,
account_id=policy_holder.account.id,
email=policy_holder.account.email,
first_name=policy_holder.account.first_name,
last_name=policy_holder.account.last_name)
elif policy_holder.document_type == "group":
policy_holder_schema = GroupSchemas.Group(id=policy_holder.id,
name=policy_holder.name,
description=policy_holder.description)

return PolicySchemas.PolicyOutput(permissions=permission_type(policy.permissions).name.split('|'), # type: ignore
policy_holder=policy_holder_schema)
16 changes: 9 additions & 7 deletions src/unipoll_api/actions/workspace.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from bson import DBRef
from unipoll_api import AccountManager
from unipoll_api import actions
from unipoll_api.documents import Workspace, Account, Policy
from unipoll_api.documents import Workspace, Account, Policy, Member
from unipoll_api.utils import Permissions
from unipoll_api.schemas import WorkspaceSchemas
from unipoll_api.exceptions import WorkspaceExceptions
from unipoll_api.dependencies import get_member


# Get a list of workspaces where the account is a owner/member
async def get_workspaces(account: Account | None = None) -> WorkspaceSchemas.WorkspaceList:
account = AccountManager.active_user.get()
workspace_list = []

search_result = await Workspace.find(Workspace.members.id == account.id).to_list() # type: ignore
members = await Member.find(Member.account.id == account.id, fetch_links=True).to_list()
workspaces = [member.workspace for member in members]

# Create a workspace list for output schema using the search results
for workspace in search_result:
for workspace in workspaces:
workspace_list.append(WorkspaceSchemas.WorkspaceShort(
**workspace.model_dump(exclude={'members', 'groups', 'permissions'})))
**workspace.model_dump(exclude={'groups', 'permissions'})))

return WorkspaceSchemas.WorkspaceList(workspaces=workspace_list)

Expand Down Expand Up @@ -71,22 +73,22 @@ async def update_workspace(workspace: Workspace,
await Permissions.check_permissions(workspace, "update_workspace", check_permissions)
save_changes = False

# Check if user suplied a name
# Check if user supplied a name
if input_data.name and input_data.name != workspace.name:
# Check if workspace name is unique
if await Workspace.find_one({"name": input_data.name}) and workspace.name != input_data.name:
raise WorkspaceExceptions.NonUniqueName(input_data.name)
workspace.name = input_data.name # Update the name
save_changes = True
# Check if user suplied a description
# Check if user supplied a description
if input_data.description and input_data.description != workspace.description:
workspace.description = input_data.description # Update the description
save_changes = True
# Save the updated workspace
if save_changes:
await Workspace.save(workspace)
# Return the updated workspace
return WorkspaceSchemas.Workspace(**workspace.model_dump())
return WorkspaceSchemas.Workspace(**workspace.model_dump(include={'id', 'name', 'description'}))


# Delete a workspace
Expand Down
14 changes: 13 additions & 1 deletion src/unipoll_api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Annotated
from functools import wraps
from bson import DBRef
from fastapi import Cookie, Depends, Query, HTTPException, WebSocket
from unipoll_api.account_manager import active_user, get_current_active_user
from unipoll_api.documents import ResourceID, Workspace, Group, Account, Poll, Policy
from unipoll_api.documents import ResourceID, Workspace, Group, Account, Poll, Policy, Member, Resource
from unipoll_api import exceptions as Exceptions


Expand All @@ -29,6 +30,17 @@ async def get_account(account_id: ResourceID) -> Account:
return account


async def get_member(account: Account, resource: Resource) -> Member:
"""
Returns a member with the given id.
"""

for member in resource.members:
if member.account.id == account.id:
return member
raise Exceptions.ResourceExceptions.ResourceNotFound("member", account.id)


async def websocket_auth(websocket: WebSocket,
session: Annotated[str | None, Cookie()] = None,
token: Annotated[str | None, Query()] = None) -> dict:
Expand Down
31 changes: 20 additions & 11 deletions src/unipoll_api/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AccessToken(BeanieBaseAccessToken, Document): # type: ignore
class Resource(Document):
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
resource_type: Literal["workspace", "group", "poll"]
document_type: Literal["resource"] = "resource"
name: str = Field(
title="Name", description="Name of the resource", min_length=3, max_length=50)
description: str = Field(default="", title="Description", max_length=1000)
Expand Down Expand Up @@ -71,6 +72,7 @@ async def remove_policy_by_holder(self, policy_holder: "Group | Member", save: b

class Account(BeanieBaseUser, Document): # type: ignore
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
document_type: Literal["account"] = "account"
first_name: str = Field(
default_factory=str,
max_length=20,
Expand All @@ -85,18 +87,21 @@ class Account(BeanieBaseUser, Document): # type: ignore

class Workspace(Resource):
resource_type: Literal["workspace"] = "workspace"
document_type: Literal["workspace"] = "workspace"
members: list[Link["Member"]] = []
groups: list[Link["Group"]] = []
polls: list[Link["Poll"]] = []

async def add_member(self, account: "Account", permissions, save: bool = True) -> "Account":
async def add_member(self, account: "Account", permissions, save: bool = True) -> "Member":
new_member = await Member(account=account, resource=(await create_link(self))).create() # type: ignore
new_policy = await self.add_policy(new_member, permissions, save=False) # type: ignore
new_member.policies.append(new_policy) # type: ignore

self.members.append(new_member) # type: ignore

if save:
await self.save(link_rule=WriteRules.WRITE) # type: ignore
return account
return new_member

async def remove_member(self, member: "Member", save: bool = True) -> bool:
# Remove the account from the workspace
Expand Down Expand Up @@ -125,24 +130,24 @@ async def remove_member(self, member: "Member", save: bool = True) -> bool:

class Group(Resource):
resource_type: Literal["group"] = "group"
document_type: Literal["group"] = "group"
workspace: BackLink[Workspace] = Field(original_field="groups")
members: list[Link["Member"]] = []
groups: list[Link["Group"]] = []

async def add_member(self, account: "Account", permissions, save: bool = True) -> "Account":
if account.id not in [i.id for i in self.workspace.members]: # type: ignore
async def add_member(self, member: "Member", permissions, save: bool = True) -> "Member":
if member.workspace.id != self.workspace.id:
from unipoll_api.exceptions import WorkspaceExceptions
raise WorkspaceExceptions.UserNotMember(
self.workspace, account) # type: ignore
self.workspace, member) # type: ignore

new_member = await Member(account=account, resource=(await create_link(self))).create() # type: ignore
# Add the account to the group
self.members.append(new_member) # type: ignore
self.members.append(member) # type: ignore
# Create a policy for the new member
await self.add_policy(new_member, permissions, save=False) # type: ignore
await self.add_policy(member, permissions, save=False) # type: ignore
if save:
await self.save(link_rule=WriteRules.WRITE) # type: ignore
return account
return member

async def remove_member(self, account, save: bool = True) -> bool:
# Remove the account from the group
Expand All @@ -164,6 +169,7 @@ async def remove_member(self, account, save: bool = True) -> bool:

class Poll(Resource):
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
document_type: Literal["poll"] = "poll"
workspace: Link[Workspace]
resource_type: Literal["poll"] = "poll"
public: bool
Expand All @@ -174,6 +180,7 @@ class Poll(Resource):

class Policy(Document):
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
document_type: Literal["policy"] = "policy"
parent_resource: Link[Workspace] | Link[Group] | Link[Poll]
policy_holder_type: Literal["member", "group"]
policy_holder: Link["Group"] | Link["Member"]
Expand Down Expand Up @@ -201,6 +208,8 @@ async def get_policy_holder(self, fetch_links: bool = False) -> "Group | Member"

class Member(Document):
id: ResourceID = Field(default_factory=ResourceID, alias="_id")
document_type: Literal["member"] = "member"
account: Link[Account]
resource: Link[Workspace] | Link[Group] | Link[Poll]
# policy: Link[Policy]
workspace: BackLink[Workspace] = Field(original_field="members")
groups: list[BackLink[Group]] = Field(original_field="members")
policies: list[Link[Policy]] = []
11 changes: 7 additions & 4 deletions src/unipoll_api/routes/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ async def get_workspace(workspace: Workspace = Depends(Dependencies.get_workspac


# Update a workspace with the given id
@router.patch("/{workspace_id}", response_description="Updated workspace", response_model=WorkspaceSchemas.Workspace)
@router.patch("/{workspace_id}",
response_description="Updated workspace",
response_model=WorkspaceSchemas.Workspace,
response_model_exclude_none=True)
async def update_workspace(workspace: Workspace = Depends(Dependencies.get_workspace),
input_data: WorkspaceSchemas.WorkspaceUpdateRequest = Body(...)
):
input_data: WorkspaceSchemas.WorkspaceUpdateRequest = Body(...)):
"""
Updates the workspace with the given id.
Query parameters:
Expand Down Expand Up @@ -222,7 +224,8 @@ async def get_workspace_policies(workspace: Workspace = Depends(Dependencies.get
account_id: ResourceID = Query(None)):
try:
account = await Dependencies.get_account(account_id) if account_id else None
return await actions.PolicyActions.get_policies(resource=workspace, policy_holder=account)
member = await Dependencies.get_member(account, workspace) if account else None
return await actions.PolicyActions.get_policies(resource=workspace, policy_holder=member)
except APIException as e:
raise HTTPException(status_code=e.code, detail=str(e))

Expand Down
Loading

0 comments on commit 73e4017

Please sign in to comment.