Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #2229/f2bd35b1 backport][stable-9] aws_ssm - refactor _prepare_terminal() method #2246

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor ``_prepare_terminal()`` Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/).
111 changes: 56 additions & 55 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@
import string
import subprocess
import time
from typing import Optional
from typing import NoReturn
from typing import Optional
from typing import Tuple

try:
Expand Down Expand Up @@ -609,7 +609,7 @@ def instance_id(self) -> str:
return self._instance_id

@instance_id.setter
def instance_id(self, instance_id: str) -> NoReturn:
def instance_id(self, instance_id: str) -> None:
self._instance_id = instance_id

def start_session(self):
Expand Down Expand Up @@ -646,7 +646,7 @@ def start_session(self):
self._vvvv(f"SSM COMMAND: {to_text(cmd)}")

stdout_r, stdout_w = pty.openpty()
session = subprocess.Popen(
self._session = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=stdout_w,
Expand All @@ -657,14 +657,13 @@ def start_session(self):

os.close(stdout_w)
self._stdout = os.fdopen(stdout_r, "rb", 0)
self._session = session

# Disable command echo and prompt.
# For non-windows Hosts: Ensure the session has started, and disable command echo and prompt.
self._prepare_terminal()

self._vvvv(f"SSM CONNECTION ID: {self._session_id}")
self._vvvv(f"SSM CONNECTION ID: {self._session_id}") # pylint: disable=unreachable

return session
return self._session

def poll_stdout(self, timeout: int = 1000) -> bool:
"""Polls the stdout file descriptor.
Expand Down Expand Up @@ -767,72 +766,74 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->

return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)

def _prepare_terminal(self):
"""perform any one-time terminal settings"""
# No windows setup for now
if self.is_windows:
return

# *_complete variables are 3 valued:
# - None: not started
# - False: started
# - True: complete

startup_complete = False
disable_echo_complete = None
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")
def _ensure_ssm_session_has_started(self) -> None:
"""Ensure the SSM session has started on the host. We poll stdout
until we match the following string 'Starting session with SessionId'
"""
stdout = ""
for poll_result in self.poll("START SSM SESSION", "start_session"):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"START SSM SESSION stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self._vvvv("START SSM SESSION startup output received")
break

disable_prompt_complete = None
end_mark = self.generate_mark()
def _disable_prompt_command(self) -> None:
"""Disable prompt command from the host"""
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
)
disable_prompt_reply = re.compile(r"\r\r\n" + re.escape(end_mark) + r"\r\r\n", re.MULTILINE)

# Send command
self._vvvv(f"DISABLE PROMPT Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)

stdout = ""
# Custom command execution for when we're waiting for startup
for poll_result in self.poll("PRE", "start_session"):
if disable_prompt_complete:
break
for poll_result in self.poll("DISABLE PROMPT", disable_prompt_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}")
self._vvvv(f"DISABLE PROMPT stdout line: \n{to_bytes(stdout)}")
if disable_prompt_reply.search(stdout):
break

# wait til prompt is ready
if startup_complete is False:
match = str(stdout).find("Starting session with SessionId")
if match != -1:
self._vvvv("PRE startup output received")
startup_complete = True
def _disable_echo_command(self) -> None:
"""Disable echo command from the host"""
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

# disable echo
if startup_complete and (disable_echo_complete is None):
self._vvvv(f"PRE Disabling Echo: {disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)
disable_echo_complete = False
# Send command
self._vvvv(f"DISABLE ECHO Disabling Prompt: \n{disable_echo_cmd}")
self._session.stdin.write(disable_echo_cmd)

if disable_echo_complete is False:
stdout = ""
for poll_result in self.poll("DISABLE ECHO", disable_echo_cmd):
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"DISABLE ECHO stdout line: \n{to_bytes(stdout)}")
match = str(stdout).find("stty -echo")
if match != -1:
disable_echo_complete = True
break

# disable prompt
if disable_echo_complete and disable_prompt_complete is None:
self._vvvv(f"PRE Disabling Prompt: \n{disable_prompt_cmd}")
self._session.stdin.write(disable_prompt_cmd)
disable_prompt_complete = False
def _prepare_terminal(self) -> None:
"""perform any one-time terminal settings"""
# No Windows setup for now
if self.is_windows:
return

if disable_prompt_complete is False:
match = disable_prompt_reply.search(stdout)
if match:
stdout = stdout[match.end():] # fmt: skip
disable_prompt_complete = True
# Ensure SSM Session has started
self._ensure_ssm_session_has_started()

# see https://github.com/pylint-dev/pylint/issues/8909)
if not disable_prompt_complete: # pylint: disable=unreachable
raise AnsibleConnectionFailure(f"SSM process closed during _prepare_terminal on host: {self.instance_id}")
self._vvvv("PRE Terminal configured")
# Disable echo command
self._disable_echo_command() # pylint: disable=unreachable

# Disable prompt command
self._disable_prompt_command() # pylint: disable=unreachable

self._vvvv("PRE Terminal configured") # pylint: disable=unreachable

def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def connection_init(*args, **kwargs):
connection._session = MagicMock()
connection._session.poll = MagicMock()
connection._session.poll.side_effect = lambda: None
connection._session.stdin = MagicMock()
connection._session.stdin.write = MagicMock()
connection._stdout = MagicMock()
connection._flush_stderr = MagicMock()

Expand Down
122 changes: 122 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/test_prepare_terminal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-

# This file is part of Ansible
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

# While it may seem appropriate to import our custom fixtures here, the pytest_ansible pytest plugin
# isn't as agressive as the ansible_test._util.target.pytest.plugins.ansible_pytest_collections plugin
# when it comes to rewriting the import paths and as such we can't import fixtures via their
# absolute import path or across collections.


from unittest.mock import MagicMock
from unittest.mock import patch

import pytest

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

if not HAS_BOTO3:
pytestmark = pytest.mark.skip("test_poll.py requires the python modules 'boto3' and 'botocore'")


def poll_mock(x, y):
while poll_mock.results:
yield poll_mock.results.pop(0)
raise TimeoutError("-- poll_stdout_mock() --- Process has timeout...")


@pytest.mark.parametrize(
"stdout_lines,timeout_failure",
[
(["Starting ", "session ", "with SessionId"], False),
(["Starting session", " with SessionId"], False),
(["Init - Starting", " session", " with SessionId"], False),
(["Starting", " session", " with SessionId "], False),
(["Starting ", "session"], True),
(["Starting ", "session with Session"], True),
(["session ", "with SessionId"], True),
],
)
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
def test_ensure_ssm_session_has_started(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = str
connection_aws_ssm._stdout.read = MagicMock()

connection_aws_ssm._stdout.read.side_effect = stdout_lines

poll_mock.results = [True for i in range(len(stdout_lines))]
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._ensure_ssm_session_has_started()
else:
connection_aws_ssm._ensure_ssm_session_has_started()


@pytest.mark.parametrize(
"stdout_lines,timeout_failure",
[
(["stty -echo"], False),
(["stty ", "-echo"], False),
(["stty"], True),
(["stty ", "-ech"], True),
],
)
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
def test_disable_echo_command(m_to_text, m_to_bytes, connection_aws_ssm, stdout_lines, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
connection_aws_ssm._stdout.read = MagicMock()

connection_aws_ssm._stdout.read.side_effect = stdout_lines

poll_mock.results = [True for i in range(len(stdout_lines))]
connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_echo_command()
else:
connection_aws_ssm._disable_echo_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with("stty -echo\n")


@pytest.mark.parametrize("timeout_failure", [True, False])
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.random")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_bytes")
@patch("ansible_collections.community.aws.plugins.connection.aws_ssm.to_text")
def test_disable_prompt_command(m_to_text, m_to_bytes, m_random, connection_aws_ssm, timeout_failure):
m_to_text.side_effect = str
m_to_bytes.side_effect = lambda x, **kw: str(x)
connection_aws_ssm._stdout.read = MagicMock()

connection_aws_ssm.poll = MagicMock()
connection_aws_ssm.poll.side_effect = poll_mock

m_random.choice = MagicMock()
m_random.choice.side_effect = lambda x: "a"

end_mark = "".join(["a" for i in range(connection_aws_ssm.MARK_LENGTH)])

connection_aws_ssm._stdout.read.return_value = (
f"\r\r\n{end_mark}\r\r\n" if not timeout_failure else "unmatching value"
)
poll_mock.results = [True]

prompt_cmd = f"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '{end_mark}'\n"

if timeout_failure:
with pytest.raises(TimeoutError):
connection_aws_ssm._disable_prompt_command()
else:
connection_aws_ssm._disable_prompt_command()

connection_aws_ssm._session.stdin.write.assert_called_once_with(prompt_cmd)
Loading