Skip to content

Commit

Permalink
Start to support RemoteV2 in standard server
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Mar 26, 2024
1 parent d09d4d3 commit 186050a
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 62 deletions.
34 changes: 31 additions & 3 deletions milatools/cli/code_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from milatools.cli import console
from milatools.cli.common import (
check_disk_quota,
find_allocation,
)
from milatools.cli.init_command import DRAC_CLUSTERS
from milatools.cli.local import Local
from milatools.cli.remote import Remote
from milatools.cli.remote import Remote, SlurmRemote
from milatools.cli.utils import (
CLUSTERS,
Cluster,
CommandNotFoundError,
MilatoolsUserError,
SortingHelpFormatter,
cluster_to_connect_kwargs,
currently_in_a_test,
get_fully_qualified_hostname_of_compute_node,
make_process,
Expand Down Expand Up @@ -211,6 +211,34 @@ async def code(
return


def find_allocation_v1(
remote: Remote,
node: str | None,
job: int | None,
alloc: list[str],
cluster: Cluster = "mila",
job_name: str = "mila-tools",
):
if (node is not None) + (job is not None) + bool(alloc) > 1:
exit("ERROR: --node, --job and --alloc are mutually exclusive")

if node is not None:
node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster)
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))

elif job is not None:
node_name = remote.get_output(f"squeue --jobs {job} -ho %N")
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))

else:
alloc = ["-J", job_name, *alloc]
return SlurmRemote(
connection=remote.connection,
alloc=alloc,
hostname=remote.hostname,
)


