Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR #2226/d84f3a16 backport][stable-9] connection/aws_ssm - refactor exec_command function to improve maintanability #2231

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor exec_command Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/2224).
199 changes: 120 additions & 79 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@
name: nginx
state: present
"""
import os
import getpass
import json
import os
import pty
import random
import re
Expand All @@ -296,6 +296,8 @@
import subprocess
import time
from typing import Optional
from typing import NoReturn
from typing import Tuple

try:
import boto3
Expand All @@ -304,18 +306,19 @@
pass

from functools import wraps
from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible.errors import AnsibleConnectionFailure
from ansible.errors import AnsibleError
from ansible.errors import AnsibleFileNotFound
from ansible.module_utils.basic import missing_required_lib
from ansible.module_utils.six.moves import xrange
from ansible.module_utils._text import to_bytes
from ansible.module_utils._text import to_text
from ansible.module_utils.basic import missing_required_lib
from ansible.plugins.connection import ConnectionBase
from ansible.plugins.shell.powershell import _common_args
from ansible.utils.display import Display

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

display = Display()


Expand Down Expand Up @@ -375,6 +378,29 @@ def chunks(lst, n):
yield lst[i:i + n] # fmt: skip


def filter_ansi(line: str, is_windows: bool) -> str:
"""Remove any ANSI terminal control codes.

:param line: The input line.
:param is_windows: Whether the output is coming from a Windows host.
:returns: The result line.
"""
line = to_text(line)

if is_windows:
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
line = osc_filter.sub("", line)
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
line = ansi_filter.sub("", line)

# Replace or strip sequence (at terminal width)
line = line.replace("\r\r\n", "\n")
if len(line) == 201:
line = line[:-1]

return line


class Connection(ConnectionBase):
"""AWS SSM based connections"""

Expand All @@ -401,6 +427,9 @@ def __init__(self, *args, **kwargs):
raise AnsibleError(missing_required_lib("boto3"))

self.host = self._play_context.remote_addr
self._instance_id = None
self._polling_obj = None
self._has_timeout = False

if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
self.delegate = None
Expand Down Expand Up @@ -542,14 +571,19 @@ def reset(self):
self.close()
return self.start_session()

@property
def instance_id(self) -> str:
if not self._instance_id:
self._instance_id = self.host if self.get_option("instance_id") is None else self.get_option("instance_id")
return self._instance_id

@instance_id.setter
def instance_id(self, instance_id: str) -> NoReturn:
self._instance_id = instance_id

def start_session(self):
"""start ssm session"""

if self.get_option("instance_id") is None:
self.instance_id = self.host
else:
self.instance_id = self.get_option("instance_id")

self._vvv(f"ESTABLISH SSM CONNECTION TO: {self.instance_id}")

executable = self.get_option("plugin")
Expand Down Expand Up @@ -593,8 +627,6 @@ def start_session(self):
os.close(stdout_w)
self._stdout = os.fdopen(stdout_r, "rb", 0)
self._session = session
self._poll_stdout = select.poll()
self._poll_stdout.register(self._stdout, select.POLLIN)

# Disable command echo and prompt.
self._prepare_terminal()
Expand All @@ -603,49 +635,56 @@ def start_session(self):

return session

@_ssm_retry
def exec_command(self, cmd, in_data=None, sudoable=True):
"""run a command on the ssm host"""

super().exec_command(cmd, in_data=in_data, sudoable=sudoable)

self._vvv(f"EXEC: {to_text(cmd)}")

session = self._session

mark_begin = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
if self.is_windows:
mark_start = mark_begin + " $LASTEXITCODE"
else:
mark_start = mark_begin
mark_end = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
def poll_stdout(self, timeout: int = 1000) -> bool:
"""Polls the stdout file descriptor.

# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, sudoable, mark_start, mark_end)

self._flush_stderr(session)
:param timeout: Specifies the length of time in milliseconds which the system will wait.
:returns: A boolean to specify the polling result
"""
if self._polling_obj is None:
self._polling_obj = select.poll()
self._polling_obj.register(self._stdout, select.POLLIN)
return bool(self._polling_obj.poll(timeout))

for chunk in chunks(cmd, 1024):
session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict"))
def poll(self, label: str, cmd: str) -> NoReturn:
"""Poll session to retrieve content from stdout.

:param label: A label for the display (EXEC, PRE...)
:param cmd: The command being executed
"""
start = round(time.time())
yield self.poll_stdout()
timeout = self.get_option("ssm_timeout")
while self._session.poll() is None:
remaining = start + timeout - round(time.time())
self._vvvv(f"{label} remaining: {remaining} second(s)")
if remaining < 0:
self._has_timeout = True
raise AnsibleConnectionFailure(f"{label} command '{cmd}' timeout on host: {self.instance_id}")
yield self.poll_stdout()

def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]:
"""Interact with session.
Read stdout between the markers until 'mark_end' is reached.

