Skip to content

Commit

Permalink
connection/aws_ssm - refactor exec_command function to improve mainta…
Browse files Browse the repository at this point in the history
…nability (#2226) (#2231)

This is a backport of PR #2226 as merged into main (d84f3a1).
SUMMARY

Refer to https://issues.redhat.com/browse/ACA-2093
Refactor exec_command() and add unit tests

ISSUE TYPE


Feature Pull Request

COMPONENT NAME

connection/aws_ssm

Reviewed-by: Bikouo Aubin
  • Loading branch information
patchback[bot] authored Feb 4, 2025
1 parent aa4c8fc commit b7690aa
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor exec_command Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/2224).
199 changes: 120 additions & 79 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@
name: nginx
state: present
"""
import os
import getpass
import json
import os
import pty
import random
import re
Expand All @@ -296,6 +296,8 @@
import subprocess
import time
from typing import Optional
from typing import NoReturn
from typing import Tuple

try:
import boto3
Expand All @@ -304,18 +306,19 @@
pass

from functools import wraps
from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible.errors import AnsibleConnectionFailure
from ansible.errors import AnsibleError
from ansible.errors import AnsibleFileNotFound
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.six.moves import xrange
from ansible.module_utils._text import to_bytes
from ansible.module_utils._text import to_text
from ansible.module_utils.basic import missing_required_lib
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import _common_args
from ansible.utils.display import Display

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

display = Display()


Expand Down Expand Up @@ -375,6 +378,29 @@ def chunks(lst, n):
yield lst[i:i + n] # fmt: skip


def filter_ansi(line: str, is_windows: bool) -> str:
"""Remove any ANSI terminal control codes.
:param line: The input line.
:param is_windows: Whether the output is coming from a Windows host.
:returns: The result line.
"""
line = to_text(line)

if is_windows:
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
line = osc_filter.sub("", line)
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
line = ansi_filter.sub("", line)

# Replace or strip sequence (at terminal width)
line = line.replace("\r\r\n", "\n")
if len(line) == 201:
line = line[:-1]

return line


class Connection(ConnectionBase):
"""AWS SSM based connections"""

Expand All @@ -401,6 +427,9 @@ def __init__(self, *args, **kwargs):
raise AnsibleError(missing_required_lib("boto3"))

self.host = self._play_context.remote_addr
self._instance_id = None
self._polling_obj = None
self._has_timeout = False

if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
self.delegate = None
Expand Down Expand Up @@ -542,14 +571,19 @@ def reset(self):
self.close()
return self.start_session()

@property
def instance_id(self) -> str:
if not self._instance_id:
self._instance_id = self.host if self.get_option("instance_id") is None else self.get_option("instance_id")
return self._instance_id

@instance_id.setter
def instance_id(self, instance_id: str) -> NoReturn:
self._instance_id = instance_id

def start_session(self):
"""start ssm session"""

if self.get_option("instance_id") is None:
self.instance_id = self.host
else:
self.instance_id = self.get_option("instance_id")

self._vvv(f"ESTABLISH SSM CONNECTION TO: {self.instance_id}")

executable = self.get_option("plugin")
Expand Down Expand Up @@ -593,8 +627,6 @@ def start_session(self):
os.close(stdout_w)
self._stdout = os.fdopen(stdout_r, "rb", 0)
self._session = session
self._poll_stdout = select.poll()
self._poll_stdout.register(self._stdout, select.POLLIN)

# Disable command echo and prompt.
self._prepare_terminal()
Expand All @@ -603,49 +635,56 @@ def start_session(self):

return session

@_ssm_retry
def exec_command(self, cmd, in_data=None, sudoable=True):
"""run a command on the ssm host"""

super().exec_command(cmd, in_data=in_data, sudoable=sudoable)

self._vvv(f"EXEC: {to_text(cmd)}")

session = self._session

mark_begin = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
if self.is_windows:
mark_start = mark_begin + " $LASTEXITCODE"
else:
mark_start = mark_begin
mark_end = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
def poll_stdout(self, timeout: int = 1000) -> bool:
"""Polls the stdout file descriptor.
# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, sudoable, mark_start, mark_end)

