Skip to content

Commit

Permalink
aws_ssm: Refactor _init_clients Method (#2223) (#2230)
Browse files Browse the repository at this point in the history
This is a backport of PR #2223 as merged into main (56b0886).
SUMMARY


Refer: https://issues.redhat.com/browse/ACA-2092
This PR Refactors the _init_clients method
ISSUE TYPE


Bugfix Pull Request
Docs Pull Request
Feature Pull Request
New Module Pull Request

COMPONENT NAME

ADDITIONAL INFORMATION

Reviewed-by: Bikouo Aubin
  • Loading branch information
patchback[bot] authored Feb 4, 2025
1 parent 4a98551 commit aa4c8fc
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 59 deletions.
2 changes: 2 additions & 0 deletions changelogs/fragments/refactor_ssm_init_client.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
minor_changes:
- aws_ssm - Refactor _init_clients Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/2223).
151 changes: 92 additions & 59 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# Based on the ssh connection plugin by Michael DeHaan


DOCUMENTATION = r"""
name: aws_ssm
author:
Expand Down Expand Up @@ -284,7 +285,6 @@
name: nginx
state: present
"""

import os
import getpass
import json
Expand All @@ -295,6 +295,7 @@
import string
import subprocess
import time
from typing import Optional

try:
import boto3
Expand Down Expand Up @@ -347,7 +348,10 @@ def wrapped(self, *args, **kwargs):
if isinstance(e, AnsibleConnectionFailure):
msg = f"ssm_retry: attempt: {attempt}, cmd ({cmd_summary}), pausing for {pause} seconds"
else:
msg = f"ssm_retry: attempt: {attempt}, caught exception({e}) from cmd ({cmd_summary}), pausing for {pause} seconds"
msg = (
f"ssm_retry: attempt: {attempt}, caught exception({e})"
f"from cmd ({cmd_summary}),pausing for {pause} seconds"
)

self._vv(msg)

Expand Down Expand Up @@ -390,6 +394,90 @@ class Connection(ConnectionBase):
_timeout = False
MARK_LENGTH = 26

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if not HAS_BOTO3:
raise AnsibleError(missing_required_lib("boto3"))

self.host = self._play_context.remote_addr

if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
self.delegate = None
self.has_native_async = True
self.always_pipeline_modules = True
self.module_implementation_preferences = (".ps1", ".exe", "")
self.protocol = None
self.shell_id = None
self._shell_type = "powershell"
self.is_windows = True

def __del__(self):
self.close()

def _connect(self):
"""connect to the host via ssm"""

self._play_context.remote_user = getpass.getuser()

if not self._session_id:
self.start_session()
return self

def _init_clients(self) -> None:
"""
Initializes required AWS clients (SSM and S3).
Delegates client initialization to specialized methods.
"""

self._vvvv("INITIALIZE BOTO3 CLIENTS")
profile_name = self.get_option("profile") or ""
region_name = self.get_option("region")

# Initialize SSM client
self._initialize_ssm_client(region_name, profile_name)

# Initialize S3 client
self._initialize_s3_client(profile_name)

def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None:
"""
Initializes the SSM client used to manage sessions.
Args:
region_name (Optional[str]): AWS region for the SSM client.
profile_name (str): AWS profile name for authentication.
Returns:
None
"""

self._vvvv("SETUP BOTO3 CLIENTS: SSM")
self._client = self._get_boto_client(
"ssm",
region_name=region_name,
profile_name=profile_name,
)

def _initialize_s3_client(self, profile_name: str) -> None:
"""
Initializes the S3 client used for accessing S3 buckets.
Args:
profile_name (str): AWS profile name for authentication.
Returns:
None
"""

s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
self._s3_client = self._get_boto_client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
profile_name=profile_name,
)

def _display(self, f, message):
if self.host:
host_args = {"host": self.host}
Expand Down Expand Up @@ -448,62 +536,6 @@ def _get_bucket_endpoint(self):

return s3_bucket_client.meta.endpoint_url, s3_bucket_client.meta.region_name

def _init_clients(self):
self._vvvv("INITIALIZE BOTO3 CLIENTS")
profile_name = self.get_option("profile") or ""
region_name = self.get_option("region")

