Skip to content

Commit

Permalink
feat: improve /image_predictions route (#1489)
Browse files Browse the repository at this point in the history
* fix: set better default for with_logo param in /image_predictions route

* feat: allow to filter by image_id in /api/v1/image_predictions
  • Loading branch information
raphael0202 authored Dec 6, 2024
1 parent 33645a6 commit 63add1f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
23 changes: 14 additions & 9 deletions doc/references/api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -516,27 +516,32 @@ paths:
- $ref: "#/components/parameters/server_type"
- $ref: "#/components/parameters/barcode_query_filter"
- name: with_logo
description: if True, only return image predictions that have associated logos (only valid for universal-logo-detector image predictions)
description: if True, only return image predictions that have associated logos
(only valid for universal-logo-detector image predictions). If false, only return image predictions that have no associated logos.
Otherwise, return all image predictions.
in: query
schema:
type: boolean
default: false
default: null
nullable: true
- name: model_name
description: filter by name of the image predictor model
in: query
schema:
type: string
enum:
- universal-logo-detector
- nutrition-table
- nutriscore
example: universal-logo-detector
- name: image_id
description: filter by image ID. It should be a digit (raw images only), otherwise no result will be returned.
in: query
schema:
type: string
example: 1
- name: type
description: filter by type of the image predictor model, currently only 'object_detection'
description: filter by type of the image predictor model
in: query
schema:
type: string
enum:
- 'object_detection'
example: object_detection
- name: model_version
description: filter by model version value
in: query
Expand Down
14 changes: 8 additions & 6 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,18 +834,20 @@ class ImagePredictionResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
count: int = req.get_param_as_int("count", min_value=1, default=25)
page: int = req.get_param_as_int("page", min_value=1, default=1)
with_logo: Optional[bool] = req.get_param_as_bool("with_logo", default=False)
model_name: Optional[str] = req.get_param("model_name")
type_: Optional[str] = req.get_param("type")
model_version: Optional[str] = req.get_param("model_version")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
min_confidence: Optional[float] = req.get_param_as_float("min_confidence")
with_logo: bool | None = req.get_param_as_bool("with_logo", default=None)
model_name: str | None = req.get_param("model_name")
type_: str | None = req.get_param("type")
model_version: str | None = req.get_param("model_version")
barcode: str | None = normalize_req_barcode(req.get_param("barcode"))
image_id: str | None = req.get_param("image_id")
min_confidence: float | None = req.get_param_as_float("min_confidence")
server_type = get_server_type_from_req(req)

get_image_predictions_ = functools.partial(
get_image_predictions,
with_logo=with_logo,
barcode=barcode,
image_id=image_id,
type=type_,
server_type=server_type,
model_name=model_name,
Expand Down
20 changes: 12 additions & 8 deletions robotoff/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,21 +303,25 @@ def get_predictions(

def get_image_predictions(
server_type: ServerType,
with_logo: Optional[bool] = False,
barcode: Optional[str] = None,
type: Optional[str] = None,
model_name: Optional[str] = None,
model_version: Optional[str] = None,
min_confidence: Optional[float] = None,
offset: Optional[int] = None,
with_logo: bool | None = False,
barcode: str | None = None,
type: str | None = None,
model_name: str | None = None,
model_version: str | None = None,
min_confidence: float | None = None,
image_id: str | None = None,
offset: int | None = None,
count: bool = False,
limit: Optional[int] = None,
limit: int | None = None,
) -> Iterable[ImagePrediction]:
query = ImagePrediction.select()

query = query.switch(ImagePrediction).join(ImageModel)
where_clauses = [ImagePrediction.image.server_type == server_type.name]

if image_id is not None:
where_clauses.append(ImagePrediction.image.image_id == image_id)

if barcode is not None:
where_clauses.append(ImagePrediction.image.barcode == barcode)

Expand Down

0 comments on commit 63add1f

Please sign in to comment.