Skip to content

Commit

Permalink
Add invoke tasks to a collection
Browse files Browse the repository at this point in the history
Previously, the invoke tasks were defined in the global namespace.
Now we add them to a collection in order to facilitate the overloading
of tasks by the user plugins and to better organize them.

ASIM-5375
  • Loading branch information
ThalesCarl committed Jan 30, 2024
1 parent 2068b66 commit 78063d9
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/source/plugins/03_plugin_structure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ Below there is detailed list on what they are meant to do. However, in general,
:prog: invoke

You can check their implementations `here <https://github.com/ESSS/alfasim-sdk/blob/master/src/alfasim_sdk/default_tasks.py>`_
and you can also overwrite them by just defining a function with the same name of the default task with the ``@task`` decorator.
and you can also overwrite them by just defining a function with the same name of the default task with the ``@sdk_task`` decorator.

For instance, if you want to overwrite the ``clean`` task, define the following inside your ``tasks.py`` in the root of your plugin.

.. code-block:: python
from invoke import task
@task
@sdk_task
def clean():
print("Overwriting the clean task")
67 changes: 58 additions & 9 deletions src/alfasim_sdk/default_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,49 @@
import subprocess
import sys
from pathlib import Path
from typing import Any
from typing import Union
from zipfile import ZipFile

from colorama import Fore
from colorama import Style
from hookman.hookman_generator import HookManGenerator
from invoke import Collection
from invoke import Exit
from invoke import Task
from invoke import task
from strictyaml.ruamel import YAML

sdk_ns = Collection()


def sdk_task(*args: object, **kwargs: object) -> Any:
"""
Similar to the native @task decorator, but also registers
the task in the global ``sdk`` namespace.
"""
if len(args) == 1 and not kwargs:
# Direct decoration:
# @task
# def cog(...)
fn = args[0]
assert callable(fn)
t = task(fn)
# error: Argument 1 to "add_task" of "Collection" has incompatible type "Callable[..., Any]"; expected "Task[Any]" [arg-type]
sdk_ns.add_task(t) # type:ignore[arg-type]
return t
else:
# Indirect decoration:
# @ns_task(help="...")
# def codegen(...):
def inner(fn) -> Task:
assert callable(fn)
t = task(*args, **kwargs)(fn)
sdk_ns.add_task(t)
return t

return inner


def print_message(
message: str,
Expand Down Expand Up @@ -61,7 +94,7 @@ def get_msvc_cmake_generator(msvc_compiler: str) -> str:
# =============================================================
# ========================= Tasks ===========================
# =============================================================
@task(
@sdk_task(
help={
"cmake_extra_args": "Extra arguments that will be passed to cmake",
"debug": "Compile in debug mode",
Expand Down Expand Up @@ -170,7 +203,7 @@ def _get_hook_specs_file_path() -> Path:
return Path(alfasim_sdk._internal.hook_specs.__file__)


@task(
@sdk_task(
help={
"package-name": "Name of the package. If empty, the package-name will be assumed to be the pluginid",
"dst": "A path to where the output package should be created.",
Expand Down Expand Up @@ -212,7 +245,7 @@ def package_only(ctx, package_name="", dst=os.getcwd()):
)


@task(
@sdk_task(
help={
"package-name": "Name of the package. If empty, the package-name will be assumed to be the plugin_id",
"dst": "A path to where the output package should be created.",
Expand All @@ -237,7 +270,7 @@ def package(
package_only(ctx, package_name=package_name, dst=dst)


@task()
@sdk_task()
def update(ctx):
"""
Update plugin files automatically generated by ALFAsim-SDK.
Expand All @@ -255,10 +288,26 @@ def update(ctx):
plugin_folder = Path(ctx.config._project_prefix)
plugin_id = plugin_folder.name
plugin_location = plugin_folder.parent

# Delete previously generated file, if present
generated_hook_file = plugin_folder / "src" / "hook_specs.h"
if generated_hook_file.is_file():
generated_hook_file.unlink()

# Generate updated hook specs file
hm.generate_hook_specs_header(plugin_id, plugin_location)
if generated_hook_file.is_file():
print_message(
"Successfully updated alfasim-sdk's files", color=Fore.GREEN, bright=True
)
else: # pragma: no cover (not reachable using mock)
print_message(
"Failed to update alfasim-sdk's files", color=Fore.RED, bright=True
)
raise Exit(message=None, code=1) # `code != 0`


@task()
@sdk_task()
def install_plugin(ctx, install_dir=None):
r"""
Install a plugin to ``install_dir`` folder.
Expand Down Expand Up @@ -316,7 +365,7 @@ def install_plugin(ctx, install_dir=None):
)


@task()
@sdk_task()
def uninstall_plugin(ctx, install_dir=None):
r"""
Remove plugin from install folder.
Expand Down Expand Up @@ -354,7 +403,7 @@ def uninstall_plugin(ctx, install_dir=None):
)


@task()
@sdk_task()
def reinstall_plugin(ctx, package_name, install_dir=None):
r"""
Remove, package and install an specified plugin to install_dir
Expand All @@ -372,7 +421,7 @@ def reinstall_plugin(ctx, package_name, install_dir=None):
install_plugin(ctx, install_dir=install_dir)


@task
@sdk_task
def clean(ctx):
"""Remove all build folders and .hmplugin files from plugin root and children folders"""

Expand All @@ -381,7 +430,7 @@ def clean(ctx):
_remove_hmplugin_files(plugin_folder)


@task
@sdk_task
def msvc(ctx):
"""Create a MSVC solution file for the plugin."""
if sys.platform != "win32":
Expand Down
3 changes: 2 additions & 1 deletion tests/test_invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def test_update_task(new_plugin_dir: Path, monkeypatch: MonkeyPatch):
[f"{invoke_cmd}", "update"],
capture_output=True,
)
assert result.stdout.decode("utf-8") == ""
success_message_chunk = "Successfully updated alfasim-sdk's files"
assert success_message_chunk in result.stdout.decode("utf-8")
assert result.stderr.decode("utf-8") == ""
assert plugin_hook_spec_h_path.stat().st_size > 0
assert plugin_hook_spec_h_path.read_text(encoding="utf-8") != ""
Expand Down

0 comments on commit 78063d9

Please sign in to comment.