Skip to content

Commit

Permalink
Issue #30: enforce signature V4 to get bucket region
Browse files Browse the repository at this point in the history
  • Loading branch information
Harry Zhang committed Aug 31, 2017
1 parent 917d637 commit 2a5cd0b
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 41 deletions.
20 changes: 20 additions & 0 deletions .argo/test-yamls/aws-s3-regional-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
type: workflow
version: 1
name: argo-aws-s3-regional-test
description: Test S3 wrapper against real aws server in a randomly choosen region
inputs:
parameters:
COMMIT:
default: "%%session.commit%%"
REPO:
default: "%%session.repo%%"
steps:
- checkout:
template: argo-checkout
- aws-s3:
template: argo-platform-unit-test-base
arguments:
artifacts.CODE: "%%steps.checkout.outputs.artifacts.CODE%%"
parameters.COMMAND: "pytest -vv /src/platform/tests/aws_s3/"

57 changes: 16 additions & 41 deletions common/python/ax/cloud/aws/aws_s3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2015-2016 Applatix, Inc. All rights reserved.
# Copyright 2015-2017 Applatix, Inc. All rights reserved.
#

"""
Expand All @@ -16,6 +16,8 @@

import boto3
import requests
from ax.cloud import Cloud
from botocore.client import Config
from botocore.exceptions import ClientError
from retrying import retry

Expand Down Expand Up @@ -81,7 +83,8 @@ def __init__(self, bucket_name, aws_profile=None, region=None):
:return:
"""
self._name = bucket_name
self._region = region if region else self._get_bucket_region_from_aws(aws_profile)
self._aws_profile = aws_profile
self._region = region if region else self._get_bucket_region_from_aws()
assert self._region, "Please make sure bucket {} is created, or provide a region name to create bucket".format(self._name)
logger.info("Using region %s for bucket %s", self._region, self._name)

Expand All @@ -102,18 +105,16 @@ def get_bucket_name(self):
wait_exponential_multiplier=1000,
stop_max_attempt_number=3
)
def _get_bucket_region_from_aws(self, profile):
"""
Find out location of a bucket.
"""
# There is actually no easy way to achieve this.
# Most APIs require region first.

# Step 1. Call head_bucket() to get location.
# Don't give up when there is error.
# It's likely response headers contain location info even return code is not 200.
s3 = boto3.Session(profile_name=profile).client("s3")
logger.debug("Looking for region for bucket %s from head_bucket.", self._name)
def _get_bucket_region_from_aws(self):
# We assume cluster is not access any resource outside partition, e.g.
# clusters in partition "aws" will not access resource in partition "aws-us-gov"
instance_region = Cloud().meta_data().get_region()
s3 = boto3.Session(
profile_name=self._aws_profile,
region_name=instance_region
).client("s3", config=Config(signature_version='s3v4'))

logger.debug("Finding region for bucket %s from with initial region %s", self._name, instance_region)
try:
response = s3.head_bucket(Bucket=self._name)
logger.debug("Head_bucket returned OK %s", response)
Expand All @@ -125,33 +126,7 @@ def _get_bucket_region_from_aws(self, profile):
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
region = headers.get("x-amz-bucket-region", headers.get("x-amz-region", None))
logger.debug("Found region %s from head_bucket for %s, headers %s", region, self._name, headers)
if region is not None:
return region

# Step 2. In the unlikely event head_bucket fails, try to get it from get_bucket_location.
logger.debug("Looking for region for bucket %s from get_bucket_location.", self._name)
try:
return s3.get_bucket_location(Bucket=self._name)["LocationConstraint"]
except Exception as e:
if "NoSuchBucket" in str(e):
# Just in case someone deleted it.
return None

# Step 3. This is very similar to step 1. However we access API directly.
# We don't call this at first as it might cause slow down problem.
logger.debug("Looking for region for bucket %s from API endpoint directly.", self._name)
# According to https://github.com/aws/aws-sdk-go/issues/720#issuecomment-243891223
# performing a HEAD request and examine header is the best way to do so
head_bucket_url = "http://{bucket_name}.s3.amazonaws.com".format(bucket_name=self._name)
ret = requests.head(head_bucket_url, timeout=3)
try:
return ret.headers["x-amz-bucket-region"]
except KeyError:
logger.debug("Cannot get region from headers. Headers: %s. HTTP status code: %s", ret.headers, ret.status_code)
if ret.status_code == 404:
return None
else:
ret.raise_for_status()
return region

def region(self):
return self._region
Expand Down
Empty file.
53 changes: 53 additions & 0 deletions platform/tests/aws_s3/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2015-2017 Applatix, Inc. All rights reserved.
#

import random
import logging

from ax.cloud import Cloud
from ax.aws.meta_data import AWSMetaData
from .testdata import *

logger = logging.getLogger(__name__)


class AWSMetaDataMock(AWSMetaData):
def __init__(self):
pass

def get_region(self):
return random.choice(AWS_REGIONS)

def get_security_groups(self):
raise NotImplementedError()

def get_zone(self):
raise NotImplementedError()

def get_public_ip(self):
raise NotImplementedError()

def get_instance_id(self):
raise NotImplementedError()

def get_instance_type(self):
raise NotImplementedError()

def get_private_ip(self):
raise NotImplementedError()

def get_user_data(self, attr=None, plain_text=False):
raise NotImplementedError()


class CloudAWSMock(Cloud):
def __init__(self):
super(CloudAWSMock, self).__init__(target_cloud=self.CLOUD_AWS)
self._own_cloud = self.AX_CLOUD_AWS

def meta_data(self):
return AWSMetaDataMock()

