diff --git a/.github/workflows/_zrb.yml b/.github/workflows/_zrb.yml index 24e9861a..2fb101ea 100644 --- a/.github/workflows/_zrb.yml +++ b/.github/workflows/_zrb.yml @@ -17,7 +17,7 @@ jobs: Run-command: runs-on: ubuntu-latest container: - image: stalchmst/zrb:1.0.0b6 + image: stalchmst/zrb:1.0.0b9 steps: - name: Check out repository code uses: actions/checkout@v3 diff --git a/Dockerfile b/Dockerfile index 14afeb7d..69c96e16 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.10-bookworm +FROM python:3.10-slim-bookworm # Create and set workdir RUN mkdir -p /project diff --git a/docs/notes.md b/docs/notes.md new file mode 100644 index 00000000..6cd25da9 --- /dev/null +++ b/docs/notes.md @@ -0,0 +1,85 @@ +# Create permission bulk + +```json +[ + {"name": "book:create", "description": "create book"}, + {"name": "book:update", "description": "update book"}, + {"name": "book:delete", "description": "delete book"}, + {"name": "book:view", "description": "view book"} +] +``` + +# Create roles bulk + +```json +[ + { + "name": "librarian", + "description": "Full access to manage books", + "permission_names": [ + "book:create", + "book:update", + "book:delete", + "book:view" + ] + }, + { + "name": "assistant-librarian", + "description": "Can create, update, and view books, but not delete them", + "permission_names": [ + "book:create", + "book:update", + "book:view" + ] + }, + { + "name": "viewer", + "description": "Can only view books", + "permission_names": [ + "book:view" + ] + } +] +``` + +# Create user bulk + +```json +[ + { + "username": "john_doe", + "password": "password123", + "role_names": [ + "librarian" + ], + "active": true + }, + { + "username": "jane_smith", + "password": "securePass!2025", + "role_names": [ + "assistant-librarian" + ], + "active": true + + }, + { + "username": "alex_viewer", + "password": "viewOnly@2025", + "role_names": [ + "viewer" + ], + "active": true + + }, + { + "username": "emily_helper", + "password": "strongPass$567", + "role_names": [ + "assistant-librarian", + "viewer" + ], + "active": true + } +] +``` diff --git a/poetry.lock b/poetry.lock index c953fab7..007b91f8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2590,6 +2590,36 @@ files = [ {file = "protobuf-5.29.3.tar.gz", hash = "sha256:5da0f41edaf117bde316404bad1a486cb4ededf8e4a54891296f648e8e076620"}, ] +[[package]] +name = "psutil" +version = "6.1.1" +description = "Cross-platform lib for process and system monitoring in Python." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "psutil-6.1.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8"}, + {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777"}, + {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:8df0178ba8a9e5bc84fed9cfa61d54601b371fbec5c8eebad27575f1e105c0d4"}, + {file = "psutil-6.1.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:1924e659d6c19c647e763e78670a05dbb7feaf44a0e9c94bf9e14dfc6ba50468"}, + {file = "psutil-6.1.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:018aeae2af92d943fdf1da6b58665124897cfc94faa2ca92098838f83e1b1bca"}, + {file = "psutil-6.1.1-cp27-none-win32.whl", hash = "sha256:6d4281f5bbca041e2292be3380ec56a9413b790579b8e593b1784499d0005dac"}, + {file = "psutil-6.1.1-cp27-none-win_amd64.whl", hash = "sha256:c777eb75bb33c47377c9af68f30e9f11bc78e0f07fbf907be4a5d70b2fe5f030"}, + {file = "psutil-6.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8"}, + {file = "psutil-6.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377"}, + {file = "psutil-6.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003"}, + {file = "psutil-6.1.1-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160"}, + {file = "psutil-6.1.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3"}, + {file = "psutil-6.1.1-cp36-cp36m-win32.whl", hash = "sha256:384636b1a64b47814437d1173be1427a7c83681b17a450bfc309a1953e329603"}, + {file = "psutil-6.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:8be07491f6ebe1a693f17d4f11e69d0dc1811fa082736500f649f79df7735303"}, + {file = "psutil-6.1.1-cp37-abi3-win32.whl", hash = "sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53"}, + {file = "psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649"}, + {file = "psutil-6.1.1.tar.gz", hash = "sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5"}, +] + +[package.extras] +dev = ["abi3audit", "black", "check-manifest", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pytest-cov", "requests", "rstcheck", "ruff", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "virtualenv", "vulture", "wheel"] +test = ["pytest", "pytest-xdist", "setuptools"] + [[package]] name = "py-rust-stemmers" version = "0.1.3" @@ -3998,4 +4028,4 @@ rag = ["chromadb", "pdfplumber"] [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<4.0.0" -content-hash = "7d510a3ccd6606fd0e13785e2e74aad29fb7f886229540e7f5409de35d34dc05" +content-hash = "af381af9af438ad6024cf06069317cdcd19f760745dc8ef5638a5c0666feefb5" diff --git a/pyproject.toml b/pyproject.toml index fa49d875..0415b99e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zrb" -version = "1.0.0b6" +version = "1.0.0b9" description = "Your Automation Powerhouse" authors = ["Go Frendi Gunawan "] license = "AGPL-3.0-or-later" @@ -52,12 +52,13 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"} ulid-py = "^1.1.0" pydantic-ai = "^0.0.19" fastembed = "^0.5.1" +psutil = "^6.1.1" [tool.poetry.extras] # poetry install -E rag rag = ["chromadb", "pdfplumber"] -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] alembic = "^1.14.0" # FastApp dependencies flake8 = "~7.1.1" pytest = "~8.3.3" diff --git a/src/zrb/__main__.py b/src/zrb/__main__.py index 0bf2a120..3926851d 100644 --- a/src/zrb/__main__.py +++ b/src/zrb/__main__.py @@ -16,8 +16,11 @@ def serve_cli(): cli.run(sys.argv[1:]) except KeyboardInterrupt: print(stylize_warning("\nStopped"), file=sys.stderr) + sys.exit(1) except RuntimeError as e: if f"{e}".lower() != "event loop is closed": raise e + sys.exit(1) except NodeNotFoundError as e: print(stylize_error(f"{e}"), file=sys.stderr) + sys.exit(1) diff --git a/src/zrb/builtin/git.py b/src/zrb/builtin/git.py index 7a766e91..ac533ef4 100644 --- a/src/zrb/builtin/git.py +++ b/src/zrb/builtin/git.py @@ -57,9 +57,9 @@ ) async def get_git_diff(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) diff = await get_diff( - repo_dir, ctx.input.source, ctx.input.current, log_method=ctx.print + repo_dir, ctx.input.source, ctx.input.current, print_method=ctx.print ) result = [] decorated = [] @@ -88,17 +88,17 @@ async def get_git_diff(ctx: AnyContext): ) async def prune_local_branches(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) ctx.print(stylize_faint("Get existing branches")) - branches = await get_branches(repo_dir, log_method=ctx.print) + branches = await get_branches(repo_dir, print_method=ctx.print) ctx.print(stylize_faint("Get current branch")) - current_branch = await get_current_branch(repo_dir, log_method=ctx.print) + current_branch = await get_current_branch(repo_dir, print_method=ctx.print) for branch in branches: if branch == current_branch or branch == "main" or branch == "master": continue ctx.print(stylize_faint(f"Removing local branch: {branch}")) try: - await delete_branch(repo_dir, branch, log_method=ctx.print) + await delete_branch(repo_dir, branch, print_method=ctx.print) except Exception as e: ctx.log_error(e) @@ -117,11 +117,11 @@ async def prune_local_branches(ctx: AnyContext): ) async def git_commit(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) ctx.print(stylize_faint("Add changes to staging")) - await add(repo_dir, log_method=ctx.print) + await add(repo_dir, print_method=ctx.print) ctx.print(stylize_faint("Commit changes")) - await commit(repo_dir, ctx.input.message, log_method=ctx.print) + await commit(repo_dir, ctx.input.message, print_method=ctx.print) @make_task( @@ -139,12 +139,12 @@ async def git_commit(ctx: AnyContext): ) async def git_pull(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) ctx.print(stylize_faint("Get current branch")) - current_branch = await get_current_branch(repo_dir, log_method=ctx.print) + current_branch = await get_current_branch(repo_dir, print_method=ctx.print) remote = ctx.input.remote ctx.print(stylize_faint(f"Pulling from {remote}/{current_branch}")) - await pull(repo_dir, remote, current_branch, log_method=ctx.print) + await pull(repo_dir, remote, current_branch, print_method=ctx.print) @make_task( @@ -161,9 +161,9 @@ async def git_pull(ctx: AnyContext): alias="push", ) async def git_push(ctx: AnyContext): - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) ctx.print(stylize_faint("Get current branch")) - current_branch = await get_current_branch(repo_dir, log_method=ctx.print) + current_branch = await get_current_branch(repo_dir, print_method=ctx.print) remote = ctx.input.remote ctx.print(stylize_faint(f"Pushing to {remote}/{current_branch}")) - await push(repo_dir, remote, current_branch, log_method=ctx.print) + await push(repo_dir, remote, current_branch, print_method=ctx.print) diff --git a/src/zrb/builtin/git_subtree.py b/src/zrb/builtin/git_subtree.py index 679d49f4..253fa960 100644 --- a/src/zrb/builtin/git_subtree.py +++ b/src/zrb/builtin/git_subtree.py @@ -35,7 +35,7 @@ ) async def git_add_subtree(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) ctx.print(stylize_faint("Add subtree")) await add_subtree( repo_dir=repo_dir, @@ -43,7 +43,7 @@ async def git_add_subtree(ctx: AnyContext): repo_url=ctx.input["repo-url"], branch=ctx.input["repo-branch"], prefix=ctx.input["repo-prefix"], - log_method=ctx.print, + print_method=ctx.print, ) @@ -56,7 +56,7 @@ async def git_add_subtree(ctx: AnyContext): ) async def git_pull_subtree(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) config = load_config(repo_dir) if not config.data: raise ValueError("No subtree config found") @@ -69,7 +69,7 @@ async def git_pull_subtree(ctx: AnyContext): prefix=detail.prefix, repo_url=detail.repo_url, branch=detail.branch, - log_method=ctx.print, + print_method=ctx.print, ) except Exception as e: if first_err is None: @@ -88,7 +88,7 @@ async def git_pull_subtree(ctx: AnyContext): ) async def git_push_subtree(ctx: AnyContext): ctx.print(stylize_faint("Get directory")) - repo_dir = await get_repo_dir(log_method=ctx.print) + repo_dir = await get_repo_dir(print_method=ctx.print) config = load_config(repo_dir) if not config.data: raise ValueError("No subtree config found") @@ -101,7 +101,7 @@ async def git_push_subtree(ctx: AnyContext): prefix=detail.prefix, repo_url=detail.repo_url, branch=detail.branch, - log_method=ctx.print, + print_method=ctx.print, ) except Exception as e: if first_err is None: diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_task.py b/src/zrb/builtin/project/add/fastapp/fastapp_task.py index 81070b55..ab40ce22 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_task.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_task.py @@ -20,11 +20,11 @@ @make_task( - name="validate-create-fastapp", + name="validate-add-fastapp", input=[project_dir_input, app_name_input], retries=0, ) -async def validate_create_fastapp(ctx: AnyContext): +async def validate_add_fastapp(ctx: AnyContext): project_dir = ctx.input.project_dir if not os.path.isdir(project_dir): raise ValueError(f"Project directory not exists: {project_dir}") @@ -39,7 +39,7 @@ async def validate_create_fastapp(ctx: AnyContext): project_dir_input, app_name_input, ], - upstream=validate_create_fastapp, + upstream=validate_add_fastapp, source_path=os.path.join(os.path.dirname(__file__), "fastapp_template"), render_source_path=False, destination_path="{ctx.input.project_dir}", @@ -57,6 +57,7 @@ async def validate_create_fastapp(ctx: AnyContext): "my_app_name": "{to_snake_case(ctx.input.app)}", "MY_APP_NAME": "{to_snake_case(ctx.input.app).upper()}", "my-secure-password": lambda _: get_random_name(), + "my-secret-key": lambda _: get_random_name(), }, ), # Register fastapp's tasks to project's zrb_init (project_dir/zrb_init.py) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/column/add_column_task.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/column/add_column_task.py index 22d40ee0..05616113 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/column/add_column_task.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/column/add_column_task.py @@ -1,14 +1,93 @@ +import os + +from my_app_name._zrb.config import APP_DIR from my_app_name._zrb.format_task import format_my_app_name_code from my_app_name._zrb.group import app_create_group +from my_app_name._zrb.input import ( + existing_entity_input, + new_column_input, + new_column_type_input, +) +from my_app_name._zrb.util import get_existing_schema_names + +from zrb import AnyContext, Task, make_task +from zrb.util.codemod.prepend_property_to_class import prepend_property_to_class +from zrb.util.file import read_file, write_file +from zrb.util.string.conversion import to_pascal_case, to_snake_case + + +@make_task( + name="validate-add-my-app-name-column", + input=existing_entity_input, + retries=0, +) +async def validate_add_my_app_name_column(ctx: AnyContext): + schema_name = ctx.input.entity + if schema_name not in get_existing_schema_names(): + raise ValueError(f"Schema not exist: {schema_name}") + + +@make_task( + name="update-my-app-name-schema", + input=[ + existing_entity_input, + new_column_input, + new_column_type_input, + ], + retries=0, + upstream=validate_add_my_app_name_column, +) +def update_my_app_name_schema(ctx: AnyContext): + snake_entity_name = to_snake_case(ctx.input.entity) + pascal_entity_name = to_pascal_case(ctx.input.entity) + schema_file_path = os.path.join(APP_DIR, "schema", f"{snake_entity_name}.py") + existing_code = read_file(schema_file_path) + snake_column_name = to_snake_case(ctx.input.column) + column_type = ctx.input.type + # Base + new_code = prepend_property_to_class( + original_code=existing_code, + class_name=f"{pascal_entity_name}Base", + property_name=snake_column_name, + annotation=column_type, + default_value=_get_default_value(column_type), + ) + # Update + new_code = prepend_property_to_class( + original_code=new_code, + class_name=f"{pascal_entity_name}Update", + property_name=snake_column_name, + annotation=f"{column_type} | None", + default_value="None", + ) + # Table + new_code = prepend_property_to_class( + original_code=new_code, + class_name=f"{pascal_entity_name}", + property_name=snake_column_name, + annotation=f"{column_type} | None", + default_value="Field(index=False)", + ) + write_file(schema_file_path, new_code) -from zrb import Task add_my_app_name_column = app_create_group.add_task( Task( name="add-my-app-name-column", description="๐Ÿ“Š Create new column on an entity", + upstream=update_my_app_name_schema, successor=format_my_app_name_code, retries=0, ), alias="column", ) + + +def _get_default_value(data_type: str) -> str: + if data_type == "str": + return '""' + if data_type in ("int", "float"): + return "0" + if data_type == "bool": + return "True" + return "None" diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py index 18fe3e84..acb1b772 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py @@ -48,7 +48,7 @@ @make_task( - name="validate-create-my-app-name-entity", + name="validate-add-my-app-name-entity", input=[ existing_module_input, new_entity_input, @@ -57,7 +57,7 @@ ], retries=0, ) -async def validate_create_my_app_name_entity(ctx: AnyContext): +async def validate_add_my_app_name_entity(ctx: AnyContext): module_name = to_snake_case(ctx.input.module) if module_name not in get_existing_module_names(): raise ValueError(f"Module not exist: {module_name}") @@ -88,6 +88,7 @@ async def validate_create_my_app_name_entity(ctx: AnyContext): match=is_in_app_schema_dir, transform={ "MyEntity": "{to_pascal_case(ctx.input.entity)}", + "my_entities": "{to_snake_case(ctx.input.plural)}", "my_column": "{to_snake_case(ctx.input.column)}", }, ), @@ -147,7 +148,7 @@ async def validate_create_my_app_name_entity(ctx: AnyContext): ), ], retries=0, - upstream=validate_create_my_app_name_entity, + upstream=validate_add_my_app_name_entity, ) create_my_app_name_entity_migration = CmdTask( diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py index 26c35b62..9f05379f 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py @@ -67,11 +67,11 @@ async def create_my_entity(self, data: MyEntityCreateWithAudit) -> MyEntityRespo @BaseService.route( "/api/v1/my-entities/bulk", methods=["put"], - response_model=MyEntityResponse, + response_model=list[MyEntityResponse], ) async def update_my_entity_bulk( self, my_entity_ids: list[str], data: MyEntityUpdateWithAudit - ) -> MyEntityResponse: + ) -> list[MyEntityResponse]: await self.my_entity_repository.update_bulk(my_entity_ids, data) return await self.my_entity_repository.get_by_ids(my_entity_ids) @@ -89,11 +89,11 @@ async def update_my_entity( @BaseService.route( "/api/v1/my-entities/bulk", methods=["delete"], - response_model=MyEntityResponse, + response_model=list[MyEntityResponse], ) async def delete_my_entity_bulk( self, my_entity_ids: list[str], deleted_by: str - ) -> MyEntityResponse: + ) -> list[MyEntityResponse]: my_entities = await self.my_entity_repository.get_by_ids(my_entity_ids) await self.my_entity_repository.delete_bulk(my_entity_ids) return my_entities @@ -106,6 +106,6 @@ async def delete_my_entity_bulk( async def delete_my_entity( self, my_entity_id: str, deleted_by: str ) -> MyEntityResponse: - my_entity = await self.my_entity_repository.get_by_id(my_entity.id) + my_entity = await self.my_entity_repository.get_by_id(my_entity_id) await self.my_entity_repository.delete(my_entity_id) return my_entity diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py index a3fc6e5f..e4e2c2ba 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py @@ -39,6 +39,7 @@ class MultipleMyEntityResponse(BaseModel): class MyEntity(SQLModel, table=True): + __tablename__ = "my_entities" id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) created_at: datetime.datetime = Field(index=True) created_by: str = Field(index=True) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py index 25de90fd..2fde78e1 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py @@ -3,6 +3,7 @@ @app.get("/api/v1/my-entities", response_model=MultipleMyEntityResponse) async def get_my_entities( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], page: int = 1, page_size: int = 10, sort: str | None = None, @@ -14,7 +15,10 @@ async def get_my_entities( @app.get("/api/v1/my-entities/{my_entity_id}", response_model=MyEntityResponse) -async def get_my_entity_by_id(my_entity_id: str) -> MyEntityResponse: +async def get_my_entity_by_id( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + my_entity_id: str, +) -> MyEntityResponse: return await my_module_client.get_my_entity_by_id(my_entity_id) @@ -22,9 +26,12 @@ async def get_my_entity_by_id(my_entity_id: str) -> MyEntityResponse: "/api/v1/my-entities/bulk", response_model=list[MyEntityResponse], ) -async def create_my_entity_bulk(data: list[MyEntityCreate]): +async def create_my_entity_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: list[MyEntityCreate], +) -> list[MyEntityResponse]: return await my_module_client.create_my_entity_bulk( - [row.with_audit(created_by="system") for row in data] + [row.with_audit(created_by=current_user.id) for row in data] ) @@ -32,17 +39,26 @@ async def create_my_entity_bulk(data: list[MyEntityCreate]): "/api/v1/my-entities", response_model=MyEntityResponse, ) -async def create_my_entity(data: MyEntityCreate): - return await my_module_client.create_my_entity(data.with_audit(created_by="system")) +async def create_my_entity( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: MyEntityCreate, +) -> MyEntityResponse: + return await my_module_client.create_my_entity( + data.with_audit(created_by=current_user.id) + ) @app.put( "/api/v1/my-entities/bulk", response_model=list[MyEntityResponse], ) -async def update_my_entity_bulk(my_entity_ids: list[str], data: MyEntityUpdate): +async def update_my_entity_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + my_entity_ids: list[str], + data: MyEntityUpdate, +) -> list[MyEntityResponse]: return await my_module_client.update_my_entity_bulk( - my_entity_ids, data.with_audit(updated_by="system") + my_entity_ids, data.with_audit(updated_by=current_user.id) ) @@ -50,9 +66,13 @@ async def update_my_entity_bulk(my_entity_ids: list[str], data: MyEntityUpdate): "/api/v1/my-entities/{my_entity_id}", response_model=MyEntityResponse, ) -async def update_my_entity(my_entity_id: str, data: MyEntityUpdate): +async def update_my_entity( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + my_entity_id: str, + data: MyEntityUpdate, +) -> MyEntityResponse: return await my_module_client.update_my_entity( - my_entity_id, data.with_audit(updated_by="system") + my_entity_id, data.with_audit(updated_by=current_user.id) ) @@ -60,9 +80,12 @@ async def update_my_entity(my_entity_id: str, data: MyEntityUpdate): "/api/v1/my-entities/bulk", response_model=list[MyEntityResponse], ) -async def delete_my_entity_bulk(my_entity_ids: list[str]): +async def delete_my_entity_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + my_entity_ids: list[str], +) -> list[MyEntityResponse]: return await my_module_client.delete_my_entity_bulk( - my_entity_ids, deleted_by="system" + my_entity_ids, deleted_by=current_user.id ) @@ -70,5 +93,10 @@ async def delete_my_entity_bulk(my_entity_ids: list[str]): "/api/v1/my-entities/{my_entity_id}", response_model=MyEntityResponse, ) -async def delete_my_entity(my_entity_id: str): - return await my_module_client.delete_my_entity(my_entity_id, deleted_by="system") +async def delete_my_entity( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + my_entity_id: str, +) -> MyEntityResponse: + return await my_module_client.delete_my_entity( + my_entity_id, deleted_by=current_user.id + ) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/input.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/input.py index 774f2903..3fa382a4 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/input.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/input.py @@ -3,6 +3,14 @@ from zrb import OptionInput, StrInput from zrb.util.string.conversion import pluralize +run_env_input = OptionInput( + name="env", + description="Running environment", + prompt="Running Environment", + options=["dev", "prod"], + default_str="prod", +) + new_module_input = StrInput( name="module", description="Module name", prompt="New module name" ) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_task.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_task.py index 404565d8..a15ace03 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_task.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_task.py @@ -24,11 +24,11 @@ @make_task( - name="validate-create-my-app-name-module", + name="validate-add-my-app-name-module", input=new_module_input, retries=0, ) -async def validate_create_my_app_name_module(ctx: AnyContext): +async def validate_add_my_app_name_module(ctx: AnyContext): if ctx.input.module in get_existing_module_names(): raise ValueError(f"Module already exists: {ctx.input.module}") @@ -93,7 +93,7 @@ async def validate_create_my_app_name_module(ctx: AnyContext): ), ], retries=0, - upstream=validate_create_my_app_name_module, + upstream=validate_add_my_app_name_module, ) add_my_app_name_module = app_create_group.add_task( diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/gateway/subroute/my_module.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/gateway/subroute/my_module.py index 87ef09db..f00280a3 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/gateway/subroute/my_module.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/gateway/subroute/my_module.py @@ -1,4 +1,8 @@ -from fastapi import FastAPI +from typing import Annotated + +from fastapi import Depends, FastAPI +from my_app_name.module.gateway.util.auth import get_current_user +from my_app_name.schema.user import AuthUserResponse def serve_my_module_route(app: FastAPI): diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py index f6013649..ecf23e3c 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py @@ -9,8 +9,14 @@ app_migrate_group, app_run_group, ) +from my_app_name._zrb.input import run_env_input from my_app_name._zrb.module.add_module_task import add_my_app_name_module -from my_app_name._zrb.util import create_migration, migrate_module, run_microservice +from my_app_name._zrb.task_util import ( + create_migration, + migrate_module, + run_microservice, + run_my_app_name, +) from my_app_name._zrb.venv_task import prepare_venv from zrb import CmdTask, EnvFile, EnvMap, Task @@ -52,6 +58,7 @@ CmdTask( name="run-monolith-my-app-name", description="๐Ÿ—ฟ Run My App Name as a monolith", + input=run_env_input, env=[ EnvFile(path=os.path.join(APP_DIR, "template.env")), EnvMap(vars=MONOLITH_ENV_VARS), @@ -59,7 +66,7 @@ cwd=APP_DIR, cmd=[ ACTIVATE_VENV_SCRIPT, - 'fastapi dev main.py --port "${MY_APP_NAME_PORT}"', + run_my_app_name, ], render_cmd=False, retries=2, diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task_util.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task_util.py new file mode 100644 index 00000000..89c320e6 --- /dev/null +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task_util.py @@ -0,0 +1,100 @@ +import os + +from my_app_name._zrb.config import ( + ACTIVATE_VENV_SCRIPT, + APP_DIR, + MICROSERVICES_ENV_VARS, + MONOLITH_ENV_VARS, +) +from my_app_name._zrb.input import run_env_input +from my_app_name._zrb.util import ( + cd_module_script, + run_my_app_name, + set_create_migration_db_url_env, + set_env, + set_module_env, +) + +from zrb import Cmd, CmdTask, EnvFile, EnvMap, StrInput, Task +from zrb.util.string.conversion import to_snake_case + + +def create_migration(name: str, module: str) -> Task: + return CmdTask( + name=f"create-my-app-name-{name}-migration", + description=f"๐Ÿงฉ Create My App Name {name.capitalize()} DB migration", + input=StrInput( + name="message", + description="Migration message", + prompt="Migration message", + allow_empty=False, + ), + env=EnvFile(path=os.path.join(APP_DIR, "template.env")), + cwd=APP_DIR, + cmd=[ + ACTIVATE_VENV_SCRIPT, + set_create_migration_db_url_env(module), + set_module_env(module), + cd_module_script(module), + "alembic upgrade head", + Cmd( + "alembic revision --autogenerate -m {double_quote(ctx.input.message)}", + auto_render=True, + ), + ], + render_cmd=False, + retries=2, + ) + + +def migrate_module(name: str, module: str, as_microservices: bool) -> Task: + env_vars = ( + dict(MICROSERVICES_ENV_VARS) if as_microservices else dict(MONOLITH_ENV_VARS) + ) + if as_microservices: + env_vars["MY_APP_NAME_MODULES"] = to_snake_case(module) + return CmdTask( + name=( + f"migrate-my-app-name-{name}" + if as_microservices + else f"migrate-{name}-on-monolith" + ), + description=f"๐Ÿงฉ Run My App Name {name.capitalize()} DB migration", + env=[ + EnvFile(path=os.path.join(APP_DIR, "template.env")), + EnvMap(vars=env_vars), + ], + cwd=APP_DIR, + cmd=[ + ACTIVATE_VENV_SCRIPT, + cd_module_script(module), + "alembic upgrade head", + ], + render_cmd=False, + retries=2, + ) + + +def run_microservice(name: str, port: int, module: str) -> Task: + return CmdTask( + name=f"run-my-app-name-{name}", + description=f"๐Ÿงฉ Run My App Name {name.capitalize()}", + input=run_env_input, + env=[ + EnvFile(path=os.path.join(APP_DIR, "template.env")), + EnvMap( + vars={ + **MICROSERVICES_ENV_VARS, + } + ), + ], + cwd=APP_DIR, + cmd=[ + ACTIVATE_VENV_SCRIPT, + set_env("MY_APP_NAME_MODULES", module), + set_env("MY_APP_NAME_PORT", f"{port}"), + run_my_app_name, + ], + render_cmd=False, + retries=2, + ) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py index ec69faad..8131df53 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py @@ -1,95 +1,15 @@ import os import platform -from my_app_name._zrb.config import ( - ACTIVATE_VENV_SCRIPT, - APP_DIR, - MICROSERVICES_ENV_VARS, - MONOLITH_ENV_VARS, -) - -from zrb import Cmd, CmdTask, EnvFile, EnvMap, StrInput, Task -from zrb.util.string.conversion import double_quote, to_snake_case - - -def create_migration(name: str, module: str) -> Task: - return CmdTask( - name=f"create-my-app-name-{name}-migration", - description=f"๐Ÿงฉ Create My App Name {name.capitalize()} DB migration", - input=StrInput( - name="message", - description="Migration message", - prompt="Migration message", - allow_empty=False, - ), - env=EnvFile(path=os.path.join(APP_DIR, "template.env")), - cwd=APP_DIR, - cmd=[ - ACTIVATE_VENV_SCRIPT, - set_create_migration_db_url_env(module), - set_module_env(module), - cd_module_script(module), - "alembic upgrade head", - Cmd( - "alembic revision --autogenerate -m {double_quote(ctx.input.message)}", - auto_render=True, - ), - ], - render_cmd=False, - retries=2, - ) - +from my_app_name._zrb.config import APP_DIR -def migrate_module(name: str, module: str, as_microservices: bool) -> Task: - env_vars = ( - dict(MICROSERVICES_ENV_VARS) if as_microservices else dict(MONOLITH_ENV_VARS) - ) - if as_microservices: - env_vars["MY_APP_NAME_MODULES"] = to_snake_case(module) - return CmdTask( - name=( - f"migrate-my-app-name-{name}" - if as_microservices - else f"migrate-{name}-on-monolith" - ), - description=f"๐Ÿงฉ Run My App Name {name.capitalize()} DB migration", - env=[ - EnvFile(path=os.path.join(APP_DIR, "template.env")), - EnvMap(vars=env_vars), - ], - cwd=APP_DIR, - cmd=[ - ACTIVATE_VENV_SCRIPT, - cd_module_script(module), - "alembic upgrade head", - ], - render_cmd=False, - retries=2, - ) +from zrb import AnyContext +from zrb.util.string.conversion import double_quote, to_snake_case -def run_microservice(name: str, port: int, module: str) -> Task: - return CmdTask( - name=f"run-my-app-name-{name}", - description=f"๐Ÿงฉ Run My App Name {name.capitalize()}", - env=[ - EnvFile(path=os.path.join(APP_DIR, "template.env")), - EnvMap( - vars={ - **MICROSERVICES_ENV_VARS, - } - ), - ], - cwd=APP_DIR, - cmd=[ - ACTIVATE_VENV_SCRIPT, - set_env("MY_APP_NAME_MODULES", module), - set_env("MY_APP_NAME_PORT", f"{port}"), - 'fastapi dev main.py --port "${MY_APP_NAME_PORT}"', - ], - render_cmd=False, - retries=2, - ) +def run_my_app_name(ctx: AnyContext) -> str: + subcommand = "dev" if ctx.input.env == "dev" else "run" + return f'fastapi {subcommand} main.py --port "${{MY_APP_NAME_PORT}}"' def get_existing_module_names() -> list[str]: diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py index e8ff0dab..6a9fcb6b 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py @@ -122,9 +122,19 @@ async def get_by_id(self, id: str) -> ResponseModel: return self._ensure_one(rows) async def get_by_ids(self, id_list: list[str]) -> list[ResponseModel]: - return await self._select_to_response( + rows = await self._select_to_response( lambda q: q.where(self.db_model.id.in_(id_list)) ) + # raise error if any id not in id_list + existing_id_list = [row.id for row in rows] + inexist_id_list = [id for id in id_list if id not in existing_id_list] + if len(inexist_id_list) > 0: + raise NotFoundError( + f"{self.entity_name} not found, inexist ids: {', '.join(inexist_id_list)}" + ) + # sort rows + row_dict = {row.id: row for row in rows} + return [row_dict[id] for id in id_list] async def count(self, filter: str | None = None) -> int: count_statement = select(func.count(1)).select_from(self.db_model) @@ -184,21 +194,22 @@ async def create(self, data: CreateModel) -> DBModel: async def create_bulk(self, data_list: list[CreateModel]) -> list[DBModel]: now = datetime.datetime.now(datetime.timezone.utc) - data_dicts = [ + data_dict_list = [ self._model_to_data_dict(data, created_at=now, id=ulid.new().str) for data in data_list ] + id_list = [data_dict["id"] for data_dict in data_dict_list] async with self._session_scope() as session: await self._execute_statement( - session, insert(self.db_model).values(data_dicts) + session, insert(self.db_model).values(data_dict_list) ) - id_list = [d["id"] for d in data_dicts] statement = select(self.db_model).where(self.db_model.id.in_(id_list)) result = await self._execute_statement(session, statement) - return [ - self.db_model(**entity.model_dump()) + row_dict = { + entity.id: self.db_model(**entity.model_dump()) for entity in result.scalars().all() - ] + } + return [row_dict[id] for id in id_list] async def delete(self, id: str) -> DBModel: async with self._session_scope() as session: @@ -220,11 +231,15 @@ async def delete_bulk(self, id_list: list[str]) -> list[DBModel]: await self._execute_statement( session, delete(self.db_model).where(self.db_model.id.in_(id_list)) ) - return [self.db_model(**entity.model_dump()) for entity in entities] + row_dict = { + entity.id: self.db_model(**entity.model_dump()) for entity in entities + } + return [row_dict[id] for id in id_list] async def update(self, id: str, data: UpdateModel) -> DBModel: now = datetime.datetime.now(datetime.timezone.utc) update_data = self._model_to_data_dict(data, updated_at=now) + update_data = {k: v for k, v in update_data.items() if v is not None} async with self._session_scope() as session: statement = ( update(self.db_model) @@ -256,7 +271,8 @@ async def update_bulk(self, id_list: list[str], data: UpdateModel) -> list[DBMod result = await self._execute_statement( session, select(self.db_model).where(self.db_model.id.in_(id_list)) ) - return [ - self.db_model(**entity.model_dump()) + row_dict = { + entity.id: self.db_model(**entity.model_dump()) for entity in result.scalars().all() - ] + } + return [row_dict[id] for id in id_list] diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py index a346440c..45dd1884 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py @@ -2,10 +2,10 @@ from enum import Enum from functools import partial from logging import Logger -from typing import Any, Callable, Sequence +from typing import Any, Callable, Union, get_args, get_origin import httpx -from fastapi import APIRouter, Depends, params +from fastapi import APIRouter, Depends from my_app_name.common.error import ClientAPIError from pydantic import BaseModel @@ -56,7 +56,6 @@ def route( response_model: Any = None, status_code: int | None = None, tags: list[str | Enum] | None = None, - dependencies: Sequence[params.Depends] | None = None, summary: str | None = None, description: str = None, deprecated: bool | None = None, @@ -160,56 +159,24 @@ async def client_method(self, *args, **kwargs): return client_method -def _create_api_client_method(logger: Logger, param: RouteParam, base_url: str): +def _create_api_client_method(logger: Logger, route_param: RouteParam, base_url: str): async def client_method(*args, **kwargs): - url = base_url + param.path - method = ( - param.methods[0].lower() - if isinstance(param.methods, list) - else param.methods.lower() + url = base_url + route_param.path + method = _get_api_client_method(route_param) + body_param_names = _get_api_client_body_param_names(route_param, method) + path_params, query_params, body_params = _create_api_client_request_params( + route_param, body_param_names, args, kwargs ) - # Get the signature of the original function - sig = inspect.signature(param.func) - # Bind the arguments to the signature - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - # Analyze parameters - params = list(sig.parameters.values()) - body_params = [ - p - for p in params - if p.name != "self" and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ] - # Prepare the request - path_params = {} - query_params = {} - body = {} - for name, value in bound_args.arguments.items(): - if name == "self": - continue - if f"{{{name}}}" in param.path: - path_params[name] = value - elif isinstance(value, BaseModel): - body = _parse_api_param(value) - elif method in ["get", "delete"]: - query_params[name] = _parse_api_param(value) - elif len(body_params) == 1 and name == body_params[0].name: - # If there's only one body parameter, use its value directly - body = _parse_api_param(value) - else: - body[name] = _parse_api_param(value) # Format the URL with path parameters url = url.format(**path_params) + json_body_params = None if method == "get" else body_params logger.info( - f"Sending request to {url} with method {method}, json={body}, params={query_params}" # noqa + f"Sending request to {url} with method {method}, json={json_body_params}, params={query_params}" # noqa ) async with httpx.AsyncClient() as client: - if method in ["get", "delete"]: - response = await getattr(client, method)(url, params=query_params) - else: - response = await getattr(client, method)( - url, json=body, params=query_params - ) + response = await client.request( + method=method, url=url, params=query_params, json=json_body_params + ) logger.info( f"Received response: status={response.status_code}, content={response.content}" ) @@ -220,17 +187,139 @@ async def client_method(*args, **kwargs): else response.text ) raise ClientAPIError(response.status_code, error_detail) - return response.json() + return _parse_api_client_response(logger, route_param, response) return client_method -def _parse_api_param(data: Any) -> Any: +def _parse_api_client_response( + logger: Logger, route_param: RouteParam, response: Any +) -> Any: + sig = inspect.signature(route_param.func) + try: + response_data = response.json() + except Exception: + logger.warning("Failed to parse JSON") + return None + return_annotation = sig.return_annotation # e.g., list[User] + if return_annotation is inspect.Signature.empty: + logger.warning("No return annotation detected, return value as is") + return response_data # No return type specified, return raw JSON + origin = get_origin(return_annotation) # e.g., list + args = get_args(return_annotation) # e.g., (User,) + try: + if origin is None: # Not a generic type, so check it directly + if inspect.isclass(return_annotation) and issubclass( + return_annotation, BaseModel + ): + if response_data: + return return_annotation.model_validate(response_data) + return None + elif origin in {list, set, tuple} and args: + model_type = args[0] + if inspect.isclass(model_type) and issubclass(model_type, BaseModel): + if isinstance(response_data, list): + return [model_type.model_validate(item) for item in response_data] + elif isinstance(response_data, tuple): + return tuple( + model_type.model_validate(item) for item in response_data + ) + elif isinstance(response_data, set): + return {model_type.model_validate(item) for item in response_data} + return None + elif origin is Union and len(args) == 2 and type(None) in args: + model_type = next( + ( + arg + for arg in args + if inspect.isclass(arg) and issubclass(arg, BaseModel) + ), + None, + ) + if response_data and model_type: + return model_type.model_validate(response_data) + return None + elif origin is dict and len(args) == 2: + key_type, value_type = args + if inspect.isclass(value_type) and issubclass(value_type, BaseModel): + if inspect(response_data, dict): + return { + k: value_type.model_validate(v) + for k, v in response_data.items() + } + return None + return response_data + except Exception: + logger.warning( + "Return annotation detected, but parsing error, return value as is" + ) + return response_data + + +def _create_api_client_request_params( + route_param: RouteParam, + body_param_names: list[str], + args: list[Any], + kwargs: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + # Get the signature of the original function + sig = inspect.signature(route_param.func) + # Bind the arguments to the signature + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + # Prepare the request + path_params = {} + query_params = {} + body_params = {} + for name, value in bound_args.arguments.items(): + if name == "self": + continue + if f"{{{name}}}" in route_param.path: + path_params[name] = value + elif name not in body_param_names: + query_params[name] = _parse_api_client_param(value) + elif len(body_param_names) == 1 and name == body_param_names[0]: + # If there's only one body parameter, use its value directly + body_params = _parse_api_client_param(value) + else: + body_params[name] = _parse_api_client_param(value) + return path_params, query_params, body_params + + +def _parse_api_client_param(data: Any) -> Any: if isinstance(data, BaseModel): return data.model_dump() - elif isinstance(data, list): - return [_parse_api_param(item) for item in data] elif isinstance(data, dict): - return {key: _parse_api_param(value) for key, value in data.items()} + return {key: _parse_api_client_param(value) for key, value in data.items()} + elif isinstance(data, list): + return [_parse_api_client_param(item) for item in data] + elif isinstance(data, tuple): + return tuple(_parse_api_client_param(item) for item in data) + elif isinstance(data, set): + return {_parse_api_client_param(item) for item in data} else: return data + + +def _get_api_client_method(route_param: RouteParam) -> str: + if isinstance(route_param.methods, list): + return route_param.methods[0].lower() + return route_param.methods.lower() + + +def _get_api_client_body_param_names(route_param: RouteParam, method: str): + sig = inspect.signature(route_param.func) + function_params = list(sig.parameters.values()) + return [ + p.name + for p in function_params + if ( + p.name != "self" + and f"{{{p.name}}}" not in route_param.path + and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and ( + method not in ["get", "delete"] + or (method == "delete" and p.annotation not in [str, float, bool]) + ) + ) + ] diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py index 421b1c72..401095b7 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py @@ -8,11 +8,26 @@ def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None: super().__init__(404, {"message": message}, headers) +class ForbiddenError(HTTPException): + def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None: + super().__init__(403, {"message": message}, headers) + + +class UnauthorizedError(HTTPException): + def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None: + super().__init__(401, {"message": message}, headers) + + class InvalidValueError(HTTPException): def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None: super().__init__(422, {"message": message}, headers) +class InternalServerError(HTTPException): + def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None: + super().__init__(500, {"message": message}, headers) + + class ClientAPIError(HTTPException): def __init__( self, status_code: int, message: str, headers: Dict[str, str] | None = None diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py index d8161b2e..65058b7b 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py @@ -1,5 +1,8 @@ import os +TRUE_STRS = ["true", "1", "yes", "y", "active", "on"] +FALSE_STRS = ["false", "0", "no", "n", "inactive", "off"] + APP_PATH = os.path.dirname(__file__) APP_VERSION = "0.1.0" @@ -51,16 +54,31 @@ "MY_APP_NAME_AUTH_SUPER_USER_PASSWORD", "my-secure-password" ) APP_AUTH_GUEST_USER = os.getenv("MY_APP_NAME_AUTH_GUEST_USER", "user") -APP_AUTH_GUEST_USER_PERMISSIONS = ( +APP_AUTH_GUEST_USER_PERMISSIONS = [ permission_name.strip() for permission_name in os.getenv( "MY_APP_NAME_AUTH_GUEST_USER_PERMISSIONS", "" ).split(",") if permission_name.strip() != "" +] +APP_AUTH_MAX_PARALLEL_SESSION = int( + os.getenv("MY_APP_NAME_AUTH_MAX_PARALLEL_SESSION", "1") +) +APP_AUTH_ACCESS_TOKEN_EXPIRE_MINUTES = int( + os.getenv("MY_APP_NAME_AUTH_ACCESS_TOKEN_EXPIRE_MINUTES", "30") +) +APP_AUTH_REFRESH_TOKEN_EXPIRE_MINUTES = int( + os.getenv("MY_APP_NAME_AUTH_REFRESH_TOKEN_EXPIRE_MINUTES", "1440") +) +APP_AUTH_ACCESS_TOKEN_COOKIE_NAME = os.getenv( + "MY_APP_NAME_AUTH_ACCESS_TOKEN_COOKIE_NAME", "access_token" +) +APP_AUTH_REFRESH_TOKEN_COOKIE_NAME = os.getenv( + "MY_APP_NAME_AUTH_REFRESH_TOKEN_COOKIE_NAME", "refresh_token" ) -APP_MAX_PARALLEL_SESSION = int(os.getenv("MY_APP_NAME_MAX_PARALLEL_SESSION", "1")) -APP_SESSION_EXPIRE_MINUTES = int( - os.getenv("MY_APP_NAME_SESSION_EXPIRE_MINUTES", "1440") +APP_AUTH_SECRET_KEY = os.getenv("MY_APP_NAME_AUTH_SECRET_KEY", "my-secret-key") +APP_AUTH_PRIORITIZE_NEW_SESSION = ( + os.getenv("MY_APP_NAME_AUTH_PRIORITIZE_NEW_SESSION", "1").lower() in TRUE_STRS ) APP_AUTH_BASE_URL = os.getenv("MY_APP_NAME_AUTH_BASE_URL", "http://localhost:3001") diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/migrate.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/migrate.py deleted file mode 100644 index c5560763..00000000 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/migrate.py +++ /dev/null @@ -1,3 +0,0 @@ -from my_app_name.module.auth import migration as auth_migration - -assert auth_migration diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py index f1fc9e10..af9827da 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py @@ -13,15 +13,36 @@ RoleUpdateWithPermissionsAndAudit, ) from my_app_name.schema.user import ( + AuthUserResponse, MultipleUserResponse, UserCreateWithRolesAndAudit, + UserCredentials, UserResponse, + UserSessionResponse, UserUpdateWithRolesAndAudit, ) class AuthClient(ABC): + @abstractmethod + async def get_current_user(self, access_token: str) -> AuthUserResponse: + """Get current user based on access token""" + + @abstractmethod + async def create_user_session( + self, credential: UserCredentials + ) -> UserSessionResponse: + """Create new user session""" + + @abstractmethod + async def update_user_session(self, refresh_token: str) -> UserSessionResponse: + """Update user session""" + + @abstractmethod + async def delete_user_session(self, refresh_token: str) -> UserSessionResponse: + """Delete user session""" + # Permission related methods @abstractmethod diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py index ea7c7644..8c3c93b3 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py @@ -22,7 +22,7 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table( - "permission", + "permissions", sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=True), sa.Column("created_by", sqlmodel.sql.sqltypes.AutoString(), nullable=True), @@ -32,22 +32,22 @@ def upgrade() -> None: sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_permission_name"), "permission", ["name"], unique=True) + op.create_index(op.f("ix_permission_name"), "permissions", ["name"], unique=True) op.create_index( - op.f("ix_permission_created_at"), "permission", ["created_at"], unique=False + op.f("ix_permission_created_at"), "permissions", ["created_at"], unique=False ) op.create_index( - op.f("ix_permission_created_by"), "permission", ["created_by"], unique=False + op.f("ix_permission_created_by"), "permissions", ["created_by"], unique=False ) op.create_index( - op.f("ix_permission_updated_at"), "permission", ["updated_at"], unique=False + op.f("ix_permission_updated_at"), "permissions", ["updated_at"], unique=False ) op.create_index( - op.f("ix_permission_updated_by"), "permission", ["updated_by"], unique=False + op.f("ix_permission_updated_by"), "permissions", ["updated_by"], unique=False ) op.create_table( - "role", + "roles", sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=True), sa.Column("created_by", sqlmodel.sql.sqltypes.AutoString(), nullable=True), @@ -57,14 +57,14 @@ def upgrade() -> None: sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_role_name"), "role", ["name"], unique=True) - op.create_index(op.f("ix_role_created_at"), "role", ["created_at"], unique=False) - op.create_index(op.f("ix_role_created_by"), "role", ["created_by"], unique=False) - op.create_index(op.f("ix_role_updated_at"), "role", ["updated_at"], unique=False) - op.create_index(op.f("ix_role_updated_by"), "role", ["updated_by"], unique=False) + op.create_index(op.f("ix_role_name"), "roles", ["name"], unique=True) + op.create_index(op.f("ix_role_created_at"), "roles", ["created_at"], unique=False) + op.create_index(op.f("ix_role_created_by"), "roles", ["created_by"], unique=False) + op.create_index(op.f("ix_role_updated_at"), "roles", ["updated_at"], unique=False) + op.create_index(op.f("ix_role_updated_by"), "roles", ["updated_by"], unique=False) op.create_table( - "rolepermission", + "role_permissions", sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("role_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("permission_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), @@ -73,34 +73,39 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), ) op.create_index( - op.f("ix_rolepermission_permission_id"), - "rolepermission", + op.f("ix_role_permissions_permission_id"), + "role_permissions", ["permission_id"], unique=False, ) op.create_index( - op.f("ix_rolepermission_role_id"), "rolepermission", ["role_id"], unique=False + op.f("ix_role_permissions_role_id"), + "role_permissions", + ["role_id"], + unique=False, ) op.create_table( - "user", + "users", sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("password", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), sa.Column("created_at", sa.DateTime(), nullable=False), sa.Column("created_by", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("updated_at", sa.DateTime(), nullable=True), sa.Column("updated_by", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_user_username"), "user", ["username"], unique=True) - op.create_index(op.f("ix_user_created_at"), "user", ["created_at"], unique=False) - op.create_index(op.f("ix_user_created_by"), "user", ["created_by"], unique=False) - op.create_index(op.f("ix_user_updated_at"), "user", ["updated_at"], unique=False) - op.create_index(op.f("ix_user_updated_by"), "user", ["updated_by"], unique=False) + op.create_index(op.f("ix_user_username"), "users", ["username"], unique=True) + op.create_index(op.f("ix_user_active"), "users", ["active"], unique=False) + op.create_index(op.f("ix_user_created_at"), "users", ["created_at"], unique=False) + op.create_index(op.f("ix_user_created_by"), "users", ["created_by"], unique=False) + op.create_index(op.f("ix_user_updated_at"), "users", ["updated_at"], unique=False) + op.create_index(op.f("ix_user_updated_by"), "users", ["updated_by"], unique=False) op.create_table( - "userrole", + "user_roles", sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("user_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("role_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), @@ -108,53 +113,90 @@ def upgrade() -> None: sa.Column("created_by", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_userrole_role_id"), "userrole", ["role_id"], unique=False) - op.create_index(op.f("ix_userrole_user_id"), "userrole", ["user_id"], unique=False) + op.create_index( + op.f("ix_user_roles_role_id"), "user_roles", ["role_id"], unique=False + ) + op.create_index( + op.f("ix_user_roles_user_id"), "user_roles", ["user_id"], unique=False + ) op.create_table( - "session", + "user_sessions", sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("user_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("access_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("refresh_token", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("access_token_expired_at", sa.DateTime(), nullable=False), + sa.Column("refresh_token_expired_at", sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_session_user_id"), "session", ["user_id"], unique=False) - op.create_index(op.f("ix_session_token"), "session", ["token"], unique=True) + op.create_index( + op.f("ix_user_session_user_id"), "user_sessions", ["user_id"], unique=False + ) + op.create_index( + op.f("ix_user_session_token"), "user_sessions", ["access_token"], unique=True + ) + op.create_index( + op.f("ix_user_session_refresh_token"), + "user_sessions", + ["refresh_token"], + unique=True, + ) + op.create_index( + op.f("ix_user_session_access_token_expired_at"), + "user_sessions", + ["access_token_expired_at"], + ) + op.create_index( + op.f("ix_user_session_refresh_token_expired_at"), + "user_sessions", + ["refresh_token_expired_at"], + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f("ix_session_token"), table_name="session") - op.drop_index(op.f("ix_session_user_id"), table_name="session") - op.drop_table("session") - - op.drop_index(op.f("ix_userrole_user_id"), table_name="userrole") - op.drop_index(op.f("ix_userrole_role_id"), table_name="userrole") - op.drop_table("userrole") - - op.drop_index(op.f("ix_user_username"), table_name="user") - op.drop_index(op.f("ix_user_updated_by"), table_name="user") - op.drop_index(op.f("ix_user_updated_at"), table_name="user") - op.drop_index(op.f("ix_user_created_by"), table_name="user") - op.drop_index(op.f("ix_user_created_at"), table_name="user") - op.drop_table("user") - - op.drop_index(op.f("ix_rolepermission_role_id"), table_name="rolepermission") - op.drop_index(op.f("ix_rolepermission_permission_id"), table_name="rolepermission") - op.drop_table("rolepermission") - - op.drop_index(op.f("ix_role_name"), table_name="role") - op.drop_index(op.f("ix_role_updated_by"), table_name="role") - op.drop_index(op.f("ix_role_updated_at"), table_name="role") - op.drop_index(op.f("ix_role_created_by"), table_name="role") - op.drop_index(op.f("ix_role_created_at"), table_name="role") - op.drop_table("role") - - op.drop_index(op.f("ix_permission_updated_by"), table_name="permission") - op.drop_index(op.f("ix_permission_updated_at"), table_name="permission") - op.drop_index(op.f("ix_permission_created_by"), table_name="permission") - op.drop_index(op.f("ix_permission_created_at"), table_name="permission") - op.drop_index(op.f("ix_permission_name"), table_name="permission") - op.drop_table("permission") + op.drop_index(op.f("ix_user_session_user_id"), table_name="user_sessions") + op.drop_index(op.f("ix_user_session_access_token"), table_name="user_sessions") + op.drop_index(op.f("ix_user_session_refresh_token"), table_name="user_sessions") + op.drop_index( + op.f("ix_user_session_access_token_expired_at"), table_name="user_sessions" + ) + op.drop_index( + op.f("ix_user_session_refresh_token_expired_at"), table_name="user_sessions" + ) + op.drop_table("user_sessions") + + op.drop_index(op.f("ix_user_roles_user_id"), table_name="user_roles") + op.drop_index(op.f("ix_user_roles_role_id"), table_name="user_roles") + op.drop_table("user_roles") + + op.drop_index(op.f("ix_user_username"), table_name="users") + op.drop_index(op.f("ix_user_active"), table_name="users") + op.drop_index(op.f("ix_user_updated_by"), table_name="users") + op.drop_index(op.f("ix_user_updated_at"), table_name="users") + op.drop_index(op.f("ix_user_created_by"), table_name="users") + op.drop_index(op.f("ix_user_created_at"), table_name="users") + op.drop_table("users") + + op.drop_index(op.f("ix_role_permissions_role_id"), table_name="role_permissions") + op.drop_index( + op.f("ix_role_permissions_permission_id"), table_name="role_permissions" + ) + op.drop_table("role_permissions") + + op.drop_index(op.f("ix_role_name"), table_name="roles") + op.drop_index(op.f("ix_role_updated_by"), table_name="roles") + op.drop_index(op.f("ix_role_updated_at"), table_name="roles") + op.drop_index(op.f("ix_role_created_by"), table_name="roles") + op.drop_index(op.f("ix_role_created_at"), table_name="roles") + op.drop_table("roles") + + op.drop_index(op.f("ix_permission_updated_by"), table_name="permissions") + op.drop_index(op.f("ix_permission_updated_at"), table_name="permissions") + op.drop_index(op.f("ix_permission_created_by"), table_name="permissions") + op.drop_index(op.f("ix_permission_created_at"), table_name="permissions") + op.drop_index(op.f("ix_permission_name"), table_name="permissions") + op.drop_table("permissions") # ### end Alembic commands ### diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py index 79e5c616..65939f75 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py @@ -1,7 +1,6 @@ from my_app_name.schema.permission import Permission from my_app_name.schema.role import Role, RolePermission -from my_app_name.schema.session import Session -from my_app_name.schema.user import User, UserRole +from my_app_name.schema.user import User, UserRole, UserSession from sqlalchemy import MetaData metadata = MetaData() @@ -19,5 +18,5 @@ UserRole.metadata = metadata UserRole.__table__.tometadata(metadata) -Session.metadata = metadata -Session.__table__.tometadata(metadata) +UserSession.metadata = metadata +UserSession.__table__.tometadata(metadata) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py index 06d73d31..15b85a65 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py @@ -5,10 +5,23 @@ from my_app_name.module.auth.service.permission.permission_service_factory import ( permission_service, ) +from my_app_name.module.auth.service.role.role_service_factory import role_service from my_app_name.module.auth.service.user.user_service_factory import user_service -def serve_health_check(app: FastAPI): +def serve_route(app: FastAPI): + if APP_MODE != "microservices" or "auth" not in APP_MODULES: + return + if APP_MAIN_MODULE == "auth": + _serve_health_check(app) + _serve_readiness_check(app) + + permission_service.serve_route(app) + role_service.serve_route(app) + user_service.serve_route(app) + + +def _serve_health_check(app: FastAPI): @app.api_route("/health", methods=["GET", "HEAD"], response_model=BasicResponse) async def health(): """ @@ -17,7 +30,7 @@ async def health(): return BasicResponse(message="ok") -def serve_readiness_check(app: FastAPI): +def _serve_readiness_check(app: FastAPI): @app.api_route("/readiness", methods=["GET", "HEAD"], response_model=BasicResponse) async def readiness(): """ @@ -26,16 +39,4 @@ async def readiness(): return BasicResponse(message="ok") -def serve_route(app: FastAPI): - if APP_MODE != "microservices" or "auth" not in APP_MODULES: - return - if APP_MAIN_MODULE == "auth": - serve_health_check(app) - serve_readiness_check(app) - - # Serve user endpoints for APIClient - user_service.serve_route(app) - permission_service.serve_route(app) - - serve_route(app) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py index 6b223f6c..5f2e8a0b 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py @@ -71,11 +71,11 @@ async def create_permission( @BaseService.route( "/api/v1/permissions/bulk", methods=["put"], - response_model=PermissionResponse, + response_model=list[PermissionResponse], ) async def update_permission_bulk( self, permission_ids: list[str], data: PermissionUpdateWithAudit - ) -> PermissionResponse: + ) -> list[PermissionResponse]: await self.permission_repository.update_bulk(permission_ids, data) return await self.permission_repository.get_by_ids(permission_ids) @@ -93,11 +93,11 @@ async def update_permission( @BaseService.route( "/api/v1/permissions/bulk", methods=["delete"], - response_model=PermissionResponse, + response_model=list[PermissionResponse], ) async def delete_permission_bulk( self, permission_ids: list[str], deleted_by: str - ) -> PermissionResponse: + ) -> list[PermissionResponse]: permissions = await self.permission_repository.get_by_ids(permission_ids) await self.permission_repository.delete_bulk(permission_ids) return permissions diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py index 2ea79792..6031845d 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py @@ -54,28 +54,47 @@ def _rows_to_responses(self, rows: list[tuple[Any, ...]]) -> list[RoleResponse]: and permission.id not in role_permission_map[role.id] ): role_permission_map[role.id].append(permission.id) - role_map[role.id]["permissions"].append(permission.model_dump()) + role_map[role.id]["permissions"].append(permission) return [ - RoleResponse(**data["role"].model_dump(), permissions=data["permissions"]) + RoleResponse( + **data["role"].model_dump(), + permission_names=[ + permission.name for permission in data["permissions"] + ], + ) for data in role_map.values() ] async def add_permissions(self, data: dict[str, list[str]], created_by: str): now = datetime.datetime.now(datetime.timezone.utc) + # get mapping from perrmission names to permission ids + all_permission_names = { + name for permission_names in data.values() for name in permission_names + } + async with self._session_scope() as session: + result = await self._execute_statement( + session, + select(Permission.id, Permission.name).where( + Permission.name.in_(all_permission_names) + ), + ) + permission_mapping = {row.name: row.id for row in result} + # Assemble data dict data_dict_list: list[dict[str, Any]] = [] - for role_id, permission_ids in data.items(): - for permission_id in permission_ids: + for role_id, permission_names in data.items(): + for permission_name in permission_names: data_dict_list.append( self._model_to_data_dict( RolePermission( id=ulid.new().str, role_id=role_id, - permission_id=permission_id, + permission_id=permission_mapping.get(permission_name), created_at=now, created_by=created_by, ) ) ) + # Insert rolePermissions async with self._session_scope() as session: await self._execute_statement( session, insert(RolePermission).values(data_dict_list) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py index da7e335d..a229a33d 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py @@ -50,13 +50,13 @@ async def get_roles( async def create_role_bulk( self, data: list[RoleCreateWithPermissionsAndAudit] ) -> list[RoleResponse]: - permission_ids = [row.get_permission_ids() for row in data] + permission_names = [row.get_permission_names() for row in data] data = [row.get_role_create_with_audit() for row in data] roles = await self.role_repository.create_bulk(data) if len(roles) > 0: created_by = roles[0].created_by await self.role_repository.add_permissions( - data={role.id: permission_ids[i] for i, role in enumerate(roles)}, + data={role.id: permission_names[i] for i, role in enumerate(roles)}, created_by=created_by, ) return await self.role_repository.get_by_ids([role.id for role in roles]) @@ -69,30 +69,32 @@ async def create_role_bulk( async def create_role( self, data: RoleCreateWithPermissionsAndAudit ) -> RoleResponse: - permission_ids = data.get_permission_ids() + permission_names = data.get_permission_names() data = data.get_role_create_with_audit() role = await self.role_repository.create(data) await self.role_repository.add_permissions( - data={role.id: permission_ids}, created_by=role.created_by + data={role.id: permission_names}, created_by=role.created_by ) return await self.role_repository.get_by_id(role.id) @BaseService.route( "/api/v1/roles/bulk", methods=["put"], - response_model=RoleResponse, + response_model=list[RoleResponse], ) async def update_role_bulk( self, role_ids: list[str], data: RoleUpdateWithPermissionsAndAudit - ) -> RoleResponse: - permission_ids = [row.get_permission_ids() for row in data] + ) -> list[RoleResponse]: + permission_names = [row.get_permission_names() for row in data] data = [row.get_role_update_with_audit() for row in data] await self.role_repository.update_bulk(role_ids, data) if len(role_ids) > 0: updated_by = data[0].updated_by await self.role_repository.remove_all_permissions(role_ids) await self.role_repository.add_permissions( - data={role_id: permission_ids[i] for i, role_id in enumerate(role_ids)}, + data={ + role_id: permission_names[i] for i, role_id in enumerate(role_ids) + }, created_by=updated_by, ) return await self.role_repository.get_by_ids(role_ids) @@ -105,23 +107,23 @@ async def update_role_bulk( async def update_role( self, role_id: str, data: RoleUpdateWithPermissionsAndAudit ) -> RoleResponse: - permission_ids = data.get_permission_ids() + permission_names = data.get_permission_names() role_data = data.get_role_update_with_audit() await self.role_repository.update(role_id, role_data) await self.role_repository.remove_all_permissions([role_id]) await self.role_repository.add_permissions( - data={role_id: permission_ids}, created_by=role_data.updated_by + data={role_id: permission_names}, created_by=role_data.updated_by ) return await self.role_repository.get_by_id(role_id) @BaseService.route( "/api/v1/roles/bulk", methods=["delete"], - response_model=RoleResponse, + response_model=list[RoleResponse], ) async def delete_role_bulk( self, role_ids: list[str], deleted_by: str - ) -> RoleResponse: + ) -> list[RoleResponse]: roles = await self.role_repository.get_by_ids(role_ids) await self.role_repository.delete_bulk(role_ids) await self.role_repository.remove_all_permissions(role_ids) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py index 26a274fa..4077bad8 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py @@ -1,35 +1,27 @@ import datetime -from typing import Any, Callable +from typing import Any import ulid from my_app_name.common.base_db_repository import BaseDBRepository -from my_app_name.common.error import NotFoundError -from my_app_name.config import ( - APP_AUTH_GUEST_USER, - APP_AUTH_GUEST_USER_PERMISSIONS, - APP_AUTH_SUPER_USER, - APP_AUTH_SUPER_USER_PASSWORD, - APP_MAX_PARALLEL_SESSION, - APP_SESSION_EXPIRE_MINUTES, -) +from my_app_name.common.error import NotFoundError, UnauthorizedError from my_app_name.module.auth.service.user.repository.user_repository import ( UserRepository, ) from my_app_name.schema.permission import Permission from my_app_name.schema.role import Role, RolePermission -from my_app_name.schema.session import Session, SessionResponse from my_app_name.schema.user import ( User, UserCreateWithAudit, UserResponse, UserRole, + UserSession, + UserSessionResponse, + UserTokenData, UserUpdateWithAudit, ) from passlib.context import CryptContext -from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import AsyncEngine -from sqlalchemy.sql import ClauseElement, ColumnElement, Select -from sqlmodel import SQLModel, delete, insert, select +from sqlalchemy.sql import Select +from sqlmodel import delete, insert, select, update # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -39,6 +31,11 @@ def hash_password(password: str) -> str: return pwd_context.hash(password) +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verifies if a password matches the stored hash.""" + return pwd_context.verify(plain_password, hashed_password) + + class UserDBRepository( BaseDBRepository[User, UserResponse, UserCreateWithAudit, UserUpdateWithAudit], UserRepository, @@ -50,86 +47,65 @@ class UserDBRepository( entity_name = "user" column_preprocessors = {"password": hash_password} - def __init__( - self, - engine: Engine | AsyncEngine, - super_user_username: str = APP_AUTH_SUPER_USER, - super_user_password: str = APP_AUTH_SUPER_USER_PASSWORD, - guest_user_username: str = APP_AUTH_GUEST_USER, - guest_user_password: str = APP_AUTH_SUPER_USER_PASSWORD, - guest_user_permission_names: list[str] = APP_AUTH_GUEST_USER_PERMISSIONS, - max_parallel_session: int = APP_MAX_PARALLEL_SESSION, - session_expire_minutes: int = APP_SESSION_EXPIRE_MINUTES, - filter_param_parser: ( - Callable[[SQLModel, str], list[ClauseElement]] | None - ) = None, - sort_param_parser: Callable[[SQLModel, str], list[ColumnElement]] | None = None, - ): - super().__init__( - engine=engine, - filter_param_parser=filter_param_parser, - sort_param_parser=sort_param_parser, - ) - self._super_user_username = super_user_username - self._super_user_passwored = super_user_password - self._guest_user_username = guest_user_username - self._guest_user_password = guest_user_password - self._guest_user_permission_names = guest_user_permission_names - self._max_parallel_session = max_parallel_session - self._session_expire_minutes = session_expire_minutes - self._super_user: User | None = None - self._guest_user: User | None = None - def _select(self) -> Select: return ( - select(User, Role, Permission, Session) + select(User, Role, Permission) .join(UserRole, UserRole.user_id == User.id, isouter=True) .join(Role, Role.id == UserRole.role_id, isouter=True) .join(RolePermission, RolePermission.role_id == Role.id, isouter=True) .join( Permission, Permission.id == RolePermission.permission_id, isouter=True ) - .join(Session, Session.user_id == User.id) ) def _rows_to_responses(self, rows: list[tuple[Any, ...]]) -> list[UserResponse]: user_map: dict[str, dict[str, Any]] = {} user_role_map: dict[str, list[str]] = {} user_permission_map: dict[str, list[str]] = {} - for user, role, permission, _ in rows: + for user, role, permission in rows: if user.id not in user_map: user_map[user.id] = {"user": user, "roles": [], "permissions": []} user_role_map[user.id] = [] user_permission_map[user.id] = [] if role is not None and role.id not in user_role_map[user.id]: user_role_map[user.id].append(role.id) - user_map[user.id]["roles"].append(role.model_dump()) + user_map[user.id]["roles"].append(role) if ( permission is not None and permission.id not in user_permission_map[user.id] ): user_permission_map[user.id].append(permission.id) - user_map[user.id]["permissions"].append(permission.model_dump()) + user_map[user.id]["permissions"].append(permission) return [ UserResponse( **data["user"].model_dump(), - roles=list(data["roles"]), - permissions=list(data["permissions"]), + role_names=[role.name for role in data["roles"]], + permission_names=[ + permission.name for permission in data["permissions"] + ], ) for data in user_map.values() ] async def add_roles(self, data: dict[str, list[str]], created_by: str): now = datetime.datetime.now(datetime.timezone.utc) + # get mapping from role names to role ids + all_role_names = {name for role_names in data.values() for name in role_names} + async with self._session_scope() as session: + result = await self._execute_statement( + session, select(Role.id, Role.name).where(Role.name.in_(all_role_names)) + ) + role_mapping = {row.name: row.id for row in result} + # Assemble data dict data_dict_list: list[dict[str, Any]] = [] - for user_id, role_ids in data.items(): - for role_id in role_ids: + for user_id, role_names in data.items(): + for role_name in role_names: data_dict_list.append( self._model_to_data_dict( UserRole( id=ulid.new().str, user_id=user_id, - role_id=role_id, + role_id=role_mapping.get(role_name), created_at=now, created_by=created_by, ) @@ -148,65 +124,123 @@ async def remove_all_roles(self, user_ids: list[str] = []): ) async def get_by_credentials(self, username: str, password: str) -> UserResponse: - rows = await self._select_to_response( - lambda q: q.where( - User.username == username, User.password == hash_password(password) + async with self._session_scope() as session: + result = await self._execute_statement( + session, select(User).where(User.username == username, User.active) ) - ) - return self._ensure_one(rows) + user = result.scalar_one_or_none() + if user is None or not verify_password(password, user.password): + raise UnauthorizedError("Invalid username or password") + return await self.get_by_id(user.id) - async def get_by_token(self, token: str) -> UserResponse: - rows = await self._select_tor_response( - lambda q: q.where(Session.token == token) - ) - return self._ensure_one(rows) - - async def add_token(self, user_id: str, token: str): + async def delete_expired_user_sessions(self, user_id: str): + now = datetime.datetime.now(datetime.timezone.utc) async with self._session_scope() as session: await self._execute_statement( session, - insert(Session).values( - { - "id": ulid.new().str, - "user_id": user_id, - "token": token, - "created_by": "system", - "created_at": datetime.datetime.now(datetime.timezone.utc), - } + delete(UserSession).where( + UserSession.user_id == user_id, + UserSession.refresh_token_expired_at < now, ), ) - async def remove_token(self, user_id: str, token: str): + async def get_active_user_sessions(self, user_id: str) -> list[UserSessionResponse]: + now = datetime.datetime.now(datetime.timezone.utc) async with self._session_scope() as session: - await self._execute_statement( + result = await self._execute_statement( session, - delete(Session).where( - Session.token == token, Session.user_id == user_id + select(UserSession).where( + UserSession.user_id == user_id, + UserSession.refresh_token_expired_at > now, ), ) + return [self._user_session_to_response(row[0]) for row in result.all()] - async def get_sessions(self, user_id: str) -> list[SessionResponse]: + async def get_user_session_by_access_token( + self, access_token: str + ) -> UserSessionResponse: + now = datetime.datetime.now(datetime.timezone.utc) + async with self._session_scope() as session: + result = await self._execute_statement( + session, + select(UserSession).where( + UserSession.access_token == access_token, + UserSession.access_token_expired_at > now, + ), + ) + user_session = result.scalar_one_or_none() + if user_session is None: + raise NotFoundError("User session not found") + return self._user_session_to_response(user_session) + + async def get_user_session_by_refresh_token( + self, refresh_token: str + ) -> UserSessionResponse: + now = datetime.datetime.now(datetime.timezone.utc) async with self._session_scope() as session: - statement = select(Session).where(Session.user_id == user_id) - result = await self._execute_statement(session, statement) - return [ - SessionResponse(**session.model_dump()) - for session in result.scalars().all() - ] - - async def remove_session(self, user_id: str, session_id: str) -> SessionResponse: + result = await self._execute_statement( + session, + select(UserSession).where( + UserSession.refresh_token == refresh_token, + UserSession.refresh_token_expired_at > now, + ), + ) + user_session = result.scalar_one_or_none() + if user_session is None: + raise NotFoundError("User session not found") + return self._user_session_to_response(user_session) + + async def create_user_session( + self, user_id: str, token_data: UserTokenData + ) -> UserSessionResponse: + data_dict = self._model_to_data_dict( + token_data, user_id=user_id, id=ulid.new().str + ) async with self._session_scope() as session: - statement = select(Session).where( - Session.user_id == user_id, Session.id == session_id + await self._execute_statement( + session, insert(UserSession).values(**data_dict) + ) + result = await self._execute_statement( + session, select(UserSession).where(UserSession.id == data_dict["id"]) ) - result = await self._execute_statement(session, statement) - session = result.scalar_one_or_none() - if not session: - raise NotFoundError(f"{self.entity_name} not found") + user_session = result.scalar_one_or_none() + if user_session is None: + raise NotFoundError("User session not found after created") + return self._user_session_to_response(user_session) + + async def update_user_session( + self, user_id: str, session_id: str, token_data: UserTokenData + ) -> UserSessionResponse: + data_dict = self._model_to_data_dict(token_data, user_id=user_id) + async with self._session_scope() as session: await self._execute_statement( session, - delete(Session).where( - Session.id == session_id, Session.user_id == user_id + ( + update(UserSession) + .where(UserSession.id == session_id) + .values(**data_dict) ), ) - return SessionResponse(**session.model_dump()) + result = await self._execute_statement( + session, select(UserSession).where(UserSession.id == session_id) + ) + user_session = result.scalar_one_or_none() + if user_session is None: + raise NotFoundError("User session not found after created") + return self._user_session_to_response(user_session) + + async def delete_user_sessions(self, session_ids: list[str]): + async with self._session_scope() as session: + await self._execute_statement( + session, delete(UserSession).where(UserSession.id.in_(session_ids)) + ) + + def _user_session_to_response( + self, user_session: UserSession + ) -> UserSessionResponse: + return UserSessionResponse( + id=user_session.id, + user_id=user_session.user_id, + access_token_expired_at=user_session.access_token_expired_at, + refresh_token_expired_at=user_session.refresh_token_expired_at, + ) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py index ef97df96..49600dda 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod -from my_app_name.schema.session import SessionResponse from my_app_name.schema.user import ( User, UserCreateWithAudit, UserResponse, + UserSessionResponse, + UserTokenData, UserUpdateWithAudit, ) @@ -72,21 +73,37 @@ async def get_by_credentials(self, username: str, password: str) -> UserResponse """Get user by credential""" @abstractmethod - async def get_by_token(self, token: str) -> UserResponse: - """Get user by token""" + async def get_active_user_sessions(self, user_id: str) -> list[UserSessionResponse]: + """Get user sessions""" @abstractmethod - async def add_token(self, user_id: str, token: str): - """Add token to user""" + async def get_user_session_by_access_token( + self, access_token: str + ) -> UserSessionResponse: + """Get user session by access token""" @abstractmethod - async def remove_token(self, user_id: str, token: str): - """Remove token from user""" + async def get_user_session_by_refresh_token( + self, refresh_token: str + ) -> UserSessionResponse: + """Get user session by refresh token""" @abstractmethod - async def get_sessions(self, user_id: str) -> list[SessionResponse]: - """Get sessions""" + async def create_user_session( + self, user_id: str, token_data: UserTokenData + ) -> UserSessionResponse: + """Create new user session""" @abstractmethod - async def remove_session(self, user_id: str, session_id: str) -> SessionResponse: - """Remove a session""" + async def update_user_session( + self, user_id: str, session_id: str, token_data: UserTokenData + ) -> UserSessionResponse: + """Update user session""" + + @abstractmethod + async def delete_expired_user_sessions(self, user_id: str): + """Delete expired user sessions""" + + @abstractmethod + async def delete_user_sessions(self, session_ids: list[str]): + """Delete user session""" diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py index 83350186..77bdb051 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py @@ -1,22 +1,103 @@ +import datetime from logging import Logger +from jose import jwt from my_app_name.common.base_service import BaseService +from my_app_name.common.error import ForbiddenError, NotFoundError from my_app_name.module.auth.service.user.repository.user_repository import ( UserRepository, ) from my_app_name.schema.user import ( + AuthUserResponse, MultipleUserResponse, UserCreateWithRolesAndAudit, + UserCredentials, UserResponse, + UserSessionResponse, + UserTokenData, UserUpdateWithRolesAndAudit, ) +from pydantic import BaseModel + + +class UserServiceConfig(BaseModel): + super_user: str + super_user_password: str + guest_user: str = "guest" + guest_user_permissions: list[str] = [] + max_parallel_session: int = 1 + access_token_expire_minutes: int = 30 + refresh_token_expire_minutes: int = 1440 + secret_key: str = "my-secret-key" + prioritize_new_session: bool = True class UserService(BaseService): - def __init__(self, logger: Logger, user_repository: UserRepository): + def __init__( + self, logger: Logger, user_repository: UserRepository, config: UserServiceConfig + ): super().__init__(logger) self.user_repository = user_repository + self.config = config + + @BaseService.route( + "/api/v1/current-user", + methods=["get"], + response_model=AuthUserResponse, + ) + async def get_current_user(self, access_token: str | None) -> AuthUserResponse: + return await self._get_auth_user_by_access_token(access_token) + + @BaseService.route( + "/api/v1/user-sessions", + methods=["post"], + response_model=UserSessionResponse, + ) + async def create_user_session( + self, credentials: UserCredentials + ) -> UserSessionResponse: + current_user = await self._get_user_by_credentials(credentials) + await self.user_repository.delete_expired_user_sessions(current_user.id) + user_sessions = await self.user_repository.get_active_user_sessions( + current_user.id + ) + user_session_count = len(user_sessions) + if user_session_count >= self.config.max_parallel_session: + await self._handle_excess_sessions(user_sessions) + token_data = self._create_user_token_data(current_user.username) + return await self.user_repository.create_user_session( + user_id=current_user.id, token_data=token_data + ) + + @BaseService.route( + "/api/v1/user-sessions", + methods=["put"], + response_model=UserSessionResponse, + ) + async def update_user_session(self, refresh_token: str) -> UserSessionResponse: + current_user = await self._get_auth_user_by_refresh_token(refresh_token) + current_user_session = ( + await self.user_respository.get_user_sesion_by_refresh_token(refresh_token) + ) + token_data = self._create_user_token_data(current_user.username) + return await self.user_repository.update_user_session( + user_id=current_user.id, + session_id=current_user_session.id, + token_data=token_data, + ) + + @BaseService.route( + "/api/v1/user-sessions", + methods=["delete"], + response_model=UserSessionResponse, + ) + async def delete_user_session(self, refresh_token: str) -> UserSessionResponse: + current_user_session = ( + await self.user_respository.get_user_sesion_by_refresh_token(refresh_token) + ) + await self.user_repository.delete_user_sessions([current_user_session.id]) + return current_user_session @BaseService.route( "/api/v1/users/{user_id}", @@ -50,13 +131,13 @@ async def get_users( async def create_user_bulk( self, data: list[UserCreateWithRolesAndAudit] ) -> list[UserResponse]: - role_ids = [row.get_role_ids() for row in data] + role_names = [row.get_role_names() for row in data] data = [row.get_user_create_with_audit() for row in data] users = await self.user_repository.create_bulk(data) if len(users) > 0: created_by = users[0].created_by await self.user_repository.add_roles( - data={user.id: role_ids[i] for i, user in enumerate(data)}, + data={user.id: role_names[i] for i, user in enumerate(users)}, created_by=created_by, ) return await self.user_repository.get_by_ids([user.id for user in users]) @@ -67,30 +148,30 @@ async def create_user_bulk( response_model=UserResponse, ) async def create_user(self, data: UserCreateWithRolesAndAudit) -> UserResponse: - role_ids = data.get_role_ids() + role_names = data.get_role_names() data = data.get_user_create_with_audit() user = await self.user_repository.create(data) await self.user_repository.add_roles( - data={user.id: role_ids}, created_by=user.created_by + data={user.id: role_names}, created_by=user.created_by ) return await self.user_repository.get_by_id(user.id) @BaseService.route( "/api/v1/users/bulk", methods=["put"], - response_model=UserResponse, + response_model=list[UserResponse], ) async def update_user_bulk( self, user_ids: list[str], data: UserUpdateWithRolesAndAudit - ) -> UserResponse: - role_ids = [row.get_role_ids() for row in data] + ) -> list[UserResponse]: + role_names = [row.get_role_names() for row in data] user_data = [row.get_user_create_with_audit() for row in data] await self.user_repository.update_bulk(user_ids, user_data) if len(user_ids) > 0: updated_by = user_data[0].updated_by await self.user_repository.remove_all_roles(user_ids) await self.user_repository.add_roles( - data={user_id: role_ids[i] for i, user_id in enumerate(user_ids)}, + data={user_id: role_names[i] for i, user_id in enumerate(user_ids)}, updated_by=updated_by, ) return await self.user_repository.get_by_ids(user_ids) @@ -103,23 +184,23 @@ async def update_user_bulk( async def update_user( self, user_id: str, data: UserUpdateWithRolesAndAudit ) -> UserResponse: - role_ids = data.get_role_ids() + role_names = data.get_role_names() user_data = data.get_user_update_with_audit() await self.user_repository.update(user_id, user_data) await self.user_repository.remove_all_roles([user_id]) await self.user_repository.add_roles( - data={user_id: role_ids}, created_by=user_data.updated_by + data={user_id: role_names}, created_by=user_data.updated_by ) return await self.user_repository.get_by_id(user_id) @BaseService.route( "/api/v1/users/bulk", methods=["delete"], - response_model=UserResponse, + response_model=list[UserResponse], ) async def delete_user_bulk( self, user_ids: list[str], deleted_by: str - ) -> UserResponse: + ) -> list[UserResponse]: roles = await self.user_repository.get_by_ids(user_ids) await self.user_repository.delete_bulk(user_ids) await self.user_repository.remove_all_roles(user_ids) @@ -135,3 +216,129 @@ async def delete_user(self, user_id: str, deleted_by: str) -> UserResponse: await self.user_repository.delete(user_id) await self.user_repository.remove_all_roles([user_id]) return user + + async def _get_auth_user_by_refresh_token( + self, refresh_token: str + ) -> AuthUserResponse: + if refresh_token is None or refresh_token == "": + raise NotFoundError("User not found") + user_session = await self.user_repository.get_user_session_by_refresh_token( + refresh_token + ) + user_id = user_session.user_id + if user_id == self.config.super_user: + return self._get_super_user() + user = await self.user_repository.get_by_id(user_id) + return self._to_auth_user_response(user) + + async def _get_auth_user_by_access_token( + self, access_token: str | None + ) -> AuthUserResponse: + if access_token is None or access_token == "": + return self._get_guest_user() + user_session = await self.user_repository.get_user_session_by_access_token( + access_token + ) + user_id = user_session.user_id + if user_id == self.config.super_user: + return self._get_super_user() + try: + user = await self.user_repository.get_by_id(user_id) + return self._to_auth_user_response(user) + except NotFoundError: + return self._get_guest_user() + + async def _get_user_by_credentials( + self, credentials: UserCredentials + ) -> AuthUserResponse: + if ( + credentials.username == self.config.super_user + and credentials.password == self.config.super_user_password + ): + return self._get_super_user() + user = await self.user_repository.get_by_credentials( + username=credentials.username, + password=credentials.password, + ) + return self._to_auth_user_response(user) + + def _to_auth_user_response(self, user_response: UserResponse) -> AuthUserResponse: + return AuthUserResponse( + **user_response.model_dump(), is_guest=False, is_super_user=False + ) + + def _get_guest_user(self): + return AuthUserResponse( + id=self.config.guest_user, + username=self.config.guest_user, + active=True, + role_names=[], + permission_names=self.config.guest_user_permissions, + is_guest=True, + is_super_user=False, + ) + + def _get_super_user(self): + return AuthUserResponse( + id=self.config.super_user, + username=self.config.super_user, + active=True, + role_names=[], + permission_names=[], + is_guest=False, + is_super_user=True, + ) + + async def _handle_excess_sessions(self, active_sessions: list[UserSessionResponse]): + """Handles excess user sessions by deleting the oldest if necessary.""" + if not self.config.prioritize_new_session: + raise ForbiddenError("No additional session allowed") + # Sort sessions by expiration and remove the oldest ones + sessions_to_delete = sorted( + active_sessions, key=lambda s: s.refresh_token_expired_at + ) + excess_count = len(active_sessions) + 1 - self.config.max_parallel_session + await self.user_repository.delete_user_sessions( + [session.id for session in sessions_to_delete[:excess_count]] + ) + + def _create_user_token_data(self, username: str) -> UserTokenData: + now = datetime.datetime.now(datetime.timezone.utc) + access_token_expire_at = now + datetime.timedelta( + minutes=self.config.access_token_expire_minutes + ) + refresh_token_expire_at = now + datetime.timedelta( + minutes=self.config.refresh_token_expire_minutes + ) + return UserTokenData( + access_token=self._generate_access_token( + username=username, + expire_at=access_token_expire_at, + ), + refresh_token=self._generate_refresh_token( + username=username, + expire_at=refresh_token_expire_at, + ), + access_token_expired_at=access_token_expire_at, + refresh_token_expired_at=refresh_token_expire_at, + ) + + def _generate_access_token( + self, username: str, expire_at: datetime.datetime + ) -> str: + return self._generate_user_token( + username=username, expire_at=expire_at, token_type="access" + ) + + def _generate_refresh_token( + self, username: str, expire_at: datetime.datetime + ) -> str: + return self._generate_user_token( + username=username, expire_at=expire_at, token_type="refresh" + ) + + def _generate_user_token( + self, username: str, expire_at: datetime.datetime, token_type: str + ) -> str: + to_encode = {"sub": username, "exp": expire_at, "type": token_type} + return jwt.encode(to_encode, self.config.secret_key) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py index ea848ed1..a93198d2 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py @@ -1,7 +1,35 @@ from my_app_name.common.logger_factory import logger +from my_app_name.config import ( + APP_AUTH_ACCESS_TOKEN_EXPIRE_MINUTES, + APP_AUTH_GUEST_USER, + APP_AUTH_GUEST_USER_PERMISSIONS, + APP_AUTH_MAX_PARALLEL_SESSION, + APP_AUTH_PRIORITIZE_NEW_SESSION, + APP_AUTH_REFRESH_TOKEN_EXPIRE_MINUTES, + APP_AUTH_SECRET_KEY, + APP_AUTH_SUPER_USER, + APP_AUTH_SUPER_USER_PASSWORD, +) from my_app_name.module.auth.service.user.repository.user_repository_factory import ( user_repository, ) -from my_app_name.module.auth.service.user.user_service import UserService +from my_app_name.module.auth.service.user.user_service import ( + UserService, + UserServiceConfig, +) -user_service = UserService(logger, user_repository=user_repository) +user_service = UserService( + logger, + user_repository=user_repository, + config=UserServiceConfig( + super_user=APP_AUTH_SUPER_USER, + super_user_password=APP_AUTH_SUPER_USER_PASSWORD, + guest_user=APP_AUTH_GUEST_USER, + guest_user_permissions=APP_AUTH_GUEST_USER_PERMISSIONS, + max_parallel_session=APP_AUTH_MAX_PARALLEL_SESSION, + access_token_expire_minutes=APP_AUTH_ACCESS_TOKEN_EXPIRE_MINUTES, + refresh_token_expire_minutes=APP_AUTH_REFRESH_TOKEN_EXPIRE_MINUTES, + secret_key=APP_AUTH_SECRET_KEY, + prioritize_new_session=APP_AUTH_PRIORITIZE_NEW_SESSION, + ), +) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py index 5e9b4e71..6b65371b 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py @@ -1,5 +1,9 @@ -from fastapi import FastAPI +from typing import Annotated + +from fastapi import Depends, FastAPI, Response +from fastapi.security import OAuth2PasswordRequestForm from my_app_name.module.auth.client.auth_client_factory import auth_client +from my_app_name.module.gateway.util.auth import get_current_user from my_app_name.schema.permission import ( MultiplePermissionResponse, PermissionCreate, @@ -13,19 +17,43 @@ RoleUpdateWithPermissions, ) from my_app_name.schema.user import ( + AuthUserResponse, MultipleUserResponse, UserCreateWithRoles, + UserCredentials, UserResponse, + UserSessionResponse, UserUpdateWithRoles, ) def serve_auth_route(app: FastAPI): + @app.post("/api/v1/user-sessions", response_model=UserSessionResponse) + async def create_user_session( + response: Response, form_data: Annotated[OAuth2PasswordRequestForm, Depends()] + ) -> UserSessionResponse: + user_session = await auth_client.create_user_session( + UserCredentials( + username=form_data.username, + password=form_data.password, + ) + ) + return user_session + + @app.put("/api/v1/user-sessions", response_model=UserSessionResponse) + async def update_user_session(refresh_token: str) -> UserSessionResponse: + return await auth_client.update_user_session(refresh_token) + + @app.delete("/api/v1/user-sessions", response_model=UserSessionResponse) + async def delete_user_session(refresh_token: str) -> UserSessionResponse: + return await auth_client.delete_user_session(refresh_token) + # Permission routes @app.get("/api/v1/permissions", response_model=MultiplePermissionResponse) async def get_permissions( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], page: int = 1, page_size: int = 10, sort: str | None = None, @@ -36,61 +64,91 @@ async def get_permissions( ) @app.get("/api/v1/permissions/{permission_id}", response_model=PermissionResponse) - async def get_permission_by_id(permission_id: str) -> PermissionResponse: + async def get_permission_by_id( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + permission_id: str, + ) -> PermissionResponse: return await auth_client.get_permission_by_id(permission_id) @app.post( "/api/v1/permissions/bulk", response_model=list[PermissionResponse], ) - async def create_permission_bulk(data: list[PermissionCreate]): + async def create_permission_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: list[PermissionCreate], + ) -> list[PermissionResponse]: return await auth_client.create_permission_bulk( - [row.with_audit(created_by="system") for row in data] + [row.with_audit(created_by=current_user.id) for row in data] ) @app.post( "/api/v1/permissions", response_model=PermissionResponse, ) - async def create_permission(data: PermissionCreate): - return await auth_client.create_permission(data.with_audit(created_by="system")) + async def create_permission( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: PermissionCreate, + ) -> PermissionResponse: + return await auth_client.create_permission( + data.with_audit(created_by=current_user.id) + ) @app.put( "/api/v1/permissions/bulk", response_model=list[PermissionResponse], ) - async def update_permission_bulk(permission_ids: list[str], data: PermissionUpdate): + async def update_permission_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + permission_ids: list[str], + data: PermissionUpdate, + ) -> list[PermissionResponse]: return await auth_client.update_permission_bulk( - permission_ids, data.with_audit(updated_by="system") + permission_ids, data.with_audit(updated_by=current_user.id) ) @app.put( "/api/v1/permissions/{permission_id}", response_model=PermissionResponse, ) - async def update_permission(permission_id: str, data: PermissionUpdate): - return await auth_client.update_permission(data.with_audit(updated_by="system")) + async def update_permission( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + permission_id: str, + data: PermissionUpdate, + ) -> PermissionResponse: + return await auth_client.update_permission( + data.with_audit(updated_by=current_user.id) + ) @app.delete( "/api/v1/permissions/bulk", response_model=list[PermissionResponse], ) - async def delete_permission_bulk(permission_ids: list[str]): + async def delete_permission_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + permission_ids: list[str], + ) -> list[PermissionResponse]: return await auth_client.delete_permission_bulk( - permission_ids, deleted_by="system" + permission_ids, deleted_by=current_user.id ) @app.delete( "/api/v1/permissions/{permission_id}", response_model=PermissionResponse, ) - async def delete_permission(permission_id: str): - return await auth_client.delete_permission(permission_id, deleted_by="system") + async def delete_permission( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + permission_id: str, + ) -> PermissionResponse: + return await auth_client.delete_permission( + permission_id, deleted_by=current_user.id + ) # Role routes @app.get("/api/v1/roles", response_model=MultipleRoleResponse) async def get_roles( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], page: int = 1, page_size: int = 10, sort: str | None = None, @@ -101,59 +159,87 @@ async def get_roles( ) @app.get("/api/v1/roles/{role_id}", response_model=RoleResponse) - async def get_role_by_id(role_id: str) -> RoleResponse: + async def get_role_by_id( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + role_id: str, + ) -> RoleResponse: return await auth_client.get_role_by_id(role_id) @app.post( "/api/v1/roles/bulk", response_model=list[RoleResponse], ) - async def create_role_bulk(data: list[RoleCreateWithPermissions]): + async def create_role_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: list[RoleCreateWithPermissions], + ) -> list[RoleResponse]: return await auth_client.create_role_bulk( - [row.with_audit(created_by="system") for row in data] + [row.with_audit(created_by=current_user.id) for row in data] ) @app.post( "/api/v1/roles", response_model=RoleResponse, ) - async def create_role(data: RoleCreateWithPermissions): - return await auth_client.create_role(data.with_audit(created_by="system")) + async def create_role( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: RoleCreateWithPermissions, + ) -> RoleResponse: + return await auth_client.create_role( + data.with_audit(created_by=current_user.id) + ) @app.put( "/api/v1/roles/bulk", response_model=list[RoleResponse], ) - async def update_role_bulk(role_ids: list[str], data: RoleUpdateWithPermissions): + async def update_role_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + role_ids: list[str], + data: RoleUpdateWithPermissions, + ) -> list[RoleResponse]: return await auth_client.update_role_bulk( - role_ids, data.with_audit(updated_by="system") + role_ids, data.with_audit(updated_by=current_user.id) ) @app.put( "/api/v1/roles/{role_id}", response_model=RoleResponse, ) - async def update_role(role_id: str, data: RoleUpdateWithPermissions): - return await auth_client.update_role(data.with_audit(updated_by="system")) + async def update_role( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + role_id: str, + data: RoleUpdateWithPermissions, + ) -> RoleResponse: + return await auth_client.update_role( + data.with_audit(updated_by=current_user.id) + ) @app.delete( "/api/v1/roles/bulk", response_model=list[RoleResponse], ) - async def delete_role_bulk(role_ids: list[str]): - return await auth_client.delete_role_bulk(role_ids, deleted_by="system") + async def delete_role_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + role_ids: list[str], + ) -> list[RoleResponse]: + return await auth_client.delete_role_bulk(role_ids, deleted_by=current_user.id) @app.delete( "/api/v1/roles/{role_id}", response_model=RoleResponse, ) - async def delete_role(role_id: str): - return await auth_client.delete_role(role_id, deleted_by="system") + async def delete_role( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + role_id: str, + ) -> RoleResponse: + return await auth_client.delete_role(role_id, deleted_by=current_user.id) # User routes @app.get("/api/v1/users", response_model=MultipleUserResponse) async def get_users( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], page: int = 1, page_size: int = 10, sort: str | None = None, @@ -164,51 +250,78 @@ async def get_users( ) @app.get("/api/v1/users/{user_id}", response_model=UserResponse) - async def get_user_by_id(user_id: str) -> UserResponse: + async def get_user_by_id( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + user_id: str, + ) -> UserResponse: return await auth_client.get_user_by_id(user_id) @app.post( "/api/v1/users/bulk", response_model=list[UserResponse], ) - async def create_user_bulk(data: list[UserCreateWithRoles]): - return await auth_client.create_user( - [row.with_audit(created_by="system") for row in data] + async def create_user_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: list[UserCreateWithRoles], + ) -> list[UserResponse]: + return await auth_client.create_user_bulk( + [row.with_audit(created_by=current_user.id) for row in data] ) @app.post( "/api/v1/users", response_model=UserResponse, ) - async def create_user(data: UserCreateWithRoles): - return await auth_client.create_user(data.with_audit(created_by="system")) + async def create_user( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + data: UserCreateWithRoles, + ) -> UserResponse: + return await auth_client.create_user( + data.with_audit(created_by=current_user.id) + ) @app.put( "/api/v1/users/bulk", response_model=list[UserResponse], ) - async def update_user_bulk(user_ids: list[str], data: UserUpdateWithRoles): + async def update_user_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + user_ids: list[str], + data: UserUpdateWithRoles, + ) -> list[UserResponse]: return await auth_client.update_user_bulk( - user_ids, data.with_audit(updated_by="system") + user_ids, data.with_audit(updated_by=current_user.id) ) @app.put( "/api/v1/users/{user_id}", response_model=UserResponse, ) - async def update_user(user_id: str, data: UserUpdateWithRoles): - return await auth_client.update_user(data.with_audit(updated_by="system")) + async def update_user( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + user_id: str, + data: UserUpdateWithRoles, + ) -> UserResponse: + return await auth_client.update_user( + data.with_audit(updated_by=current_user.id) + ) @app.delete( "/api/v1/users/bulk", response_model=list[UserResponse], ) - async def delete_user_bulk(user_ids: list[str]): - return await auth_client.delete_user_bulk(user_ids, deleted_by="system") + async def delete_user_bulk( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + user_ids: list[str], + ) -> list[UserResponse]: + return await auth_client.delete_user_bulk(user_ids, deleted_by=current_user.id) @app.delete( "/api/v1/users/{user_id}", response_model=UserResponse, ) - async def delete_user(user_id: str): - return await auth_client.delete_user(user_id, deleted_by="system") + async def delete_user( + current_user: Annotated[AuthUserResponse, Depends(get_current_user)], + user_id: str, + ) -> UserResponse: + return await auth_client.delete_user(user_id, deleted_by=current_user.id) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/util/auth.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/util/auth.py new file mode 100644 index 00000000..81aa37d7 --- /dev/null +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/util/auth.py @@ -0,0 +1,19 @@ +from fastapi import Depends, Request +from fastapi.security import OAuth2PasswordBearer +from my_app_name.config import APP_AUTH_ACCESS_TOKEN_COOKIE_NAME +from my_app_name.module.auth.client.auth_client_factory import auth_client +from my_app_name.schema.user import AuthUserResponse +from typing_extensions import Annotated + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/user-sessions", auto_error=False) + + +async def get_current_user( + request: Request, bearer_access_token: Annotated[str, Depends(oauth2_scheme)] +) -> AuthUserResponse: + bearer_user = await auth_client.get_current_user(bearer_access_token) + if bearer_user is None or bearer_user.is_guest: + cookie_access_token = request.cookies.get(APP_AUTH_ACCESS_TOKEN_COOKIE_NAME) + if cookie_access_token is not None and cookie_access_token != "": + return await auth_client.get_current_user(cookie_access_token) + return bearer_user diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt index 64202ae4..7405d5f4 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt @@ -3,4 +3,5 @@ alembic~=1.14.0 sqlmodel~=0.0.22 ulid-py~=1.1.0 passlib~=1.7.4 -Jinja2==3.1.5 +Jinja2~=3.1.5 +python-jose~=3.3.0 \ No newline at end of file diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py index f4fb986f..3654fd00 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py @@ -42,6 +42,7 @@ class MultiplePermissionResponse(BaseModel): class Permission(SQLModel, table=True): + __tablename__ = "permissions" id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) created_at: datetime.datetime | None = Field(index=True) created_by: str | None = Field(index=True) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py index 7dbe2d27..f8d3e51a 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py @@ -1,7 +1,6 @@ import datetime import ulid -from my_app_name.schema.permission import Permission from pydantic import BaseModel from sqlmodel import Field, SQLModel @@ -22,7 +21,7 @@ class RoleCreateWithAudit(RoleCreate): class RoleCreateWithPermissions(RoleCreate): - permission_ids: list[str] | None = None + permission_names: list[str] | None = None def with_audit(self, created_by: str) -> "RoleCreateWithPermissionsAndAudit": return RoleCreateWithPermissionsAndAudit( @@ -37,14 +36,14 @@ def get_role_create_with_audit(self) -> RoleCreateWithAudit: data = { key: val for key, val in self.model_dump().items() - if key != "permission_ids" + if key != "permission_names" } return RoleCreateWithAudit(**data) - def get_permission_ids(self) -> list[str]: - if self.permission_ids is None: + def get_permission_names(self) -> list[str]: + if self.permission_names is None: return [] - return self.permission_ids + return self.permission_names class RoleUpdate(SQLModel): @@ -60,7 +59,7 @@ class RoleUpdateWithAudit(RoleUpdate): class RoleUpdateWithPermissions(RoleUpdate): - permission_ids: list[str] | None = None + permission_names: list[str] | None = None def with_audit(self, updated_by: str) -> "RoleUpdateWithPermissionsAndAudit": return RoleUpdateWithPermissionsAndAudit( @@ -75,19 +74,19 @@ def get_role_update_with_audit(self) -> RoleUpdateWithAudit: data = { key: val for key, val in self.model_dump().items() - if key != "permission_ids" + if key != "permission_names" } return RoleUpdateWithAudit(**data) - def get_permission_ids(self) -> list[str]: - if self.permission_ids is None: + def get_permission_names(self) -> list[str]: + if self.permission_names is None: return [] - return self.permission_ids + return self.permission_names class RoleResponse(RoleBase): id: str - permissions: list[Permission] + permission_names: list[str] class MultipleRoleResponse(BaseModel): @@ -96,6 +95,7 @@ class MultipleRoleResponse(BaseModel): class Role(SQLModel, table=True): + __tablename__ = "roles" id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) created_at: datetime.datetime | None = Field(index=True) created_by: str | None = Field(index=True) @@ -106,6 +106,7 @@ class Role(SQLModel, table=True): class RolePermission(SQLModel, table=True): + __tablename__ = "role_permissions" id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) role_id: str = Field(index=True) permission_id: str = Field(index=True) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py deleted file mode 100644 index 6aa6f37b..00000000 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +++ /dev/null @@ -1,48 +0,0 @@ -import datetime - -import ulid -from pydantic import BaseModel -from sqlmodel import Field, SQLModel - - -class SessionBase(SQLModel): - user_id: str - token: str - expired_at: datetime.datetime | None - - -class SessionCreate(SessionBase): - - def with_audit(self, created_by: str) -> "SessionCreateWithAudit": - return SessionCreateWithAudit(**self.model_dump(), created_by=created_by) - - -class SessionCreateWithAudit(SessionCreate): - created_by: str - - -class SessionUpdate(SessionBase): - - def with_audit(self, updated_by: str) -> "SessionUpdateWithAudit": - return SessionUpdateWithAudit(**self.model_dump(), updated_by=updated_by) - - -class SessionUpdateWithAudit(SessionUpdate): - updated_by: str - - -class SessionResponse(SessionBase): - id: str - created_at: datetime.datetime = Field(index=True) - updated_at: datetime.datetime | None = Field(index=True) - - -class MultipleSessionResponse(BaseModel): - data: list[SessionResponse] - count: int - - -class Session(SQLModel, table=True): - id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) - user_id: str = Field(index=True) - token: str = Field(index=True, unique=True) diff --git a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py index b5482945..8fc7cf5b 100644 --- a/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +++ b/src/zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py @@ -9,6 +9,7 @@ class UserBase(SQLModel): username: str + active: bool class UserCreate(UserBase): @@ -23,7 +24,7 @@ class UserCreateWithAudit(UserCreate): class UserCreateWithRoles(UserCreate): - role_ids: list[str] | None = None + role_names: list[str] | None = None def with_audit(self, created_by: str) -> "UserCreateWithRolesAndAudit": return UserCreateWithRolesAndAudit(**self.model_dump(), created_by=created_by) @@ -33,18 +34,21 @@ class UserCreateWithRolesAndAudit(UserCreateWithRoles): created_by: str def get_user_create_with_audit(self) -> UserCreateWithAudit: - data = {key: val for key, val in self.model_dump().items() if key != "role_ids"} + data = { + key: val for key, val in self.model_dump().items() if key != "role_names" + } return UserCreateWithAudit(**data) - def get_role_ids(self) -> list[str]: - if self.role_ids is None: + def get_role_names(self) -> list[str]: + if self.role_names is None: return [] - return self.role_ids + return self.role_names class UserUpdate(SQLModel): username: str | None = None password: str | None = None + active: bool | None = None def with_audit(self, updated_by: str) -> "UserUpdateWithAudit": return UserUpdateWithAudit(**self.model_dump(), updated_by=updated_by) @@ -55,7 +59,7 @@ class UserUpdateWithAudit(UserUpdate): class UserUpdateWithRoles(UserUpdate): - role_ids: list[str] | None = None + role_names: list[str] | None = None def with_audit(self, updated_by: str) -> "UserUpdateWithRolesAndAudit": return UserUpdateWithRolesAndAudit(**self.model_dump(), updated_by=updated_by) @@ -65,19 +69,32 @@ class UserUpdateWithRolesAndAudit(UserUpdateWithRoles): updated_by: str def get_user_update_with_audit(self) -> UserUpdateWithAudit: - data = {key: val for key, val in self.model_dump().items() if key != "role_ids"} + data = { + key: val for key, val in self.model_dump().items() if key != "role_names" + } return UserUpdateWithAudit(**data) - def get_role_ids(self) -> list[str]: - if self.role_ids is None: + def get_role_names(self) -> list[str]: + if self.role_names is None: return [] - return self.role_ids + return self.role_names class UserResponse(UserBase): id: str - roles: list[Role] - permissions: list[Permission] + role_names: list[str] + permission_names: list[str] + + +class AuthUserResponse(UserResponse): + is_super_user: bool + is_guest: bool + + def has_permission(self, permission_name: str): + return self.is_super_user or permission_name in self.permission_names + + def has_role(self, role_name: str): + return self.is_super_user or role_name in self.role_names class MultipleUserResponse(BaseModel): @@ -85,7 +102,27 @@ class MultipleUserResponse(BaseModel): count: int +class UserCredentials(SQLModel): + username: str + password: str + + +class UserTokenData(SQLModel): + access_token: str + refresh_token: str + access_token_expired_at: datetime.datetime + refresh_token_expired_at: datetime.datetime + + +class UserSessionResponse(SQLModel): + id: str + user_id: str + access_token_expired_at: datetime.datetime + refresh_token_expired_at: datetime.datetime + + class User(SQLModel, table=True): + __tablename__ = "users" id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) created_at: datetime.datetime = Field(index=True) created_by: str = Field(index=True) @@ -93,11 +130,23 @@ class User(SQLModel, table=True): updated_by: str | None = Field(index=True) username: str = Field(index=True, unique=True) password: str + active: bool = Field(index=True) class UserRole(SQLModel, table=True): + __tablename__ = "user_roles" id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) user_id: str = Field(index=True) role_id: str = Field(index=True) created_at: datetime.datetime | None created_by: str | None + + +class UserSession(SQLModel, table=True): + __tablename__ = "user_sessions" + id: str = Field(default_factory=lambda: ulid.new().str, primary_key=True) + user_id: str = Field(index=True) + access_token: str = Field(index=True) + refresh_token: str = Field(index=True) + access_token_expired_at: datetime.datetime = Field(index=True) + refresh_token_expired_at: datetime.datetime = Field(index=True) diff --git a/src/zrb/builtin/setup/zsh/zsh.py b/src/zrb/builtin/setup/zsh/zsh.py index 5ece8d24..5d8fc0d0 100644 --- a/src/zrb/builtin/setup/zsh/zsh.py +++ b/src/zrb/builtin/setup/zsh/zsh.py @@ -53,4 +53,3 @@ def setup_zsh(ctx: AnyContext): # Write config write_file(zsh_config_file, [content, zsh_config, ""]) ctx.print("Setup complete, restart your terminal to continue") - diff --git a/src/zrb/task/base_task.py b/src/zrb/task/base_task.py index 26b3eaf9..08d3417b 100644 --- a/src/zrb/task/base_task.py +++ b/src/zrb/task/base_task.py @@ -295,16 +295,16 @@ def __fill_shared_context_envs(self, shared_context: AnySharedContext): async def exec_root_tasks(self, session: AnySession): session.set_main_task(self) session.state_logger.write(session.as_state_log()) - log_state = asyncio.create_task(self._log_session_state(session)) - root_tasks = [ - task - for task in session.get_root_tasks(self) - if session.is_allowed_to_run(task) - ] - root_task_coros = [ - run_async(root_task.exec_chain(session)) for root_task in root_tasks - ] try: + log_state = asyncio.create_task(self._log_session_state(session)) + root_tasks = [ + task + for task in session.get_root_tasks(self) + if session.is_allowed_to_run(task) + ] + root_task_coros = [ + run_async(root_task.exec_chain(session)) for root_task in root_tasks + ] await asyncio.gather(*root_task_coros) await session.wait_deferred() session.terminate() diff --git a/src/zrb/task/cmd_task.py b/src/zrb/task/cmd_task.py index 2c4f85b2..3e73d2d5 100644 --- a/src/zrb/task/cmd_task.py +++ b/src/zrb/task/cmd_task.py @@ -126,7 +126,7 @@ async def _exec_action(self, ctx: AnyContext) -> CmdResult: if self._get_should_warn_unrecommended_commands(): self._check_unrecommended_commands(ctx, shell, cmd_script) ctx.log_info("Running script") - log_method = ( + print_method = ( partial(ctx.print, plain=True) if self._should_plain_print else ctx.print ) xcom_pid_key = f"{self.name}-pid" @@ -135,17 +135,14 @@ async def _exec_action(self, ctx: AnyContext) -> CmdResult: cmd=[shell, shell_flag, cmd_script], cwd=cwd, env_map=env_map, - print_method=log_method, + print_method=print_method, register_pid_method=lambda pid: ctx.xcom.get(xcom_pid_key).push(pid), max_output_line=self._max_output_line, max_error_line=self._max_error_line, ) # Check for errors - if return_code != 0: - ctx.log_error(f"Exit status: {return_code}") - raise Exception( - f"Process {self._name} exited ({return_code}): {cmd_result.error}" - ) + if return_code > 0: + raise Exception(f"Process {self._name} exited ({return_code})") ctx.log_info(f"Exit status: {return_code}") return cmd_result diff --git a/src/zrb/util/cmd/command.py b/src/zrb/util/cmd/command.py index 9d8fe25d..252e24a6 100644 --- a/src/zrb/util/cmd/command.py +++ b/src/zrb/util/cmd/command.py @@ -1,26 +1,12 @@ import asyncio import os import re -import signal import sys from collections.abc import Callable -from zrb.cmd.cmd_result import CmdResult - -_RUNNING_PROCESSES = [] - - -def _cleanup_processes(): - """Terminate all running subprocesses on exit.""" - for process in _RUNNING_PROCESSES: - if process.returncode is None: - process.terminate() - process.kill() - +import psutil -# Register a single cleanup function once -signal.signal(signal.SIGINT, lambda sig, frame: _cleanup_processes()) -signal.signal(signal.SIGTERM, lambda sig, frame: _cleanup_processes()) +from zrb.cmd.cmd_result import CmdResult def check_unrecommended_commands(cmd_script: str) -> dict[str, str]: @@ -65,7 +51,7 @@ async def run_command( max_error_line: int = 1000, ) -> tuple[CmdResult, int]: async def __read_stream( - stream, log_method: Callable[..., None], max_lines: int + stream, print_method: Callable[..., None], max_lines: int ) -> str: lines = [] while True: @@ -76,40 +62,45 @@ async def __read_stream( lines.append(line) if len(lines) > max_lines: lines.pop(0) # Keep only the last max_lines - log_method(line) + print_method(line) return "\n".join(lines) actual_print_method = print_method if print_method is not None else print cmd_process = None - try: - if cwd is None: - cwd = os.getcwd() - if env_map is None: - env_map = os.environ - cmd_process = await asyncio.create_subprocess_exec( - *cmd, - cwd=cwd, - stdin=sys.stdin if sys.stdin.isatty() else None, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=env_map, - bufsize=0, - ) - if register_pid_method is not None: - register_pid_method(cmd_process.pid) - stdout_task = asyncio.create_task( - __read_stream(cmd_process.stdout, actual_print_method, max_output_line) - ) - stderr_task = asyncio.create_task( - __read_stream(cmd_process.stderr, actual_print_method, max_error_line) - ) - # Wait for process to complete and gather stdout/stderr - return_code = await cmd_process.wait() - stdout = await stdout_task - stderr = await stderr_task - return CmdResult(stdout, stderr), return_code - finally: - if cmd_process in _RUNNING_PROCESSES: - _RUNNING_PROCESSES.remove(cmd_process) - if cmd_process is not None and cmd_process.returncode is None: - cmd_process.terminate() + if cwd is None: + cwd = os.getcwd() + if env_map is None: + env_map = os.environ + cmd_process = await asyncio.create_subprocess_exec( + *cmd, + cwd=cwd, + stdin=sys.stdin if sys.stdin.isatty() else None, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env_map, + bufsize=0, + ) + if register_pid_method is not None: + register_pid_method(cmd_process.pid) + stdout_task = asyncio.create_task( + __read_stream(cmd_process.stdout, actual_print_method, max_output_line) + ) + stderr_task = asyncio.create_task( + __read_stream(cmd_process.stderr, actual_print_method, max_error_line) + ) + # Wait for process to complete and gather stdout/stderr + return_code = await cmd_process.wait() + stdout = await stdout_task + stderr = await stderr_task + return CmdResult(stdout, stderr), return_code + + +def kill_pid(pid: int, print_method: Callable[..., None] | None = None): + actual_print_method = print_method if print_method is not None else print + parent = psutil.Process(pid) + children = parent.children(recursive=True) + for child in children: + actual_print_method(f"Killing child process {child.pid}") + child.terminate() + actual_print_method(f"Killing process {pid}") + parent.terminate() diff --git a/src/zrb/util/git.py b/src/zrb/util/git.py index fad9f7cc..109751dd 100644 --- a/src/zrb/util/git.py +++ b/src/zrb/util/git.py @@ -17,12 +17,12 @@ async def get_diff( repo_dir: str, source_commit: str, current_commit: str, - log_method: Callable[..., Any] = print, + print_method: Callable[..., Any] = print, ) -> DiffResult: cmd_result, exit_code = await run_command( cmd=["git", "diff", source_commit, current_commit], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") @@ -54,10 +54,10 @@ async def get_diff( ) -async def get_repo_dir(log_method: Callable[..., Any] = print) -> str: +async def get_repo_dir(print_method: Callable[..., Any] = print) -> str: cmd_result, exit_code = await run_command( cmd=["git", "rev-parse", "--show-toplevel"], - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") @@ -65,12 +65,12 @@ async def get_repo_dir(log_method: Callable[..., Any] = print) -> str: async def get_current_branch( - repo_dir: str, log_method: Callable[..., Any] = print + repo_dir: str, print_method: Callable[..., Any] = print ) -> str: cmd_result, exit_code = await run_command( cmd=["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") @@ -78,12 +78,12 @@ async def get_current_branch( async def get_branches( - repo_dir: str, log_method: Callable[..., Any] = print + repo_dir: str, print_method: Callable[..., Any] = print ) -> list[str]: cmd_result, exit_code = await run_command( cmd=["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") @@ -93,35 +93,35 @@ async def get_branches( async def delete_branch( - repo_dir: str, branch_name: str, log_method: Callable[..., Any] = print + repo_dir: str, branch_name: str, print_method: Callable[..., Any] = print ) -> str: cmd_result, exit_code = await run_command( cmd=["git", "branch", "-D", branch_name], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") return cmd_result.output.strip() -async def add(repo_dir: str, log_method: Callable[..., Any] = print): +async def add(repo_dir: str, print_method: Callable[..., Any] = print): _, exit_code = await run_command( cmd=["git", "add", ".", "-A"], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") async def commit( - repo_dir: str, message: str, log_method: Callable[..., Any] = print + repo_dir: str, message: str, print_method: Callable[..., Any] = print ) -> str: cmd_result, exit_code = await run_command( cmd=["git", "commit", "-m", message], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: ignored_error_message = "nothing to commit, working tree clean" @@ -133,24 +133,24 @@ async def commit( async def pull( - repo_dir: str, remote: str, branch: str, log_method: Callable[..., Any] = print + repo_dir: str, remote: str, branch: str, print_method: Callable[..., Any] = print ) -> str: _, exit_code = await run_command( cmd=["git", "pull", remote, branch], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") async def push( - repo_dir: str, remote: str, branch: str, log_method: Callable[..., Any] = print + repo_dir: str, remote: str, branch: str, print_method: Callable[..., Any] = print ) -> str: _, exit_code = await run_command( cmd=["git", "push", "-u", remote, branch], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") diff --git a/src/zrb/util/git_subtree.py b/src/zrb/util/git_subtree.py index 14c1ec8d..d4973689 100644 --- a/src/zrb/util/git_subtree.py +++ b/src/zrb/util/git_subtree.py @@ -36,7 +36,7 @@ async def add_subtree( repo_url: str, branch: str, prefix: str, - log_method: Callable[..., Any] = print, + print_method: Callable[..., Any] = print, ): config = load_config() if os.path.isdir(prefix): @@ -54,7 +54,7 @@ async def add_subtree( branch, ], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") @@ -69,7 +69,7 @@ async def pull_subtree( prefix: str, repo_url: str, branch: str, - log_method: Callable[..., Any] = print, + print_method: Callable[..., Any] = print, ): _, exit_code = await run_command( cmd=[ @@ -82,7 +82,7 @@ async def pull_subtree( branch, ], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") @@ -93,7 +93,7 @@ async def push_subtree( prefix: str, repo_url: str, branch: str, - log_method: Callable[..., Any] = print, + print_method: Callable[..., Any] = print, ): _, exit_code = await run_command( cmd=[ @@ -106,7 +106,7 @@ async def push_subtree( branch, ], cwd=repo_dir, - print_method=log_method, + print_method=print_method, ) if exit_code != 0: raise Exception(f"Non zero exit code: {exit_code}") diff --git a/zrb_init.py b/zrb_init.py index e60c4d9a..46fef3d3 100644 --- a/zrb_init.py +++ b/zrb_init.py @@ -2,11 +2,11 @@ import json import os import shutil -import signal import traceback from functools import partial from typing import Any +import psutil import requests import tomlkit @@ -26,7 +26,8 @@ ) from zrb.builtin.git import git_commit from zrb.config import DEFAULT_SHELL -from zrb.util.cmd.command import run_command +from zrb.util.cli.style import stylize_green, stylize_magenta +from zrb.util.cmd.command import kill_pid, run_command from zrb.util.file import read_file from zrb.util.load import load_file @@ -189,7 +190,7 @@ async def remove_generated(ctx: AnyContext): name="test-generate", description="๐Ÿช„ Generate app", group=test_generator_group, - alias="generate", + alias="scaffold", retries=0, ) async def test_generate(ctx: AnyContext): @@ -218,14 +219,18 @@ async def test_generate(ctx: AnyContext): ctx.print("Generate entity") await _run_shell_script( ctx, - "zrb project fastapp create entity --module library --entity book --plural books --column title", # noqa + "zrb project fastapp create entity --module library --entity book --plural books --column isbn", # noqa ) assert os.path.isfile(os.path.join(app_dir_path, "schema", "book.py")) # Create column ctx.print("Generate column") await _run_shell_script( ctx, - "zrb project fastapp create column --module library --entity book --column isbn --type str", # noqa + "zrb project fastapp create column --module library --entity book --column title --type str", # noqa + ) + await _run_shell_script( + ctx, + "zrb project fastapp create column --module library --entity book --column author --type str", # noqa ) # Create migration ctx.print("Generate migration") @@ -234,54 +239,56 @@ async def test_generate(ctx: AnyContext): ) -run_generated_fastapp = CmdTask( - name="run-generated-app", - description="๐Ÿƒ Run generated app", - readiness_check=[ - HttpCheck(name="check-monolith", url="http://localhost:3000/readiness"), - HttpCheck(name="check-gateway", url="http://localhost:3001/readiness"), - HttpCheck(name="check-auth-svc", url="http://localhost:3002/readiness"), - HttpCheck(name="check-lib-svc", url="http://localhost:3003/readiness"), - ], - cmd="zrb project fastapp run all", - plain_print=True, - retries=0, +run_generated_fastapp = test_generator_group.add_task( + CmdTask( + name="run-generated-app", + description="๐Ÿƒ Run generated app", + readiness_check=[ + HttpCheck(name="check-monolith", url="http://localhost:3000/readiness"), + HttpCheck(name="check-gateway", url="http://localhost:3001/readiness"), + HttpCheck(name="check-auth-svc", url="http://localhost:3002/readiness"), + HttpCheck(name="check-lib-svc", url="http://localhost:3003/readiness"), + ], + cmd="zrb project fastapp run all --env prod", + plain_print=True, + retries=0, + ), + alias="launch", ) -test_generator_group.add_task(run_generated_fastapp, alias="run") @make_task( name="test-generated-app", description="๐Ÿงช Test generated app", group=test_generator_group, - alias="eval", + alias="validate", retries=0, ) async def test_generated_fastapp(ctx: AnyContext) -> str: try: await asyncio.sleep(2) - ctx.print("Test fastapp monolith") + ctx.print(stylize_magenta("Test fastapp monolith")) await _test_fastapp_permission_api(ctx, "http://localhost:3000") await _test_fastapp_book_api(ctx, "http://localhost:3000") - ctx.print("Test fastapp gateway") + ctx.print(stylize_magenta("Test fastapp gateway")) await _test_fastapp_permission_api(ctx, "http://localhost:3001") await _test_fastapp_book_api(ctx, "http://localhost:3001") - print("\a") return "Test succeed, here have a beer ๐Ÿบ" finally: app_pid_xcom: Xcom = ctx.xcom.get("run-generated-app-pid") app_pid = app_pid_xcom.pop() - ctx.print(f"Killing process {app_pid}") - os.kill(app_pid, signal.SIGTERM) + kill_pid(app_pid, print_method=ctx.print) + print("\a") remove_generated >> test_generate >> run_generated_fastapp >> test_generated_fastapp async def _test_fastapp_permission_api(ctx: AnyContext, base_url: str): - ctx.print("Test creating permission") + ctx.print(stylize_green("Test creating permission")) url = f"{base_url}/api/v1/permissions" json_data = json.dumps({"name": "admin", "description": "Can do everything"}) + ctx.print(url, json_data) response = requests.post( url, data=json_data, headers={"Content-Type": "application/json"} ) @@ -292,26 +299,187 @@ async def _test_fastapp_permission_api(ctx: AnyContext, base_url: str): async def _test_fastapp_book_api(ctx: AnyContext, base_url: str): - ctx.print("Test creating books") + ids = await _test_fastapp_book_api_bulk_insert(ctx, base_url) + await _test_fastapp_book_api_bulk_update(ctx, base_url, ids) + bleach_id = await _test_fastapp_book_api_get_with_filter(ctx, base_url) + await _test_fastapp_book_api_update(ctx, base_url, id=bleach_id) + naruto_id = await _test_fastapp_book_api_insert(ctx, base_url) + await _test_fastapp_book_api_get_by_id(ctx, base_url, id=naruto_id) + await _test_fastapp_book_api_get_without_filter(ctx, base_url) + await _test_fastapp_book_api_delete(ctx, base_url, naruto_id) + await _test_fastapp_book_api_bulk_delete(ctx, base_url, [bleach_id]) + + +async def _test_fastapp_book_api_delete(ctx: AnyContext, base_url: str, id: str): + ctx.print(stylize_green("Test delete book by id")) + url = f"{base_url}/api/v1/books/{id}" + ctx.print(url) + response = requests.delete(url, headers={"Content-Type": "application/json"}) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + + +async def _test_fastapp_book_api_bulk_delete( + ctx: AnyContext, base_url: str, ids: list[str] +): + ctx.print(stylize_green("Test delete books")) + url = f"{base_url}/api/v1/books/bulk" + json_data = json.dumps(ids) + ctx.print(url, json_data) + response = requests.delete( + url, data=json_data, headers={"Content-Type": "application/json"} + ) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + + +async def _test_fastapp_book_api_get_by_id( + ctx: AnyContext, base_url: str, id: str +) -> str: + ctx.print(stylize_green("Test get book by id")) + url = f"{base_url}/api/v1/books/{id}" + ctx.print(url) + response = requests.get(url, headers={"Content-Type": "application/json"}) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + book = response.json() + assert book.get("title") == "Naruto" + assert book.get("author") == "Masashi Kishimoto" + assert book.get("isbn") == "978-0-1490-0625-5" + + +async def _test_fastapp_book_api_get_without_filter( + ctx: AnyContext, base_url: str +) -> str: + ctx.print(stylize_green("Test get books without filter")) + url = f"{base_url}/api/v1/books" + ctx.print(url) + response = requests.get(url, headers={"Content-Type": "application/json"}) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + response_json = response.json() + count = response_json.get("count") + data = response_json.get("data") + assert count == 5 + assert len(data) == 5 + + +async def _test_fastapp_book_api_get_with_filter(ctx: AnyContext, base_url: str) -> str: + ctx.print(stylize_green("Test get books with filter")) + url = f"{base_url}/api/v1/books?filter=title:eq:Bleach" + ctx.print(url) + response = requests.get(url, headers={"Content-Type": "application/json"}) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + response_json = response.json() + count = response_json.get("count") + data = response_json.get("data") + assert count == 1 + book = data[0] + assert book.get("title") == "Bleach" + assert book.get("author") == "Fujiko F. Fujio" + assert book.get("isbn") == "978-8-1744-3763-1" + return book.get("id") + + +async def _test_fastapp_book_api_bulk_insert( + ctx: AnyContext, base_url: str +) -> list[str]: + ctx.print(stylize_green("Test creating books")) url = f"{base_url}/api/v1/books/bulk" json_data = json.dumps( [ - {"title": "Doraemon"}, - {"title": "P Man"}, - {"title": "Kobochan"}, + {"title": "Doraemon", "isbn": "978-6-6625-3489-3", "author": "anonymous"}, + {"title": "P Man", "isbn": "978-0-9259-2124-6", "author": "anonymous"}, + {"title": "Kobochan", "isbn": "978-8-8818-0448-1", "author": "anonymous"}, + {"title": "Bleach", "isbn": "978-8-1744-3763-1", "author": "anonymous"}, ] ) + ctx.print(url, json_data) response = requests.post( url, data=json_data, headers={"Content-Type": "application/json"} ) ctx.print(response.status_code, response.text) assert response.status_code == 200 response_json = response.json() - assert len(response_json) == 3 + assert len(response_json) == 4 + isbns = [row.get("isbn") for row in response_json] + assert "978-6-6625-3489-3" in isbns + assert "978-0-9259-2124-6" in isbns + assert "978-8-8818-0448-1" in isbns + assert "978-8-1744-3763-1" in isbns + titles = [row.get("title") for row in response_json] + assert "Doraemon" in titles + assert "P Man" in titles + assert "Kobochan" in titles + assert "Bleach" in titles + authors = [row.get("author") for row in response_json] + assert "anonymous" in authors + return [row.get("id") for row in response_json] + + +async def _test_fastapp_book_api_bulk_update( + ctx: AnyContext, base_url: str, ids: list[str] +): + ctx.print(stylize_green("Test update books")) + url = f"{base_url}/api/v1/books/bulk" + json_data = json.dumps({"book_ids": ids, "data": {"author": "Fujiko F. Fujio"}}) + ctx.print(url, json_data) + response = requests.put( + url, data=json_data, headers={"Content-Type": "application/json"} + ) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + response_json = response.json() + assert len(response_json) == 4 + isbns = [row.get("isbn") for row in response_json] + assert "978-6-6625-3489-3" in isbns + assert "978-0-9259-2124-6" in isbns + assert "978-8-8818-0448-1" in isbns + assert "978-8-1744-3763-1" in isbns titles = [row.get("title") for row in response_json] assert "Doraemon" in titles assert "P Man" in titles assert "Kobochan" in titles + assert "Bleach" in titles + authors = [row.get("author") for row in response_json] + assert "anonymous" not in authors + assert "Fujiko F. Fujio" in authors + + +async def _test_fastapp_book_api_insert(ctx: AnyContext, base_url: str) -> str: + ctx.print(stylize_green("Test creating a book")) + url = f"{base_url}/api/v1/books" + json_data = json.dumps( + {"title": "Naruto", "isbn": "978-0-1490-0625-5", "author": "Masashi Kishimoto"}, + ) + ctx.print(url, json_data) + response = requests.post( + url, data=json_data, headers={"Content-Type": "application/json"} + ) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + response_json = response.json() + assert response_json.get("title") == "Naruto" + assert response_json.get("isbn") == "978-0-1490-0625-5" + assert response_json.get("author") == "Masashi Kishimoto" + return response_json.get("id") + + +async def _test_fastapp_book_api_update(ctx: AnyContext, base_url: str, id: str): + ctx.print(stylize_green("Test update book")) + url = f"{base_url}/api/v1/books/{id}" + json_data = json.dumps({"author": "Tite Kubo"}) + ctx.print(url, json_data) + response = requests.put( + url, data=json_data, headers={"Content-Type": "application/json"} + ) + ctx.print(response.status_code, response.text) + assert response.status_code == 200 + response_json = response.json() + assert response_json.get("title") == "Bleach" + assert response_json.get("isbn") == "978-8-1744-3763-1" + assert response_json.get("author") == "Tite Kubo" # PLAYGROUND ==================================================================