-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
- Loading branch information
Showing
6 changed files
with
103 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from pathlib import Path | ||
|
||
from paramiko import SSHConfig | ||
|
||
from milatools.cli import console | ||
from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2 | ||
|
||
|
||
async def login( | ||
ssh_config_path: Path = SSH_CONFIG_FILE, | ||
) -> list[RemoteV2]: | ||
"""Logs in and sets up reusable SSH connections to all the hosts in the SSH config. | ||
Returns the list of remotes where the connection was successfully established. | ||
""" | ||
ssh_config = SSHConfig.from_path(str(ssh_config_path.expanduser())) | ||
potential_clusters = [ | ||
host | ||
for host in ssh_config.get_hostnames() | ||
if not any(c in host for c in ["*", "?", "!"]) | ||
] | ||
# take out entries like `mila-cpu` that have a proxy and remote command. | ||
potential_clusters = [ | ||
hostname | ||
for hostname in potential_clusters | ||
if not ( | ||
(config := ssh_config.lookup(hostname)).get("proxycommand") | ||
and config.get("remotecommand") | ||
) | ||
] | ||
remotes = await asyncio.gather( | ||
*( | ||
RemoteV2.connect(hostname, ssh_config_path=ssh_config_path) | ||
for hostname in potential_clusters | ||
), | ||
return_exceptions=True, | ||
) | ||
remotes = [remote for remote in remotes if isinstance(remote, RemoteV2)] | ||
console.log(f"Successfully connected to {[remote.hostname for remote in remotes]}") | ||
return remotes | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(login()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
usage: mila [-h] [--version] [-v] | ||
{docs,intranet,init,forward,code,sync,serve} ... | ||
{docs,intranet,init,login,forward,code,sync,serve} ... | ||
mila: error: the following arguments are required: <command> |
4 changes: 2 additions & 2 deletions
4
tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
usage: mila [-h] [--version] [-v] | ||
{docs,intranet,init,forward,code,sync,serve} ... | ||
mila: error: argument <command>: invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'forward', 'code', 'sync', 'serve') | ||
{docs,intranet,init,login,forward,code,sync,serve} ... | ||
mila: error: argument <command>: invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'login', 'forward', 'code', 'sync', 'serve') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import textwrap | ||
from logging import getLogger as get_logger | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from milatools.cli.login import login | ||
from milatools.utils.remote_v2 import SSH_CACHE_DIR, RemoteV2 | ||
|
||
from .common import requires_ssh_to_localhost | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
@requires_ssh_to_localhost | ||
@pytest.mark.asyncio | ||
async def test_login(tmp_path: Path): # ssh_config_file: Path): | ||
assert SSH_CACHE_DIR.exists() | ||
ssh_config_path = tmp_path / "ssh_config" | ||
ssh_config_path.write_text( | ||
textwrap.dedent( | ||
"""\ | ||
Host foo | ||
hostname localhost | ||
Host bar | ||
hostname localhost | ||
""" | ||
) | ||
+ "\n" | ||
) | ||
|
||
# Should create a connection to every host in the ssh config file. | ||
remotes = await login(ssh_config_path=ssh_config_path) | ||
assert all(isinstance(remote, RemoteV2) for remote in remotes) | ||
assert set(remote.hostname for remote in remotes) == {"foo", "bar"} | ||
for remote in remotes: | ||
logger.info(f"Removing control socket at {remote.control_path}") | ||
remote.control_path.unlink() |