# The SSM Boto client, currently used to initiate and manage the session
# Note: does not handle the actual SSM session traffic
self._vvvv("SETUP BOTO3 CLIENTS: SSM")
ssm_client = self._get_boto_client(
"ssm",
region_name=region_name,
profile_name=profile_name,
)
self._client = ssm_client

s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
s3_bucket_client = self._get_boto_client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
profile_name=profile_name,
)

self._s3_client = s3_bucket_client

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if not HAS_BOTO3:
raise AnsibleError(missing_required_lib("boto3"))

self.host = self._play_context.remote_addr

if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
self.delegate = None
self.has_native_async = True
self.always_pipeline_modules = True
self.module_implementation_preferences = (".ps1", ".exe", "")
self.protocol = None
self.shell_id = None
self._shell_type = "powershell"
self.is_windows = True

def __del__(self):
self.close()

def _connect(self):
"""connect to the host via ssm"""

self._play_context.remote_user = getpass.getuser()

if not self._session_id:
self.start_session()
return self

def reset(self):
"""start a fresh ssm session"""
self._vvvv("reset called on ssm connection")
Expand Down Expand Up @@ -854,7 +886,8 @@ def _generate_commands(self, bucket_name, s3_path, in_path, out_path):
put_commands = [
(
"Invoke-WebRequest -Method PUT "
f"-Headers @{{{put_command_headers}}} " # @{'key' = 'value'; 'key2' = 'value2'}
# @{'key' = 'value'; 'key2' = 'value2'}
f"-Headers @{{{put_command_headers}}} "
f"-InFile '{in_path}' "
f"-Uri '{put_url}' "
f"-UseBasicParsing"
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/plugins/connection/test_aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,85 @@


class TestConnectionBaseClass:
def test_init_clients(self):
pc = PlayContext()
new_stdin = StringIO()
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)

# Mock get_option to return expected region and profile
def mock_get_option(key):
options = {
"profile": "test-profile",
"region": "us-east-1",
}
return options.get(key, None)

conn.get_option = MagicMock(side_effect=mock_get_option)

# Mock the _initialize_ssm_client and _initialize_s3_client methods
conn._initialize_ssm_client = MagicMock()
conn._initialize_s3_client = MagicMock()

conn._init_clients()

conn._initialize_ssm_client.assert_called_once_with("us-east-1", "test-profile")
conn._initialize_s3_client.assert_called_once_with("test-profile")

@patch("boto3.client")
def test_initialize_ssm_client(self, mock_boto3_client):
"""
Test for the _initialize_ssm_client method to ensure the SSM client is initialized correctly.
"""
pc = PlayContext()
new_stdin = StringIO()
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)

test_region_name = "us-west-2"
test_profile_name = "test-profile"

# Mock the _get_boto_client method to return a mock client
conn._get_boto_client = MagicMock(return_value=mock_boto3_client)

conn._initialize_ssm_client(test_region_name, test_profile_name)

conn._get_boto_client.assert_called_once_with(
"ssm",
region_name=test_region_name,
profile_name=test_profile_name,
)

assert conn._client is mock_boto3_client

@patch("boto3.client")
def test_initialize_s3_client(self, mock_boto3_client):
"""
Test for the _initialize_s3_client method to ensure the S3 client is initialized correctly.
"""

pc = PlayContext()
new_stdin = StringIO()
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)

test_profile_name = "test-profile"

# Mock the _get_bucket_endpoint method to return dummy values
conn._get_bucket_endpoint = MagicMock(return_value=("http://example.com", "us-west-2"))

conn._get_boto_client = MagicMock(return_value=mock_boto3_client)

conn._initialize_s3_client(test_profile_name)

conn._get_bucket_endpoint.assert_called_once()

conn._get_boto_client.assert_called_once_with(
"s3",
region_name="us-west-2",
endpoint_url="http://example.com",
profile_name=test_profile_name,
)

assert conn._s3_client is mock_boto3_client

@patch("os.path.exists")
@patch("subprocess.Popen")
@patch("select.poll")
Expand Down

0 comments on commit aa4c8fc

Please sign in to comment.