Skip to content

Commit

Permalink
Pass 'user_project' if set for blob downloads w/ 'mediaLink' set (#3500)
Browse files Browse the repository at this point in the history
  • Loading branch information
tseaver authored Jun 26, 2017
1 parent cbf073b commit dbdc6a6
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 11 deletions.
53 changes: 43 additions & 10 deletions storage/google/cloud/storage/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
import warnings

import httplib2
from six.moves.urllib.parse import parse_qsl
from six.moves.urllib.parse import quote
from six.moves.urllib.parse import urlencode
from six.moves.urllib.parse import urlsplit
from six.moves.urllib.parse import urlunsplit

import google.auth.transport.requests
from google import resumable_media
Expand Down Expand Up @@ -403,15 +407,19 @@ def _get_download_url(self):
:rtype: str
:returns: The download URL for the current blob.
"""
name_value_pairs = []
if self.media_link is None:
download_url = _DOWNLOAD_URL_TEMPLATE.format(path=self.path)
base_url = _DOWNLOAD_URL_TEMPLATE.format(path=self.path)
if self.generation is not None:
download_url += u'&generation={:d}'.format(self.generation)
if self.user_project is not None:
download_url += u'&userProject={}'.format(self.user_project)
return download_url
name_value_pairs.append(
('generation', '{:d}'.format(self.generation)))
else:
return self.media_link
base_url = self.media_link

if self.user_project is not None:
name_value_pairs.append(('userProject', self.user_project))

return _add_query_parameters(base_url, name_value_pairs)

def _do_download(self, transport, file_obj, download_url, headers):
"""Perform a download without any error handling.
Expand Down Expand Up @@ -658,12 +666,14 @@ def _do_multipart_upload(self, client, stream, content_type,
info = self._get_upload_arguments(content_type)
headers, object_metadata, content_type = info

upload_url = _MULTIPART_URL_TEMPLATE.format(
base_url = _MULTIPART_URL_TEMPLATE.format(
bucket_path=self.bucket.path)
name_value_pairs = []

if self.user_project is not None:
upload_url += '&userProject={}'.format(self.user_project)
name_value_pairs.append(('userProject', self.user_project))

upload_url = _add_query_parameters(base_url, name_value_pairs)
upload = MultipartUpload(upload_url, headers=headers)

if num_retries is not None:
Expand Down Expand Up @@ -734,12 +744,14 @@ def _initiate_resumable_upload(self, client, stream, content_type,
if extra_headers is not None:
headers.update(extra_headers)

upload_url = _RESUMABLE_URL_TEMPLATE.format(
base_url = _RESUMABLE_URL_TEMPLATE.format(
bucket_path=self.bucket.path)
name_value_pairs = []

if self.user_project is not None:
upload_url += '&userProject={}'.format(self.user_project)
name_value_pairs.append(('userProject', self.user_project))

upload_url = _add_query_parameters(base_url, name_value_pairs)
upload = ResumableUpload(upload_url, chunk_size, headers=headers)

if num_retries is not None:
Expand Down Expand Up @@ -1676,3 +1688,24 @@ def _raise_from_invalid_response(error, error_info=None):
faux_response = httplib2.Response({'status': response.status_code})
raise make_exception(faux_response, response.content,
error_info=error_info, use_json=False)


def _add_query_parameters(base_url, name_value_pairs):
"""Add one query parameter to a base URL.
:type base_url: string
:param base_url: Base URL (may already contain query parameters)
:type name_value_pairs: list of (string, string) tuples.
:param name_value_pairs: Names and values of the query parameters to add
:rtype: string
:returns: URL with additional query strings appended.
"""
if len(name_value_pairs) == 0:
return base_url

scheme, netloc, path, query, frag = urlsplit(base_url)
query = parse_qsl(query)
query.extend(name_value_pairs)
return urlunsplit((scheme, netloc, path, urlencode(query), frag))
46 changes: 45 additions & 1 deletion storage/tests/unit/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test__make_transport(self, fake_session_factory):

def test__get_download_url_with_media_link(self):
blob_name = 'something.txt'
bucket = mock.Mock(spec=[])
bucket = _Bucket(name='IRRELEVANT')
blob = self._make_one(blob_name, bucket=bucket)
media_link = 'http://test.invalid'
# Set the media link on the blob
Expand All @@ -375,6 +375,19 @@ def test__get_download_url_with_media_link(self):
download_url = blob._get_download_url()
self.assertEqual(download_url, media_link)

def test__get_download_url_with_media_link_w_user_project(self):
blob_name = 'something.txt'
user_project = 'user-project-123'
bucket = _Bucket(name='IRRELEVANT', user_project=user_project)
blob = self._make_one(blob_name, bucket=bucket)
media_link = 'http://test.invalid'
# Set the media link on the blob
blob._properties['mediaLink'] = media_link

download_url = blob._get_download_url()
self.assertEqual(
download_url, '{}?userProject={}'.format(media_link, user_project))

def test__get_download_url_on_the_fly(self):
blob_name = 'bzzz-fly.txt'
bucket = _Bucket(name='buhkit')
Expand Down Expand Up @@ -2430,6 +2443,37 @@ def test_with_error_info(self):
self.assertEqual(exc_info.exception.errors, [])


class Test__add_query_parameters(unittest.TestCase):

@staticmethod
def _call_fut(*args, **kwargs):
from google.cloud.storage.blob import _add_query_parameters

return _add_query_parameters(*args, **kwargs)

def test_w_empty_list(self):
BASE_URL = 'https://test.example.com/base'
self.assertEqual(self._call_fut(BASE_URL, []), BASE_URL)

def test_wo_existing_qs(self):
BASE_URL = 'https://test.example.com/base'
NV_LIST = [('one', 'One'), ('two', 'Two')]
expected = '&'.join([
'{}={}'.format(name, value) for name, value in NV_LIST])
self.assertEqual(
self._call_fut(BASE_URL, NV_LIST),
'{}?{}'.format(BASE_URL, expected))

def test_w_existing_qs(self):
BASE_URL = 'https://test.example.com/base?one=Three'
NV_LIST = [('one', 'One'), ('two', 'Two')]
expected = '&'.join([
'{}={}'.format(name, value) for name, value in NV_LIST])
self.assertEqual(
self._call_fut(BASE_URL, NV_LIST),
'{}&{}'.format(BASE_URL, expected))


class _Connection(object):

API_BASE_URL = 'http://example.com'
Expand Down

0 comments on commit dbdc6a6

Please sign in to comment.