Skip to content

Commit

Permalink
update test, multithreading is not evenly disttributed
Browse files Browse the repository at this point in the history
  • Loading branch information
betolink committed Jan 18, 2025
1 parent fee6958 commit 5140afa
Showing 1 changed file with 67 additions and 62 deletions.
129 changes: 67 additions & 62 deletions tests/unit/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,74 +134,79 @@ def test_store_can_create_s3_fsspec_session(self):
return None

@responses.activate
def test_session_cloning_and_file_download(self):
# Mock URLs and their responses
def test_session_reuses_token_download(self):
mock_creds = {
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
}
urls = [f"https://example.com/file{i}" for i in range(1, 11)]
for i, url in enumerate(urls):
responses.add(
responses.GET, url, body=f"Content of file {i + 1}", status=200
)

# Mock authentication and store setup
mock_auth = MagicMock()
mock_auth.authenticated = True
mock_auth.system.edl_hostname = "urs.earthdata.nasa.gov" # Mock hostname
responses.add(
responses.GET,
"https://urs.earthdata.nasa.gov/profile",
json=mock_creds,
status=200,
)
test_cases = [
(2, 500), # 2 threads, 500 files
(4, 400), # 4 threads, 400 files
(8, 5000), # 8 threads, 5k files
]
for n_threads, n_files in test_cases:
with self.subTest(n_threads=n_threads, n_files=n_files):
urls = [f"https://example.com/file{i}" for i in range(1, n_files + 1)]
for i, url in enumerate(urls):
responses.add(
responses.GET, url, body=f"Content of file {i + 1}", status=200
)

mock_auth = MagicMock()
mock_auth.authenticated = True
mock_auth.system.edl_hostname = "urs.earthdata.nasa.gov"
responses.add(
responses.GET,
"https://urs.earthdata.nasa.gov/profile",
json=mock_creds,
status=200,
)

original_session = SessionWithHeaderRedirection()
original_session.cookies.set("sessionid", "mocked-session-id")
mock_auth.get_session.return_value = original_session

# Create the Store instance
store = Store(auth=mock_auth)
store.thread_locals = threading.local() # Use real thread-local storage

# Track cloned sessions
cloned_sessions = set()

def mock_clone_session_in_local_thread(original_session):
"""Mock session cloning to track cloned sessions."""
if not hasattr(store.thread_locals, "local_thread_session"):
session = SessionWithHeaderRedirection()
session.cookies.update(original_session.cookies)
cloned_sessions.add(id(session)) # Track unique sessions by ID
store.thread_locals.local_thread_session = session

with patch.object(
store,
"_clone_session_in_local_thread",
side_effect=mock_clone_session_in_local_thread,
):
mock_directory = Path("/mock/directory")
downloaded_files = []

def mock_download_file(url):
"""Mock file download to track downloaded files."""
# Ensure session cloning happens before downloading
store._clone_session_in_local_thread(original_session)
downloaded_files.append(url)
return mock_directory / f"{url.split('/')[-1]}"

with patch.object(store, "_download_file", side_effect=mock_download_file):
# Test multi-threaded download with 2 threads
pqdm(urls, store._download_file, n_jobs=2) # type: ignore

# Verify sessions cloned
self.assertEqual(len(cloned_sessions), 2) # 2 sessions, one per thread

# Verify files downloaded
self.assertEqual(len(downloaded_files), 10) # 10 files downloaded
self.assertCountEqual(downloaded_files, urls) # All files accounted for
original_session = SessionWithHeaderRedirection()
original_session.cookies.set("sessionid", "mocked-session-cookie")
mock_auth.get_session.return_value = original_session

store = Store(auth=mock_auth)
store.thread_locals = threading.local() # Use real thread-local storage

# Track cloned sessions
cloned_sessions = set()

def mock_clone_session_in_local_thread(original_session):
"""Mock session cloning to track cloned sessions."""
if not hasattr(store.thread_locals, "local_thread_session"):
session = SessionWithHeaderRedirection()
session.cookies.update(original_session.cookies)
cloned_sessions.add(id(session))
store.thread_locals.local_thread_session = session

with patch.object(
store,
"_clone_session_in_local_thread",
side_effect=mock_clone_session_in_local_thread,
):
mock_directory = Path("/mock/directory")
downloaded_files = []

def mock_download_file(url):
"""Mock file download to track downloaded files."""
# Ensure session cloning happens before downloading
store._clone_session_in_local_thread(original_session)
downloaded_files.append(url)
return mock_directory / f"{url.split('/')[-1]}"

with patch.object(
store, "_download_file", side_effect=mock_download_file
):
# Test multi-threaded download
pqdm(urls, store._download_file, n_jobs=n_threads) # type: ignore

# We make sure we reuse the token up to N threads
self.assertTrue(len(cloned_sessions) <= n_threads)

self.assertEqual(len(downloaded_files), n_files) # 10 files downloaded
self.assertCountEqual(downloaded_files, urls) # All files accounted for


@pytest.mark.xfail(
Expand Down

0 comments on commit 5140afa

Please sign in to comment.