diff --git a/changelogs/fragments/2096-refactor-random-helper-function.yml b/changelogs/fragments/2096-refactor-random-helper-function.yml new file mode 100644 index 00000000000..ab479c0640a --- /dev/null +++ b/changelogs/fragments/2096-refactor-random-helper-function.yml @@ -0,0 +1,2 @@ +minor_changes: + - ssm - add function to generate random strings for SSM CLI delimitation (https://github.com/ansible-collections/community.aws/pull/2235). diff --git a/plugins/connection/aws_ssm.py b/plugins/connection/aws_ssm.py index e50c43e495d..af9133ac96e 100644 --- a/plugins/connection/aws_ssm.py +++ b/plugins/connection/aws_ssm.py @@ -704,20 +704,26 @@ def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end: # see https://github.com/pylint-dev/pylint/issues/8909) return (returncode, stdout, self._flush_stderr(self._session)) # pylint: disable=unreachable + @staticmethod + def generate_mark() -> str: + """Generates a random string of characters to delimit SSM CLI commands""" + mark = "".join([random.choice(string.ascii_letters) for i in range(Connection.MARK_LENGTH)]) + return mark + @_ssm_retry def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> Tuple[int, str, str]: - """run a command on the ssm host""" + """When running a command on the SSM host, uses generate_mark to get delimiting strings""" super().exec_command(cmd, in_data=in_data, sudoable=sudoable) self._vvv(f"EXEC: {to_text(cmd)}") - mark_begin = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)]) + mark_begin = self.generate_mark() if self.is_windows: mark_start = mark_begin + " $LASTEXITCODE" else: mark_start = mark_begin - mark_end = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)]) + mark_end = self.generate_mark() # Wrap command in markers accordingly for the shell used cmd = self._wrap_command(cmd, mark_start, mark_end) @@ -745,7 +751,7 @@ def _prepare_terminal(self): disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict") disable_prompt_complete = None - end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)]) + end_mark = self.generate_mark() disable_prompt_cmd = to_bytes( "PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n", errors="surrogate_or_strict", diff --git a/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py b/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py index f6e4480277e..c2d7cba3fbe 100644 --- a/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py +++ b/tests/unit/plugins/connection/aws_ssm/test_aws_ssm.py @@ -11,6 +11,8 @@ from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 +from ansible_collections.community.aws.plugins.connection.aws_ssm import Connection + if not HAS_BOTO3: pytestmark = pytest.mark.skip("test_data_pipeline.py requires the python modules 'boto3' and 'botocore'") @@ -257,3 +259,12 @@ def test_plugins_connection_aws_ssm_close(self, s_check_output): conn._session_id.return_value = "a" conn._client = MagicMock() conn.close() + + def test_generate_mark(self): + """Testing string generation""" + test_a = Connection.generate_mark() + test_b = Connection.generate_mark() + + assert test_a != test_b + assert len(test_a) == Connection.MARK_LENGTH + assert len(test_b) == Connection.MARK_LENGTH