Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

connection/aws_ssm - create S3clientmanager class and move related methods #2255

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
09da93e
create class S3ClientManager, move _get_bucket_endpoint's functionali…
mandar242 Feb 13, 2025
f2880a0
move _get_boto_client's functionality to S3ClientManager class
mandar242 Feb 26, 2025
44c18de
move _get_url's functionality to S3ClientManager class
mandar242 Feb 27, 2025
4e335d8
move _generate_encryption_settings's functionality to S3ClientManager…
mandar242 Feb 27, 2025
fb77ca1
add type hints
mandar242 Feb 27, 2025
af0ca82
minor fix
mandar242 Feb 27, 2025
3b5fb25
reorder client initialization to ensure s3_manager is available befor…
mandar242 Feb 27, 2025
6500337
minor fix
mandar242 Feb 27, 2025
3e19397
move s3 client initialization to S3ClientManager class
mandar242 Feb 27, 2025
d7acf92
copy s3 client from S3ClientManager class to Connection class
mandar242 Feb 27, 2025
eddbc1f
adjust unit tests as s3 client is initialized in S3ClientManager class
mandar242 Feb 27, 2025
058619b
add changlog fragment
mandar242 Feb 27, 2025
8022a6a
black and isort linter fixes
mandar242 Feb 27, 2025
f4f640a
isort linter fix
mandar242 Feb 27, 2025
029846d
add unit test for s3clientmanager.initialize_s3_client
mandar242 Feb 28, 2025
1a66b02
add unit tests for s3clientmanager.get_url
mandar242 Feb 28, 2025
21c4870
add unit tests for s3clientmanager.generate_encryption_settings
mandar242 Feb 28, 2025
7b9bfac
add unit tests for s3clientmanager.get_boto_client
mandar242 Mar 1, 2025
4354b31
add unit tests for s3clientmanager.get_bucket_endpoint
mandar242 Mar 1, 2025
b28c13c
merge conflict rebase cleanup
mandar242 Mar 5, 2025
0af6b2e
move S3ClientManager to its own file for modularity
mandar242 Mar 5, 2025
31f899f
mock S3ClientManager class in test_generate_commands
mandar242 Mar 5, 2025
d82f4f8
change initialize_s3_client to initialize_client
mandar242 Mar 6, 2025
80acd5e
minor fix
mandar242 Mar 6, 2025
dcd20bb
remove unused client definition
mandar242 Mar 6, 2025
99c21b7
remove _get_bucket_endpoint in favor of s3clientmanager.get_bucket_en…
mandar242 Mar 6, 2025
7e02ccb
Refactor S3ClientManager to explicitly handle S3 client creation
mandar242 Mar 6, 2025
fa9596c
add back original comments from get_bucket_endpoint
mandar242 Mar 6, 2025
b6340ef
remove _generate_encryption_settings, _get_url, _initialize_s3_client…
mandar242 Mar 6, 2025
c24801a
update tests as _initialize_s3_client does not exist anymore and func…
mandar242 Mar 6, 2025
fde7a6c
sanity and linter fixes
mandar242 Mar 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Refactor connection/aws_ssm to add new S3ClientManager class and move relevant methods to the new class (https://github.com/ansible-collections/community.aws/pull/2255).
145 changes: 37 additions & 108 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@
import string
import subprocess
import time
from functools import wraps
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import NoReturn
from typing import Optional
Expand All @@ -345,8 +348,6 @@
except ImportError:
pass

from functools import wraps

from ansible.errors import AnsibleConnectionFailure
from ansible.errors import AnsibleError
from ansible.errors import AnsibleFileNotFound
Expand All @@ -360,10 +361,12 @@

from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3

from ansible_collections.community.aws.plugins.module_utils.s3clientmanager import S3ClientManager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if you move this to plugin_utils like for amazon.aws collection?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo for now module_utils seem a reasonable place, as we will be adding more classes like SSMSessionManager and FileTransferManager with other jira stories.
But I'm open to either module_utils or creating new dir plugin_utils.
@abikouo @alinabuzachis

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, I had to put the amazon.aws pieces into plugin_utils rather than module_utils, because the linter will raise an error if you import from something that won't be sent to a remote host.

Personally I would recommend moving this into plugin_utils.

Over in amazon.aws.plugins.plugin_utils we also have the initial framework for an "AWS" connection plugin (plugin_utils.connection.AWSConnectionBase), which might simplify some of this code further.


display = Display()


def _ssm_retry(func):
def _ssm_retry(func: Any) -> Any:
"""
Decorator to retry in the case of a connection failure
Will retry if:
Expand All @@ -374,7 +377,7 @@ def _ssm_retry(func):
"""

@wraps(func)
def wrapped(self, *args, **kwargs):
def wrapped(self, *args: Any, **kwargs: Any) -> Any:
remaining_tries = int(self.get_option("reconnection_retries")) + 1
cmd_summary = f"{args[0]}..."
for attempt in range(remaining_tries):
Expand Down Expand Up @@ -413,7 +416,7 @@ def wrapped(self, *args, **kwargs):
return wrapped


def chunks(lst, n):
def chunks(lst: List, n: int) -> Iterator[List[Any]]:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n] # fmt: skip
Expand Down Expand Up @@ -471,7 +474,7 @@ class Connection(ConnectionBase):
_timeout = False
MARK_LENGTH = 26

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