self._flush_stderr(session)
:param timeout: Specifies the length of time in milliseconds which the system will wait.
:returns: A boolean to specify the polling result
"""
if self._polling_obj is None:
self._polling_obj = select.poll()
self._polling_obj.register(self._stdout, select.POLLIN)
return bool(self._polling_obj.poll(timeout))

for chunk in chunks(cmd, 1024):
session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict"))
def poll(self, label: str, cmd: str) -> NoReturn:
"""Poll session to retrieve content from stdout.
:param label: A label for the display (EXEC, PRE...)
:param cmd: The command being executed
"""
start = round(time.time())
yield self.poll_stdout()
timeout = self.get_option("ssm_timeout")
while self._session.poll() is None:
remaining = start + timeout - round(time.time())
self._vvvv(f"{label} remaining: {remaining} second(s)")
if remaining < 0:
self._has_timeout = True
raise AnsibleConnectionFailure(f"{label} command '{cmd}' timeout on host: {self.instance_id}")
yield self.poll_stdout()

def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]:
"""Interact with session.
Read stdout between the markers until 'mark_end' is reached.
:param cmd: The command being executed.
:param mark_start: The marker which starts the output.
:param mark_begin: The begin marker.
:param mark_end: The end marker.
:returns: A tuple with the return code, the stdout and the stderr content.
"""
# Read stdout between the markers
stdout = ""
win_line = ""
begin = False
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
while session.poll() is None:
remaining = stop_time - int(round(time.time()))
if remaining < 1:
self._timeout = True
self._vvvv(f"EXEC timeout stdout: \n{to_text(stdout)}")
raise AnsibleConnectionFailure(f"SSM exec_command timeout on host: {self.instance_id}")
if self._poll_stdout.poll(1000):
line = self._filter_ansi(self._stdout.readline())
self._vvvv(f"EXEC stdout line: \n{to_text(line)}")
else:
self._vvvv(f"EXEC remaining: {remaining}")
returncode = None
for poll_result in self.poll("EXEC", cmd):
if not poll_result:
continue

line = filter_ansi(self._stdout.readline(), self.is_windows)
self._vvvv(f"EXEC stdout line: \n{line}")

if not begin and self.is_windows:
win_line = win_line + line
line = win_line
Expand All @@ -663,9 +702,33 @@ def exec_command(self, cmd, in_data=None, sudoable=True):
break
stdout = stdout + line

stderr = self._flush_stderr(session)
# see https://github.com/pylint-dev/pylint/issues/8909)
return (returncode, stdout, self._flush_stderr(self._session)) # pylint: disable=unreachable

@_ssm_retry
def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> Tuple[int, str, str]:
"""run a command on the ssm host"""

super().exec_command(cmd, in_data=in_data, sudoable=sudoable)

return (returncode, stdout, stderr)
self._vvv(f"EXEC: {to_text(cmd)}")

mark_begin = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
if self.is_windows:
mark_start = mark_begin + " $LASTEXITCODE"
else:
mark_start = mark_begin
mark_end = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])

# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, mark_start, mark_end)

self._flush_stderr(self._session)

for chunk in chunks(cmd, 1024):
self._session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict"))

return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)

def _prepare_terminal(self):
"""perform any one-time terminal settings"""
Expand All @@ -683,7 +746,7 @@ def _prepare_terminal(self):
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

disable_prompt_complete = None
end_mark = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
Expand All @@ -692,18 +755,12 @@ def _prepare_terminal(self):

stdout = ""
# Custom command execution for when we're waiting for startup
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
while (not disable_prompt_complete) and (self._session.poll() is None):
remaining = stop_time - int(round(time.time()))
if remaining < 1:
self._timeout = True
self._vvvv(f"PRE timeout stdout: \n{to_bytes(stdout)}")
raise AnsibleConnectionFailure(f"SSM start_session timeout on host: {self.instance_id}")
if self._poll_stdout.poll(1000):
for poll_result in self.poll("PRE", "start_session"):
if disable_prompt_complete:
break
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}")
else:
self._vvvv(f"PRE remaining: {remaining}")

# wait til prompt is ready
if startup_complete is False:
Expand Down Expand Up @@ -735,12 +792,13 @@ def _prepare_terminal(self):
stdout = stdout[match.end():] # fmt: skip
disable_prompt_complete = True

if not disable_prompt_complete:
# see https://github.com/pylint-dev/pylint/issues/8909)
if not disable_prompt_complete: # pylint: disable=unreachable
raise AnsibleConnectionFailure(f"SSM process closed during _prepare_terminal on host: {self.instance_id}")
self._vvvv("PRE Terminal configured")

def _wrap_command(self, cmd, sudoable, mark_start, mark_end):
"""wrap command so stdout and status can be extracted"""
def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""

if self.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
Expand Down Expand Up @@ -790,23 +848,6 @@ def _post_process(self, stdout, mark_begin):

return (returncode, stdout)

def _filter_ansi(self, line):
"""remove any ANSI terminal control codes"""
line = to_text(line)

if self.is_windows:
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
line = osc_filter.sub("", line)
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
line = ansi_filter.sub("", line)

# Replace or strip sequence (at terminal width)
line = line.replace("\r\r\n", "\n")
if len(line) == 201:
line = line[:-1]

return line

def _flush_stderr(self, session_process):
"""read and return stderr with minimal blocking"""

Expand Down Expand Up @@ -996,7 +1037,7 @@ def close(self):
"""terminate the connection"""
if self._session_id:
self._vvv(f"CLOSING SSM CONNECTION TO: {self.instance_id}")
if self._timeout:
if self._has_timeout:
self._session.terminate()
else:
cmd = b"\nexit\n"
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-

# This file is part of Ansible
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from unittest.mock import MagicMock

import pytest

from ansible_collections.community.aws.plugins.connection.aws_ssm import Connection
from ansible_collections.community.aws.plugins.connection.aws_ssm import ConnectionBase


@pytest.fixture(name="connection_aws_ssm")
def fixture_connection_aws_ssm():
play_context = MagicMock()
play_context.shell = True

def connection_init(*args, **kwargs):
pass

Connection.__init__ = connection_init
ConnectionBase.exec_command = MagicMock()
connection = Connection()

connection._instance_id = "i-0a1b2c3d4e5f"
connection._polling_obj = None
connection._has_timeout = False
connection.is_windows = False

connection.poll_stdout = MagicMock()
connection._session = MagicMock()
connection._session.poll = MagicMock()
connection._session.poll.side_effect = lambda: None
connection._stdout = MagicMock()
connection._flush_stderr = MagicMock()

def display_msg(msg):
print("--- AWS SSM CONNECTION --- ", msg)

connection._v = MagicMock()
connection._v.side_effect = display_msg

connection._vv = MagicMock()
connection._vv.side_effect = display_msg

connection._vvv = MagicMock()
connection._vvv.side_effect = display_msg

connection._vvvv = MagicMock()
connection._vvvv.side_effect = display_msg

connection.get_option = MagicMock()
return connection
Loading

0 comments on commit b7690aa

Please sign in to comment.