Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to async logic for region extraction #49

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:

jupyter labextension list
jupyter labextension list 2>&1 | grep -ie "jupyter-drives.*OK"
python -m jupyterlab.browser_check
python -m jupyterlab.browser_check --no-chrome-test

- name: Package the extension
run: |
Expand Down
22 changes: 9 additions & 13 deletions jupyter_drives/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from libcloud.storage.types import Provider
from libcloud.storage.providers import get_driver
import pyarrow
import boto3
from aiobotocore.session import get_session

from .log import get_logger
from .base import DrivesConfig
Expand All @@ -42,14 +42,11 @@ def __init__(self, config: traitlets.config.Config) -> None:
self._content_managers = {}
self._max_files_listed = 1000

# initiate boto3 session if we are dealing with S3 drives
# initiate aiobotocore session if we are dealing with S3 drives
if self._config.provider == 's3':
self._s3_clients = {}
if self._config.access_key_id and self._config.secret_access_key:
if self._config.session_token is None:
self._s3_session = boto3.Session(aws_access_key_id = self._config.access_key_id, aws_secret_access_key = self._config.secret_access_key)
else:
self._s3_session = boto3.Session(aws_access_key_id = self._config.access_key_id, aws_secret_access_key = self._config.secret_access_key, aws_session_token = self._config.session_token)
self._s3_clients = {}
self._s3_session = get_session()
else:
raise tornado.web.HTTPError(
status_code= httpx.codes.BAD_REQUEST,
Expand Down Expand Up @@ -149,7 +146,7 @@ async def mount_drive(self, drive_name, provider):
if drive_name not in self._content_managers or self._content_managers[drive_name] is None:
if provider == 's3':
# get region of drive
region = self._get_drive_location(drive_name)
region = await self._get_drive_location(drive_name)
if self._config.session_token is None:
configuration = {
"aws_access_key_id": self._config.access_key_id,
Expand Down Expand Up @@ -555,7 +552,7 @@ async def presigned_link(self, drive_name, path):
}
return response

def _get_drive_location(self, drive_name):
async def _get_drive_location(self, drive_name):
"""Helping function for getting drive region.

Args:
Expand All @@ -564,10 +561,9 @@ def _get_drive_location(self, drive_name):
location = 'eu-north-1'
try:
# set temporary client for location extraction
s3 = self._s3_session.client('s3')
result = s3.get_bucket_location(Bucket = drive_name)

location = result['LocationConstraint']
async with self._s3_session.create_client('s3', aws_secret_access_key=self._config.secret_access_key, aws_access_key_id=self._config.access_key_id, aws_session_token=self._config.session_token) as client:
result = await client.get_bucket_location(Bucket=drive_name)
location = result['LocationConstraint']
except Exception as e:
raise tornado.web.HTTPError(
status_code= httpx.codes.BAD_REQUEST,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies = [
"obstore>=0.3.0b,<0.4",
"arro3-core>=0.2.1,<0.3",
"pyarrow>=18.0.0,<19.0.0",
"aiobotocore>=2.15.2,<2.16.0",
"jupyter_server>=2.14.2,<3",
"s3contents>=0.11.1,<0.12.0",
"apache-libcloud>=3.8.0, <4",
"entrypoints>=0.4, <0.5",
"httpx>=0.25.1, <0.26"
Expand Down
Loading