@deprecated(
"Support for the `mila code` command is now deprecated on Windows machines, as it "
"does not support ssh keys with passphrases or clusters where 2FA is enabled. "
Expand Down Expand Up @@ -288,7 +316,7 @@ def code_v1(
)

if node is None:
cnode = find_allocation(
cnode = find_allocation_v1(
remote,
job_name="mila-code",
job=job,
Expand Down
12 changes: 6 additions & 6 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from ..__version__ import __version__
from .code_command import add_mila_code_arguments
from .common import forward, standard_server
from .common import forward, standard_server_v1
from .init_command import (
print_welcome_message,
setup_keys_on_login_node,
Expand Down Expand Up @@ -581,7 +581,7 @@ def lab(path: str | None, **kwargs: Unpack[StandardServerArgs]):
if path and path.endswith(".ipynb"):
exit("Only directories can be given to the mila serve lab command")

standard_server(
standard_server_v1(
path,
program="jupyter-lab",
installers={
Expand All @@ -604,7 +604,7 @@ def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]):
if path and path.endswith(".ipynb"):
exit("Only directories can be given to the mila serve notebook command")

standard_server(
standard_server_v1(
path,
program="jupyter-notebook",
installers={
Expand All @@ -625,7 +625,7 @@ def tensorboard(logdir: str, **kwargs: Unpack[StandardServerArgs]):
logdir: Path to the experiment logs
"""

standard_server(
standard_server_v1(
logdir,
program="tensorboard",
installers={
Expand All @@ -645,7 +645,7 @@ def mlflow(logdir: str, **kwargs: Unpack[StandardServerArgs]):
logdir: Path to the experiment logs
"""

standard_server(
standard_server_v1(
logdir,
program="mlflow",
installers={
Expand All @@ -663,7 +663,7 @@ def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]):
Arguments:
logdir: Path to the experiment logs
"""
standard_server(
standard_server_v1(
logdir,
program="aim",
installers={
Expand Down
210 changes: 179 additions & 31 deletions milatools/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from urllib.parse import urlencode

import invoke
import questionary as qn
from rich.text import Text

Expand All @@ -21,7 +22,6 @@
Cluster,
MilatoolsUserError,
T,
cluster_to_connect_kwargs,
get_fully_qualified_hostname_of_compute_node,
randname,
with_control_file,
Expand Down Expand Up @@ -159,34 +159,6 @@ def get_colour(used: float, max: float) -> str:
logger.warning(UserWarning(warning_message))


def find_allocation(
remote: Remote,
node: str | None,
job: int | None,
alloc: list[str],
cluster: Cluster = "mila",
job_name: str = "mila-tools",
):
if (node is not None) + (job is not None) + bool(alloc) > 1:
exit("ERROR: --node, --job and --alloc are mutually exclusive")

if node is not None:
node_name = get_fully_qualified_hostname_of_compute_node(node, cluster=cluster)
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))

elif job is not None:
node_name = remote.get_output(f"squeue --jobs {job} -ho %N")
return Remote(node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster))

else:
alloc = ["-J", job_name, *alloc]
return SlurmRemote(
connection=remote.connection,
alloc=alloc,
hostname=remote.hostname,
)


def forward(
local: Local,
node: str,
Expand Down Expand Up @@ -256,7 +228,7 @@ def forward(
return proc, port


def standard_server(
def standard_server_v1(
path: str | None,
*,
program: str,
Expand Down Expand Up @@ -321,7 +293,183 @@ def standard_server(
):
exit(f"Exit: {program} is not installed.")

cnode = find_allocation(
from milatools.cli.code_command import find_allocation_v1

cnode = find_allocation_v1(
remote,
job_name=f"mila-serve-{program}",
node=node,
job=job,
alloc=alloc,
cluster="mila",
)

patterns = {
"node_name": "#### ([A-Za-z0-9_-]+)",
}

if port_pattern:
patterns["port"] = port_pattern
elif share:
exit(
"Server cannot be shared because it is serving over a Unix domain "
"socket"
)
else:
remote.run("mkdir -p ~/.milatools/sockets", hide=True)

if share:
host = "0.0.0.0"
else:
host = "localhost"

sock_name = name or randname()
command = command.format(
path=path,
sock=f"~/.milatools/sockets/{sock_name}.sock",
host=host,
)

if token_pattern:
patterns["token"] = token_pattern

if persist:
cnode = cnode.persist()

proc, results = (
cnode.with_profile(prof)
.with_precommand("echo '####' $(hostname)")
.extract(
command,
patterns=patterns,
)
)
node_name = results["node_name"]

if port_pattern:
to_forward = int(results["port"])
else:
to_forward = f"{remote.home()}/.milatools/sockets/{sock_name}.sock"

if cf is not None:
remote.simple_run(f"echo program = {program} >> {cf}")
remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}")
remote.simple_run(f"echo host = {host} >> {cf}")
remote.simple_run(f"echo to_forward = {to_forward} >> {cf}")
if token_pattern:
remote.simple_run(f"echo token = {results['token']} >> {cf}")

assert results is not None
assert node_name is not None
assert to_forward is not None
assert proc is not None
if token_pattern:
options = {"token": results["token"]}
else:
options = {}

local_proc, local_port = forward(
local=Local(),
node=get_fully_qualified_hostname_of_compute_node(node_name, cluster="mila"),
to_forward=to_forward,
options=options,
port=port,
)

if cf is not None:
remote.simple_run(f"echo local_port = {local_port} >> {cf}")

try:
local_proc.wait()
except KeyboardInterrupt:
qn.print("Terminated by user.")
if cf is not None:
name = Path(cf).name
qn.print("To reconnect to this server, use the command:")
qn.print(f" mila serve connect {name}", style="bold yellow")
qn.print("To kill this server, use the command:")
qn.print(f" mila serve kill {name}", style="bold red")
finally:
local_proc.kill()
proc.kill()


def standard_server_v2(
path: str | None,
*,
program: str,
installers: dict[str, str],
command: str,
profile: str | None,
persist: bool,
port: int | None,
name: str | None,
node: str | None,
job: int | None,
alloc: list[str],
port_pattern=None,
token_pattern=None,
cluster: Cluster = "mila",
):
# Make the server visible from the login node (other users will be able to connect)
# Temporarily disabled
share = False

if name is not None:
persist = True
elif persist:
name = program

remote = RemoteV2(cluster)

path = path or "~"
if path == "~" or path.startswith("~/"):
path = remote.get_output("echo $HOME", display=False, hide=True) + path[1:]

results: dict | None = None
node_name: str | None = None
to_forward: int | str | None = None
cf: str | None = None
proc = None
raise NotImplementedError("TODO: adapt the rest of this to work with RemoteV2")

with ExitStack() as stack:
if persist:
cf = stack.enter_context(with_control_file(remote, name=name))
else:
cf = None

if profile:
prof = f"~/.milatools/profiles/{profile}.bash"
else:
prof = setup_profile(remote, path)

qn.print(f"Using profile: {prof}")
cat_result = remote.run(f"cat {prof}", hide=True, warn=True)
if (
isinstance(cat_result, invoke.runners.Result)
and cat_result.return_code == 0
) or (
isinstance(cat_result, subprocess.CompletedProcess)
and cat_result.returncode == 0
):
qn.print("=" * 50)
qn.print(cat_result.stdout.rstrip())
qn.print("=" * 50)
else:
exit(f"Could not find or load profile: {prof}")

premote = remote.with_profile(prof)

if not ensure_program(
remote=premote,
program=program,
installers=installers,
):
exit(f"Exit: {program} is not installed.")
from milatools.cli.code_command import find_allocation_v1

cnode = find_allocation_v1(
remote,
job_name=f"mila-serve-{program}",
node=node,
Expand Down
Loading

0 comments on commit 186050a

Please sign in to comment.