210 changes: 210 additions & 0 deletions platform/tests/aws_s3/test_aws_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2015-2017 Applatix, Inc. All rights reserved.
#

import json
import pytest
import logging
import requests
import time

import ax.cloud
from .mock import CloudAWSMock
ax.cloud.Cloud = CloudAWSMock

from ax.cloud.aws import AXS3Bucket, BUCKET_CLEAN_KEYWORD
from .testdata import *

logger = logging.getLogger(__name__)

logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(threadName)s: %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S"
)
logging.getLogger("botocore").setLevel(logging.WARNING)
logging.getLogger("boto3").setLevel(logging.WARNING)
logging.getLogger("ax").setLevel(logging.DEBUG)
logging.getLogger("requests").setLevel(logging.WARNING)


def test_s3_get_region():
for region in AWS_REGIONS:
bucket_name = TEST_BUCKET_NAME_TEMPLATE.format(region=region)
logger.info("Testing GetRegion for bucket %s", bucket_name)

bucket = AXS3Bucket(bucket_name=bucket_name, aws_profile=TEST_AWS_PROFILE, region=region)
bucket.create()

assert bucket._get_bucket_region_from_aws() == region

# Need to cool down a bit as bucket creation / deletion is very heavy weighted
bucket.delete()
time.sleep(5)


def test_put_policy_no_bucket():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE, region=TEST_AWS_REGION)
assert not bucket.put_policy(policy="")


def test_delete_nonexist_bucket():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE, region=TEST_AWS_REGION)
assert bucket.delete(force=True)


def test_bucket_create():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE, region=TEST_AWS_REGION)
assert bucket.create()
assert bucket.exists()
assert bucket.empty()
assert bucket.clean()

# Recreate should return True
assert bucket.create()


# From this test on, bucket is already created, so we don't need to pass `region` to the class
# which is what most of our use cases are
def test_put_policy_invalid_format():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
assert not bucket.put_policy(policy=TEST_INVALID_POLICY_FORMAT)


def test_put_policy_invalid_content():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
assert not bucket.put_policy(policy=TEST_INVALID_POLICY_CONTENT)


def test_cors_operation():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)

# Do the operations twice to ensure idempotency
bucket.put_cors(TEST_CORS_CONFIG)
bucket.put_cors(TEST_CORS_CONFIG)
bucket.delete_cors()
bucket.delete_cors()


def test_single_object_operations():
file_name = "test_file"
file_content = "test_content"
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
assert bucket.put_object(file_name, file_content)

assert not bucket.clean()
assert not bucket.empty()

file_content_s3 = bucket.get_object(file_name)
if isinstance(file_content_s3, bytes):
assert file_content == file_content_s3.decode("utf-8")
else:
assert file_content == file_content_s3

file_name_cpy = "test_file_copy"
bucket.copy_object(file_name, file_name_cpy)

file_content_s3_cpy = bucket.get_object(file_name_cpy)
assert file_content_s3 == file_content_s3_cpy

assert bucket.delete_object(file_name)
assert bucket.get_object(file_name) is None

assert bucket.delete_object(file_name_cpy)
assert bucket.get_object(file_name_cpy) is None

assert bucket.clean()
assert bucket.empty()


def test_generate_object_url():
file_name = "test_file_url"
file_content = "test_content_url"
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.put_object(file_name, file_content, ACL="public-read")

url = bucket.get_object_url_from_key(key=file_name)
data = requests.get(url, timeout=5).text
assert data == file_content

assert bucket.delete_object(file_name)


def test_bucket_clean():
file_name = BUCKET_CLEAN_KEYWORD + "{:05d}".format(random.randint(1, 99999))
file_content = "test_content"
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
assert bucket.put_object(file_name, file_content)
assert bucket.clean()

if not bucket.delete_object(file_name):
pytest.fail("Failed to delete object {}".format(file_name))


def test_list_objects():
file_name_prefix = "test_file-"
file_content = "test_content"
file_name_set = set()
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
for i in range(50):
file_name = file_name_prefix + "{:03d}".format(i)
file_name_set.add(file_name)
assert bucket.put_object(key=file_name, data=file_content)
file_name_set_s3 = set(bucket.list_objects(keyword=file_name_prefix))
assert file_name_set_s3 == file_name_set


def test_delete_all_without_prefix():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
file_name_prefix = "test_file-"
file_content = "test_content"
for i in range(50):
file_name = file_name_prefix + "{:03d}".format(i)
assert bucket.put_object(key=file_name, data=file_content)

bucket.delete_all(use_prefix=False)

assert bucket.clean()
assert bucket.empty()


def test_delete_all_with_prefix_and_exemption():
file_name_prefix = "test_file-"
file_content = "test_content"
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
for i in range(50):
file_name = file_name_prefix + "{:03d}".format(i)
assert bucket.put_object(key=file_name, data=file_content)

bucket.put_object(key="special_file", data=file_content)
bucket.delete_all(obj_prefix=file_name_prefix, exempt=["test_file-015"])

remaining_file_s3 = set(bucket.list_objects(list_all=True))
remaining_file = {"special_file", "test_file-015"}

assert remaining_file_s3 == remaining_file

bucket.delete_all(use_prefix=False)
assert bucket.clean()
assert bucket.empty()


def test_bucket_delete():
bucket = AXS3Bucket(bucket_name=TEST_BUCKET_NAME, aws_profile=TEST_AWS_PROFILE)
assert bucket.create()
assert bucket.delete(force=True)

assert bucket.clean()
assert bucket.empty()

# Re-delete should return True
assert bucket.delete(force=True)

Loading

0 comments on commit 2a5cd0b

Please sign in to comment.