Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added S3_ENDPOINT variable #3368

Merged
merged 1 commit into from
Mar 16, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions tensorboard/compat/tensorflow_stub/io/gfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class S3FileSystem(object):
def __init__(self):
if not boto3:
raise ImportError("boto3 must be installed for S3 support.")
self._s3_endpoint = os.environ.get("S3_ENDPOINT", None)

def bucket_and_path(self, url):
"""Split an S3-prefixed URL into bucket and path."""
Expand All @@ -233,7 +234,7 @@ def bucket_and_path(self, url):

def exists(self, filename):
"""Determines whether a path exists or not."""
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(filename)
r = client.list_objects(Bucket=bucket, Prefix=path, Delimiter="/")
if r.get("Contents") or r.get("CommonPrefixes"):
Expand Down Expand Up @@ -264,7 +265,7 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None):
is an opaque value that can be passed to the next invocation of
`read(...) ' in order to continue from the last read position.
"""
s3 = boto3.resource("s3")
s3 = boto3.resource("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(filename)
args = {}

Expand Down Expand Up @@ -292,7 +293,7 @@ def read(self, filename, binary_mode=False, size=None, continue_from=None):
if size is not None:
# Asked for too much, so request just to the end. Do this
# in a second request so we don't check length in all cases.
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
obj = client.head_object(Bucket=bucket, Key=path)
content_length = obj["ContentLength"]
endpoint = min(content_length, offset + size)
Expand Down Expand Up @@ -321,7 +322,7 @@ def write(self, filename, file_content, binary_mode=False):
file_content: string, the contents
binary_mode: bool, write as binary if True, otherwise text
"""
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(filename)
# Always convert to bytes for writing
if binary_mode:
Expand All @@ -348,7 +349,7 @@ def glob(self, filename):
# filesystems in some way.
return []
filename = filename[:-1]
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(filename)
p = client.get_paginator("list_objects")
keys = []
Expand All @@ -361,7 +362,7 @@ def glob(self, filename):

def isdir(self, dirname):
"""Returns whether the path is a directory or not."""
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(dirname)
if not path.endswith("/"):
path += "/" # This will now only retrieve subdir content
Expand All @@ -372,7 +373,7 @@ def isdir(self, dirname):

def listdir(self, dirname):
"""Returns a list of entries contained within a directory."""
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(dirname)
p = client.get_paginator("list_objects")
if not path.endswith("/"):
Expand All @@ -394,7 +395,7 @@ def makedirs(self, dirname):
raise errors.AlreadyExistsError(
None, None, "Directory already exists"
)
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(dirname)
if not path.endswith("/"):
path += "/" # This will make sure we don't override a file
Expand All @@ -404,7 +405,7 @@ def stat(self, filename):
"""Returns file statistics for a given path."""
# NOTE: Size of the file is given by ContentLength from S3,
# but we convert to .length
client = boto3.client("s3")
client = boto3.client("s3", endpoint_url=self._s3_endpoint)
bucket, path = self.bucket_and_path(filename)
try:
obj = client.head_object(Bucket=bucket, Key=path)
Expand Down