Skip to content

Commit

Permalink
Add placeholder for int. tests for mila serve
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Mar 26, 2024
1 parent 3ba4b40 commit eed46e7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
8 changes: 4 additions & 4 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def mila():
help="Path to open on the remote machine",
)
_add_standard_server_args(serve_notebook_parser)
serve_notebook_parser.set_defaults(function=notebook)
serve_notebook_parser.set_defaults(function=serve_notebook)

# ----- mila serve tensorboard ------

Expand All @@ -321,7 +321,7 @@ def mila():
"LOGDIR", type=str, help="Path to the experiment logs"
)
_add_standard_server_args(serve_tensorboard_parser)
serve_tensorboard_parser.set_defaults(function=tensorboard)
serve_tensorboard_parser.set_defaults(function=serve_tensorboard)

# ----- mila serve mlflow ------

Expand Down Expand Up @@ -595,7 +595,7 @@ def lab(path: str | None, **kwargs: Unpack[StandardServerArgs]):
)


def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]):
def serve_notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]):
"""Start a Jupyter Notebook server.
Arguments:
Expand All @@ -618,7 +618,7 @@ def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]):
)


def tensorboard(logdir: str, **kwargs: Unpack[StandardServerArgs]):
def serve_tensorboard(logdir: str, **kwargs: Unpack[StandardServerArgs]):
"""Start a Tensorboard server.
Arguments:
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/test_serve_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

import pytest

from milatools.cli.common import standard_server_v1, standard_server_v2

from ..conftest import launches_jobs

persist_or_not = pytest.mark.parametrize("persist", [True, False])


@launches_jobs
@persist_or_not
def test_standard_server_v1(cluster: str, allocation_flags: list[str], persist: bool):
if cluster != "mila":
pytest.skip(reason="Needs to be run on the Mila cluster for now.")
standard_server_v1(
path="bob",
program="echo",
installers={},
command="ls",
profile=None,
persist=persist,
port=None,
name=None,
node=None,
job=None,
alloc=allocation_flags, # : list[str],
)
raise NotImplementedError("TODO: Add checks in this test.")


@launches_jobs
@persist_or_not
def test_standard_server_v2(cluster: str, allocation_flags: list[str], persist: bool):
raise NotImplementedError("TODO: Design this test.")
standard_server_v2(
path="bob",
program="echo",
installers={},
command="ls",
profile=None,
persist=persist,
port=None,
name=None,
node=None,
job=None,
alloc=allocation_flags,
cluster=cluster, # type: ignore
)


@launches_jobs
@persist_or_not
def test_mila_serve_notebook(cluster: str, allocation_flags: list[str]):
raise NotImplementedError("TODO: Design this test.")

0 comments on commit eed46e7

Please sign in to comment.