if not HAS_BOTO3:
Expand All @@ -492,12 +495,11 @@ def __init__(self, *args, **kwargs):
self._shell_type = "powershell"
self.is_windows = True

def __del__(self):
def __del__(self) -> None:
self.close()

def _connect(self):
def _connect(self) -> Any:
"""connect to the host via ssm"""

self._play_context.remote_user = getpass.getuser()

if not self._session_id:
Expand All @@ -509,16 +511,23 @@ def _init_clients(self) -> None:
Initializes required AWS clients (SSM and S3).
Delegates client initialization to specialized methods.
"""

self._vvvv("INITIALIZE BOTO3 CLIENTS")
profile_name = self.get_option("profile") or ""
region_name = self.get_option("region")

# Initialize SSM client
self._initialize_ssm_client(region_name, profile_name)
# Initialize S3ClientManager
self.s3_manager = S3ClientManager(self)

# Initialize S3 client
self._initialize_s3_client(profile_name)
s3_endpoint_url, s3_region_name = self.s3_manager.get_bucket_endpoint()
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
self.s3_manager.initialize_client(
region_name=s3_region_name, endpoint_url=s3_endpoint_url, profile_name=profile_name
)
self._s3_client = self.s3_manager._s3_client

# Initialize SSM client
self._initialize_ssm_client(region_name, profile_name)

def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None:
"""
Expand All @@ -538,84 +547,26 @@ def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str)
profile_name=profile_name,
)

def _initialize_s3_client(self, profile_name: str) -> None:
"""
Initializes the S3 client used for accessing S3 buckets.

Args:
profile_name (str): AWS profile name for authentication.

Returns:
None
"""

s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
self._s3_client = self._get_boto_client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
profile_name=profile_name,
)

def _display(self, f, message):
def _display(self, f: Any, message: str) -> None:
if self.host:
host_args = {"host": self.host}
else:
host_args = {}
f(to_text(message), **host_args)

def _v(self, message):
def _v(self, message: str) -> None:
self._display(display.v, message)

def _vv(self, message):
def _vv(self, message: str) -> None:
self._display(display.vv, message)

def _vvv(self, message):
def _vvv(self, message: str) -> None:
self._display(display.vvv, message)

def _vvvv(self, message):
def _vvvv(self, message: str) -> None:
self._display(display.vvvv, message)

def _get_bucket_endpoint(self):
"""
Fetches the correct S3 endpoint and region for use with our bucket.
If we don't explicitly set the endpoint then some commands will use the global
endpoint and fail
(new AWS regions and new buckets in a region other than the one we're running in)
"""

region_name = self.get_option("region") or "us-east-1"
profile_name = self.get_option("profile") or ""
self._vvvv("_get_bucket_endpoint: S3 (global)")
tmp_s3_client = self._get_boto_client(
"s3",
region_name=region_name,
profile_name=profile_name,
)
# Fetch the location of the bucket so we can open a client against the 'right' endpoint
# This /should/ always work
head_bucket = tmp_s3_client.head_bucket(
Bucket=(self.get_option("bucket_name")),
)
bucket_region = head_bucket.get("ResponseMetadata", {}).get("HTTPHeaders", {}).get("x-amz-bucket-region", None)
if bucket_region is None:
bucket_region = "us-east-1"

if self.get_option("bucket_endpoint_url"):
return self.get_option("bucket_endpoint_url"), bucket_region