:param cmd: The command being executed.
:param mark_start: The marker which starts the output.
:param mark_begin: The begin marker.
:param mark_end: The end marker.
:returns: A tuple with the return code, the stdout and the stderr content.
"""
# Read stdout between the markers
stdout = ""
win_line = ""
begin = False
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
while session.poll() is None:
remaining = stop_time - int(round(time.time()))
if remaining < 1:
self._timeout = True
self._vvvv(f"EXEC timeout stdout: \n{to_text(stdout)}")
raise AnsibleConnectionFailure(f"SSM exec_command timeout on host: {self.instance_id}")
if self._poll_stdout.poll(1000):
line = self._filter_ansi(self._stdout.readline())
self._vvvv(f"EXEC stdout line: \n{to_text(line)}")
else:
self._vvvv(f"EXEC remaining: {remaining}")
returncode = None
for poll_result in self.poll("EXEC", cmd):
if not poll_result:
continue

line = filter_ansi(self._stdout.readline(), self.is_windows)
self._vvvv(f"EXEC stdout line: \n{line}")

if not begin and self.is_windows:
win_line = win_line + line
line = win_line
Expand All @@ -663,9 +702,33 @@ def exec_command(self, cmd, in_data=None, sudoable=True):
break
stdout = stdout + line

stderr = self._flush_stderr(session)
# see https://github.com/pylint-dev/pylint/issues/8909)
return (returncode, stdout, self._flush_stderr(self._session)) # pylint: disable=unreachable

@_ssm_retry
def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> Tuple[int, str, str]:
"""run a command on the ssm host"""

super().exec_command(cmd, in_data=in_data, sudoable=sudoable)

return (returncode, stdout, stderr)
self._vvv(f"EXEC: {to_text(cmd)}")

mark_begin = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
if self.is_windows:
mark_start = mark_begin + " $LASTEXITCODE"
else:
mark_start = mark_begin
mark_end = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])

# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, mark_start, mark_end)

self._flush_stderr(self._session)

for chunk in chunks(cmd, 1024):
self._session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict"))

return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)

def _prepare_terminal(self):
"""perform any one-time terminal settings"""
Expand All @@ -683,7 +746,7 @@ def _prepare_terminal(self):
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")

disable_prompt_complete = None
end_mark = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
disable_prompt_cmd = to_bytes(
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
errors="surrogate_or_strict",
Expand All @@ -692,18 +755,12 @@ def _prepare_terminal(self):

stdout = ""
# Custom command execution for when we're waiting for startup
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
while (not disable_prompt_complete) and (self._session.poll() is None):
remaining = stop_time - int(round(time.time()))
if remaining < 1:
self._timeout = True
self._vvvv(f"PRE timeout stdout: \n{to_bytes(stdout)}")
raise AnsibleConnectionFailure(f"SSM start_session timeout on host: {self.instance_id}")
if self._poll_stdout.poll(1000):
for poll_result in self.poll("PRE", "start_session"):
if disable_prompt_complete:
break
if poll_result:
stdout += to_text(self._stdout.read(1024))
self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}")
else:
self._vvvv(f"PRE remaining: {remaining}")

# wait til prompt is ready
if startup_complete is False:
Expand Down Expand Up @@ -735,12 +792,13 @@ def _prepare_terminal(self):
stdout = stdout[match.end():] # fmt: skip
disable_prompt_complete = True

if not disable_prompt_complete:
# see https://github.com/pylint-dev/pylint/issues/8909)
if not disable_prompt_complete: # pylint: disable=unreachable
raise AnsibleConnectionFailure(f"SSM process closed during _prepare_terminal on host: {self.instance_id}")
self._vvvv("PRE Terminal configured")

def _wrap_command(self, cmd, sudoable, mark_start, mark_end):
"""wrap command so stdout and status can be extracted"""
def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""

if self.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
Expand Down Expand Up @@ -790,23 +848,6 @@ def _post_process(self, stdout, mark_begin):

return (returncode, stdout)

def _filter_ansi(self, line):
"""remove any ANSI terminal control codes"""
line = to_text(line)

if self.is_windows:
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
line = osc_filter.sub("", line)
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
line = ansi_filter.sub("", line)

# Replace or strip sequence (at terminal width)
line = line.replace("\r\r\n", "\n")
if len(line) == 201:
line = line[:-1]

return line

def _flush_stderr(self, session_process):
"""read and return stderr with minimal blocking"""

Expand Down Expand Up @@ -996,7 +1037,7 @@ def close(self):
"""terminate the connection"""
if self._session_id:
self._vvv(f"CLOSING SSM CONNECTION TO: {self.instance_id}")
if self._timeout:
if self._has_timeout:
self._session.terminate()
else:
cmd = b"\nexit\n"
Expand Down
54 changes: 54 additions & 0 deletions tests/unit/plugins/connection/aws_ssm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-

# This file is part of Ansible
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from unittest.mock import MagicMock

import pytest

from ansible_collections.community.aws.plugins.connection.aws_ssm import Connection
from ansible_collections.community.aws.plugins.connection.aws_ssm import ConnectionBase


@pytest.fixture(name="connection_aws_ssm")
def fixture_connection_aws_ssm():
play_context = MagicMock()
play_context.shell = True

def connection_init(*args, **kwargs):
pass

Connection.__init__ = connection_init
ConnectionBase.exec_command = MagicMock()
connection = Connection()

connection._instance_id = "i-0a1b2c3d4e5f"
connection._polling_obj = None
connection._has_timeout = False
connection.is_windows = False

connection.poll_stdout = MagicMock()
connection._session = MagicMock()
connection._session.poll = MagicMock()
connection._session.poll.side_effect = lambda: None
connection._stdout = MagicMock()
connection._flush_stderr = MagicMock()

def display_msg(msg):
print("--- AWS SSM CONNECTION --- ", msg)

connection._v = MagicMock()
connection._v.side_effect = display_msg

connection._vv = MagicMock()
connection._vv.side_effect = display_msg

connection._vvv = MagicMock()
connection._vvv.side_effect = display_msg

connection._vvvv = MagicMock()
connection._vvvv.side_effect = display_msg

connection.get_option = MagicMock()
return connection
Loading
Loading