Skip to content

Commit

Permalink
Make sure generated API works
Browse files Browse the repository at this point in the history
  • Loading branch information
goFrendiAsgard committed Jan 27, 2025
1 parent 2f721f0 commit f78d8e4
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ async def create_my_entity(self, data: MyEntityCreateWithAudit) -> MyEntityRespo
@BaseService.route(
"/api/v1/my-entities/bulk",
methods=["put"],
response_model=MyEntityResponse,
response_model=list[MyEntityResponse],
)
async def update_my_entity_bulk(
self, my_entity_ids: list[str], data: MyEntityUpdateWithAudit
) -> MyEntityResponse:
) -> list[MyEntityResponse]:
await self.my_entity_repository.update_bulk(my_entity_ids, data)
return await self.my_entity_repository.get_by_ids(my_entity_ids)

Expand All @@ -89,11 +89,11 @@ async def update_my_entity(
@BaseService.route(
"/api/v1/my-entities/bulk",
methods=["delete"],
response_model=MyEntityResponse,
response_model=list[MyEntityResponse],
)
async def delete_my_entity_bulk(
self, my_entity_ids: list[str], deleted_by: str
) -> MyEntityResponse:
) -> list[MyEntityResponse]:
my_entities = await self.my_entity_repository.get_by_ids(my_entity_ids)
await self.my_entity_repository.delete_bulk(my_entity_ids)
return my_entities
Expand All @@ -106,6 +106,6 @@ async def delete_my_entity_bulk(
async def delete_my_entity(
self, my_entity_id: str, deleted_by: str
) -> MyEntityResponse:
my_entity = await self.my_entity_repository.get_by_id(my_entity.id)
my_entity = await self.my_entity_repository.get_by_id(my_entity_id)
await self.my_entity_repository.delete(my_entity_id)
return my_entity
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ async def delete_bulk(self, id_list: list[str]) -> list[DBModel]:
async def update(self, id: str, data: UpdateModel) -> DBModel:
now = datetime.datetime.now(datetime.timezone.utc)
update_data = self._model_to_data_dict(data, updated_at=now)
update_data = {k: v for k, v in update_data.items() if v is not None}
async with self._session_scope() as session:
statement = (
update(self.db_model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,56 +160,62 @@ async def client_method(self, *args, **kwargs):
return client_method


def _create_api_client_method(logger: Logger, param: RouteParam, base_url: str):
def _create_api_client_method(logger: Logger, route_param: RouteParam, base_url: str):
async def client_method(*args, **kwargs):
url = base_url + param.path
url = base_url + route_param.path
method = (
param.methods[0].lower()
if isinstance(param.methods, list)
else param.methods.lower()
route_param.methods[0].lower()
if isinstance(route_param.methods, list)
else route_param.methods.lower()
)
# Get the signature of the original function
sig = inspect.signature(param.func)
sig = inspect.signature(route_param.func)
# Bind the arguments to the signature
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
# Analyze parameters
params = list(sig.parameters.values())
body_params = [
p
for p in params
if p.name != "self" and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
function_params = list(sig.parameters.values())
body_param_names = [
p.name
for p in function_params
if (
p.name != "self"
and f"{{{p.name}}}" not in route_param.path
and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
and (
method not in ["get", "delete"]
or (method == "delete" and p.annotation not in [str, float, bool])
)
)
]
# Prepare the request
path_params = {}
query_params = {}
body = {}
body_params = {}
for name, value in bound_args.arguments.items():
if name == "self":
continue
if f"{{{name}}}" in param.path:
if f"{{{name}}}" in route_param.path:
path_params[name] = value
elif isinstance(value, BaseModel):
body = _parse_api_param(value)
elif method in ["get", "delete"]:
elif name not in body_param_names:
query_params[name] = _parse_api_param(value)
elif len(body_params) == 1 and name == body_params[0].name:
elif len(body_param_names) == 1 and name == body_param_names[0]:
# If there's only one body parameter, use its value directly
body = _parse_api_param(value)
body_params = _parse_api_param(value)
else:
body[name] = _parse_api_param(value)
body_params[name] = _parse_api_param(value)
# Format the URL with path parameters
url = url.format(**path_params)
logger.info(
f"Sending request to {url} with method {method}, json={body}, params={query_params}" # noqa
f"Sending request to {url} with method {method}, json={body_params}, params={query_params}" # noqa
)
async with httpx.AsyncClient() as client:
if method in ["get", "delete"]:
response = await getattr(client, method)(url, params=query_params)
else:
response = await getattr(client, method)(
url, json=body, params=query_params
)
response = await client.request(
method=method,
url=url,
params=query_params,
json=None if method == "get" else body_params,
)
logger.info(
f"Received response: status={response.status_code}, content={response.content}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ async def create_permission(
@BaseService.route(
"/api/v1/permissions/bulk",
methods=["put"],
response_model=PermissionResponse,
response_model=list[PermissionResponse],
)
async def update_permission_bulk(
self, permission_ids: list[str], data: PermissionUpdateWithAudit
) -> PermissionResponse:
) -> list[PermissionResponse]:
await self.permission_repository.update_bulk(permission_ids, data)
return await self.permission_repository.get_by_ids(permission_ids)

Expand All @@ -93,11 +93,11 @@ async def update_permission(
@BaseService.route(
"/api/v1/permissions/bulk",
methods=["delete"],
response_model=PermissionResponse,
response_model=list[PermissionResponse],
)
async def delete_permission_bulk(
self, permission_ids: list[str], deleted_by: str
) -> PermissionResponse:
) -> list[PermissionResponse]:
permissions = await self.permission_repository.get_by_ids(permission_ids)
await self.permission_repository.delete_bulk(permission_ids)
return permissions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ async def create_role(
@BaseService.route(
"/api/v1/roles/bulk",
methods=["put"],
response_model=RoleResponse,
response_model=list[RoleResponse],
)
async def update_role_bulk(
self, role_ids: list[str], data: RoleUpdateWithPermissionsAndAudit
) -> RoleResponse:
) -> list[RoleResponse]:
permission_ids = [row.get_permission_ids() for row in data]
data = [row.get_role_update_with_audit() for row in data]
await self.role_repository.update_bulk(role_ids, data)
Expand Down Expand Up @@ -117,11 +117,11 @@ async def update_role(
@BaseService.route(
"/api/v1/roles/bulk",
methods=["delete"],
response_model=RoleResponse,
response_model=list[RoleResponse],
)
async def delete_role_bulk(
self, role_ids: list[str], deleted_by: str
) -> RoleResponse:
) -> list[RoleResponse]:
roles = await self.role_repository.get_by_ids(role_ids)
await self.role_repository.delete_bulk(role_ids)
await self.role_repository.remove_all_permissions(role_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ async def create_user(self, data: UserCreateWithRolesAndAudit) -> UserResponse:
@BaseService.route(
"/api/v1/users/bulk",
methods=["put"],
response_model=UserResponse,
response_model=list[UserResponse],
)
async def update_user_bulk(
self, user_ids: list[str], data: UserUpdateWithRolesAndAudit
) -> UserResponse:
) -> list[UserResponse]:
role_ids = [row.get_role_ids() for row in data]
user_data = [row.get_user_create_with_audit() for row in data]
await self.user_repository.update_bulk(user_ids, user_data)
Expand Down Expand Up @@ -115,11 +115,11 @@ async def update_user(
@BaseService.route(
"/api/v1/users/bulk",
methods=["delete"],
response_model=UserResponse,
response_model=list[UserResponse],
)
async def delete_user_bulk(
self, user_ids: list[str], deleted_by: str
) -> UserResponse:
) -> list[UserResponse]:
roles = await self.user_repository.get_by_ids(user_ids)
await self.user_repository.delete_bulk(user_ids)
await self.user_repository.remove_all_roles(user_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/zrb/task/cmd_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def _exec_action(self, ctx: AnyContext) -> CmdResult:
max_error_line=self._max_error_line,
)
# Check for errors
if return_code != 0:
if return_code > 0:
raise Exception(f"Process {self._name} exited ({return_code})")
ctx.log_info(f"Exit status: {return_code}")
return cmd_result
Expand Down
14 changes: 13 additions & 1 deletion src/zrb/util/cmd/command.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import os
import re
import signal
import sys
from collections.abc import Callable

import psutil

from zrb.cmd.cmd_result import CmdResult


Expand Down Expand Up @@ -92,3 +93,14 @@ async def __read_stream(
stdout = await stdout_task
stderr = await stderr_task
return CmdResult(stdout, stderr), return_code


def kill_pid(pid: int, print_method: Callable[..., None] | None = None):
actual_print_method = print_method if print_method is not None else print
parent = psutil.Process(pid)
children = parent.children(recursive=True)
for child in children:
actual_print_method(f"Killing child process {child.pid}")
child.terminate()
actual_print_method(f"Killing process {pid}")
parent.terminate()
Loading

0 comments on commit f78d8e4

Please sign in to comment.