# Create another client for the region the bucket lives in, so we can nab the endpoint URL
self._vvvv(f"_get_bucket_endpoint: S3 (bucket region) - {bucket_region}")
s3_bucket_client = self._get_boto_client(
"s3",
region_name=bucket_region,
profile_name=profile_name,
)

return s3_bucket_client.meta.endpoint_url, s3_bucket_client.meta.region_name

def reset(self):
def reset(self) -> Any:
"""start a fresh ssm session"""
self._vvvv("reset called on ssm connection")
self.close()
Expand Down Expand Up @@ -885,7 +836,7 @@ def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
self._vvvv(f"_wrap_command: \n'{to_text(cmd)}'")
return cmd

def _post_process(self, stdout, mark_begin):
def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
"""extract command status and strip unwanted lines"""

if not self.is_windows:
Expand Down Expand Up @@ -919,7 +870,7 @@ def _post_process(self, stdout, mark_begin):

return (returncode, stdout)

def _flush_stderr(self, session_process):
def _flush_stderr(self, session_process) -> str:
"""read and return stderr with minimal blocking"""

poll_stderr = select.poll()
Expand All @@ -935,15 +886,6 @@ def _flush_stderr(self, session_process):

return stderr

def _get_url(self, client_method, bucket_name, out_path, http_method, extra_args=None):
"""Generate URL for get_object / put_object"""

client = self._s3_client
params = {"Bucket": bucket_name, "Key": out_path}
if extra_args is not None:
params.update(extra_args)
return client.generate_presigned_url(client_method, Params=params, ExpiresIn=3600, HttpMethod=http_method)

def _get_boto_client(self, service, region_name=None, profile_name=None, endpoint_url=None):
"""Gets a boto3 client based on the STS token"""

Expand Down Expand Up @@ -971,22 +913,9 @@ def _get_boto_client(self, service, region_name=None, profile_name=None, endpoin
)
return client

def _escape_path(self, path):
def _escape_path(self, path: str) -> str:
return path.replace("\\", "/")

def _generate_encryption_settings(self):
put_args = {}
put_headers = {}
if not self.get_option("bucket_sse_mode"):
return put_args, put_headers

put_args["ServerSideEncryption"] = self.get_option("bucket_sse_mode")
put_headers["x-amz-server-side-encryption"] = self.get_option("bucket_sse_mode")
if self.get_option("bucket_sse_mode") == "aws:kms" and self.get_option("bucket_sse_kms_key_id"):
put_args["SSEKMSKeyId"] = self.get_option("bucket_sse_kms_key_id")
put_headers["x-amz-server-side-encryption-aws-kms-key-id"] = self.get_option("bucket_sse_kms_key_id")
return put_args, put_headers

def _generate_commands(
self,
bucket_name: str,
Expand All @@ -1006,11 +935,11 @@ def _generate_commands(
:returns: A tuple containing a list of command dictionaries along with any ``put_args`` dictionaries.
"""

put_args, put_headers = self._generate_encryption_settings()
put_args, put_headers = self.s3_manager.generate_encryption_settings()
commands = []

put_url = self._get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args)
get_url = self._get_url("get_object", bucket_name, s3_path, "GET")
put_url = self.s3_manager.get_url("put_object", bucket_name, s3_path, "PUT", extra_args=put_args)
get_url = self.s3_manager.get_url("get_object", bucket_name, s3_path, "GET")

if self.is_windows:
put_command_headers = "; ".join([f"'{h}' = '{v}'" for h, v in put_headers.items()])
Expand Down Expand Up @@ -1150,7 +1079,7 @@ def _file_transport_command(
# Remove the files from the bucket after they've been transferred
client.delete_object(Bucket=bucket_name, Key=s3_path)

def put_file(self, in_path, out_path):
def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
"""transfer a file from local to remote"""

super().put_file(in_path, out_path)
Expand All @@ -1161,15 +1090,15 @@ def put_file(self, in_path, out_path):

return self._file_transport_command(in_path, out_path, "put")

def fetch_file(self, in_path, out_path):
def fetch_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
"""fetch a file from remote to local"""

super().fetch_file(in_path, out_path)

self._vvv(f"FETCH {in_path} TO {out_path}")
return self._file_transport_command(in_path, out_path, "get")

def close(self):
def close(self) -> None:
"""terminate the connection"""
if self._session_id:
self._vvv(f"CLOSING SSM CONNECTION TO: {self.instance_id}")
Expand Down
Loading
Loading