Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

version.model should return None if no model #276

Merged
merged 9 commits into from
Jul 2, 2024
10 changes: 9 additions & 1 deletion roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ def __init__(

version_without_workspace = os.path.basename(str(version))

if self.type == TYPE_OBJECT_DETECTION:
version_info = requests.get(f"{API_URL}/{workspace}/{project}/{self.version}?api_key={self.__api_key}")

# check if version has a model
if version_info.status_code == 200:
version_info = version_info.json()["version"]

if ("models" in version_info) and (not version_info["models"]):
self.model = None
elif self.type == TYPE_OBJECT_DETECTION:
self.model = ObjectDetectionModel(
self.__api_key,
self.id,
Expand Down
54 changes: 53 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def setUp(self):
},
"versions": [
{
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/2",
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/{PROJECT_VERSION}",
"name": "augmented-416x416",
"created": 1663104679.539,
"images": 240,
Expand Down Expand Up @@ -158,6 +158,58 @@ def setUp(self):
status=200,
)

# Get version
responses.add(
responses.GET,
f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/{PROJECT_VERSION}?api_key={ROBOFLOW_API_KEY}",
json={
"workspace": {"name": WORKSPACE_NAME, "url": WORKSPACE_NAME, "members": 1},
"project": {
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}",
"type": "object-detection",
"name": "Hard Hat Sample",
"created": 1593802673.521,
"updated": 1663269501.654,
"images": 100,
"unannotated": 3,
"annotation": "Workers",
"versions": 2,
"public": False,
"splits": {"test": 10, "train": 70, "valid": 20},
"colors": {
"person": "#FF00FF",
"helmet": "#C7FC00",
"head": "#8622FF",
},
"classes": {"person": 9, "helmet": 287, "head": 90},
},
"version": {
"id": f"{WORKSPACE_NAME}/{PROJECT_NAME}/{PROJECT_VERSION}",
"name": "augmented-416x416",
"created": 1663104679.539,
"images": 240,
"splits": {"train": 210, "test": 10, "valid": 20},
"generating": False,
"progress": 1,
"preprocessing": {
"resize": {"height": "416", "enabled": True, "width": "416", "format": "Stretch to"},
"auto-orient": {"enabled": True},
},
"augmentation": {
"blur": {"enabled": True, "pixels": 1.5},
"image": {"enabled": True, "versions": 3},
"rotate": {"degrees": 15, "enabled": True},
"crop": {"enabled": True, "percent": 40, "min": 0},
"flip": {"horizontal": True, "enabled": True, "vertical": False},
},
"exports": [],
"models": {},
"classes": [],
},
},
status=200,
)

# Upload image
responses.add(
responses.POST,
Expand Down
5 changes: 1 addition & 4 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,5 @@ def test_version_fields(self):
@ordered
def test_version_methods(self):
self.assertTrue(
(
isinstance(self.version.model, ClassificationModel)
or (isinstance(self.version.model, ObjectDetectionModel))
)
self.version.model is None or isinstance(self.version.model, (ClassificationModel, ObjectDetectionModel))
)
Loading