Skip to content

Commit

Permalink
Merge pull request #312 from lrosemberg/rosemberg/refactoring-dataset…
Browse files Browse the repository at this point in the history
…-upload

Refactoring dataset upload
  • Loading branch information
tonylampada authored Aug 20, 2024
2 parents 6f1b08a + f391451 commit 86cd5bf
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 127 deletions.
72 changes: 52 additions & 20 deletions roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,19 @@ class RoboflowError(Exception):
pass


class UploadError(RoboflowError):
pass
class ImageUploadError(RoboflowError):
def __init__(self, message, status_code=None):
self.message = message
self.status_code = status_code
self.retries = 0
super().__init__(self.message)


class AnnotationSaveError(RoboflowError):
def __init__(self, message, status_code=None):
self.message = message
self.status_code = status_code
super().__init__(self.message)


def get_workspace(api_key, workspace_url):
Expand Down Expand Up @@ -78,24 +89,38 @@ def upload_image(

else:
# Hosted image upload url

upload_url = _hosted_upload_url(api_key, project_url, image_path, split, coalesced_batch_name, tag_names)

# Get response
response = requests.post(upload_url, timeout=(300, 300))

responsejson = None
try:
responsejson = response.json()
except Exception:
pass

if response.status_code != 200:
if responsejson:
raise UploadError(f"Bad response: {response.status_code}: {responsejson}")
err_msg = responsejson

if err_msg.get("error"):
err_msg = err_msg["error"]

if err_msg.get("message"):
err_msg = err_msg["message"]

raise ImageUploadError(err_msg, status_code=response.status_code)
else:
raise UploadError(f"Bad response: {response}")
raise ImageUploadError(str(response), status_code=response.status_code)

if not responsejson: # fail fast
raise UploadError(f"upload image {image_path} 200 OK, unexpected response: {response}")
raise ImageUploadError(str(response), status_code=response.status_code)

if not (responsejson.get("success") or responsejson.get("duplicate")):
raise UploadError(f"Server rejected image: {responsejson}")
message = responsejson.get("message") or str(responsejson)
raise ImageUploadError(message)

return responsejson


Expand Down Expand Up @@ -128,24 +153,28 @@ def save_annotation(
headers={"Content-Type": "application/json"},
timeout=(60, 60),
)

# Handle response
responsejson = None
try:
responsejson = response.json()
except Exception:
pass

if not responsejson:
raise _save_annotation_error(image_id, response)
raise _save_annotation_error(response)
if response.status_code not in (200, 409):
raise _save_annotation_error(image_id, response)
raise _save_annotation_error(response)
if response.status_code == 409:
if "already annotated" in responsejson.get("error", {}).get("message"):
return {"warn": "already annotated"}
else:
raise _save_annotation_error(image_id, response)
raise _save_annotation_error(response)
if responsejson.get("error"):
raise _save_annotation_error(image_id, response)
raise _save_annotation_error(response)
if not responsejson.get("success"):
raise _save_annotation_error(image_id, response)
raise _save_annotation_error(response)

return responsejson


Expand Down Expand Up @@ -191,17 +220,20 @@ def _local_upload_url(api_key, project_url, batch_name, tag_names, sequence_numb
return _upload_url(api_key, project_url, **query_params)


def _save_annotation_error(image_id, response):
errmsg = f"save annotation for {image_id} / "
def _save_annotation_error(response):
responsejson = None
try:
responsejson = response.json()
except Exception:
pass

if not responsejson:
errmsg += f"bad response: {response.status_code}: {response}"
elif responsejson.get("error"):
errmsg += f"bad response: {response.status_code}: {responsejson['error']}"
else:
errmsg += f"bad response: {response.status_code}: {responsejson}"
return UploadError(errmsg)
return AnnotationSaveError(response, status_code=response.status_code)

if responsejson.get("error"):
err_msg = responsejson["error"]
if err_msg.get("message"):
err_msg = err_msg["message"]
return AnnotationSaveError(err_msg, status_code=response.status_code)

return AnnotationSaveError(str(responsejson), status_code=response.status_code)
159 changes: 96 additions & 63 deletions roboflow/core/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import mimetypes
import os
import re
import sys
import time
import warnings
Expand All @@ -12,6 +11,7 @@
import requests

from roboflow.adapters import rfapi
from roboflow.adapters.rfapi import ImageUploadError
from roboflow.config import API_URL, DEMO_KEYS
from roboflow.core.version import Version
from roboflow.util.general import Retry
Expand Down Expand Up @@ -465,6 +465,76 @@ def upload(
print("[ " + path + " ] was skipped.")
continue

def upload_image(
self,
image_path=None,
hosted_image=False,
split="train",
num_retry_uploads=0,
batch_name=None,
tag_names=[],
sequence_number=None,
sequence_size=None,
**kwargs,
):
project_url = self.id.rsplit("/")[1]

t0 = time.time()
upload_retry_attempts = 0
retry = Retry(num_retry_uploads, ImageUploadError)

try:
image = retry(
rfapi.upload_image,
self.__api_key,
project_url,
image_path,
hosted_image=hosted_image,
split=split,
batch_name=batch_name,
tag_names=tag_names,
sequence_number=sequence_number,
sequence_size=sequence_size,
**kwargs,
)
upload_retry_attempts = retry.retries
except ImageUploadError as e:
e.retries = upload_retry_attempts
raise e

upload_time = time.time() - t0

return image, upload_time, upload_retry_attempts

def save_annotation(
self,
annotation_path=None,
annotation_labelmap=None,
image_id=None,
job_name=None,
is_prediction: bool = False,
annotation_overwrite=False,
):
project_url = self.id.rsplit("/")[1]
annotation_name, annotation_str = self._annotation_params(annotation_path)
t0 = time.time()

annotation = rfapi.save_annotation(
self.__api_key,
project_url,
annotation_name, # type: ignore[type-var]
annotation_str, # type: ignore[type-var]
image_id,
job_name=job_name, # type: ignore[type-var]
is_prediction=is_prediction,
annotation_labelmap=annotation_labelmap,
overwrite=annotation_overwrite,
)

upload_time = time.time() - t0

return annotation, upload_time

def single_upload(
self,
image_path=None,
Expand All @@ -482,64 +552,41 @@ def single_upload(
sequence_size=None,
**kwargs,
):
project_url = self.id.rsplit("/")[1]
if image_path and image_id:
raise Exception("You can't pass both image_id and image_path")
if not (image_path or image_id):
raise Exception("You need to pass image_path or image_id")
if isinstance(annotation_labelmap, str):
annotation_labelmap = load_labelmap(annotation_labelmap)

uploaded_image, uploaded_annotation = None, None
upload_time = None
upload_time, annotation_time = None, None
upload_retry_attempts = 0

if image_path:
t0 = time.time()
try:
retry = Retry(num_retry_uploads, Exception)
uploaded_image = retry(
rfapi.upload_image,
self.__api_key,
project_url,
image_path,
hosted_image=hosted_image,
split=split,
batch_name=batch_name,
tag_names=tag_names,
sequence_number=sequence_number,
sequence_size=sequence_size,
**kwargs,
)
image_id = uploaded_image["id"] # type: ignore[index]
upload_retry_attempts = retry.retries
except rfapi.UploadError as e:
raise RuntimeError(f"Error uploading image: {self._parse_upload_error(e)}")
except BaseException as e:
uploaded_image = {"error": e}
finally:
upload_time = time.time() - t0

annotation_time = None
uploaded_image, upload_time, upload_retry_attempts = self.upload_image(
image_path,
hosted_image,
split,
num_retry_uploads,
batch_name,
tag_names,
sequence_number,
sequence_size,
**kwargs,
)
image_id = uploaded_image["id"] # type: ignore[index]

if annotation_path and image_id:
annotation_name, annotation_str = self._annotation_params(annotation_path)
try:
t0 = time.time()
uploaded_annotation = rfapi.save_annotation(
self.__api_key,
project_url,
annotation_name, # type: ignore[type-var]
annotation_str, # type: ignore[type-var]
image_id,
job_name=batch_name, # type: ignore[type-var]
is_prediction=is_prediction,
annotation_labelmap=annotation_labelmap,
overwrite=annotation_overwrite,
)
except rfapi.UploadError as e:
raise RuntimeError(f"Error uploading annotation: {self._parse_upload_error(e)}")
except BaseException as e:
uploaded_annotation = {"error": e}
finally:
annotation_time = time.time() - t0
uploaded_annotation, annotation_time = self.save_annotation(
annotation_path,
annotation_labelmap,
image_id,
batch_name,
is_prediction,
annotation_overwrite,
)

return {
"image": uploaded_image,
"annotation": uploaded_annotation,
Expand Down Expand Up @@ -568,20 +615,6 @@ def _annotation_params(self, annotation_path):
)
return annotation_name, annotation_string

def _parse_upload_error(self, error: rfapi.UploadError) -> str:
dict_part = str(error).split(": ", 2)[2]
dict_part = dict_part.replace("True", "true")
dict_part = dict_part.replace("False", "false")
dict_part = dict_part.replace("None", "null")
if re.search(r"'\w+':", dict_part):
temp_str = dict_part.replace(r"\'", "<PLACEHOLDER>")
temp_str = temp_str.replace('"', r"\"")
temp_str = temp_str.replace("'", '"')
dict_part = temp_str.replace("<PLACEHOLDER>", "'")
parsed_dict: dict = json.loads(dict_part)
message = parsed_dict.get("message")
return message or str(parsed_dict)

def search(
self,
like_image: Optional[str] = None,
Expand Down
Loading

0 comments on commit 86cd5bf

Please sign in to comment.