Skip to content

Commit

Permalink
Fix vpm intg test error (#3554)
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 authored May 28, 2024
1 parent 80acb86 commit 326ec37
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,8 @@ def _customize_outputs(
masks: list[tv_tensors.Mask] = []
scores: list[torch.Tensor] = []
for output in outputs:
masks.append(torch.as_tensor(output["hard_prediction"]))
scores.append(torch.as_tensor(output["scores"]))
masks.append(torch.as_tensor(output["hard_prediction"], device=self.device))
scores.append(torch.as_tensor(output["scores"], device=self.device))

return VisualPromptingBatchPredEntity(
batch_size=len(outputs),
Expand Down Expand Up @@ -1025,17 +1025,21 @@ def _customize_outputs( # type: ignore[override]
tv_tensors.Mask(
torch.stack([torch.as_tensor(m) for m in predicted_mask], dim=0),
dtype=torch.float32,
device=self.device,
),
)
prompts.append(
Points(
torch.stack([torch.as_tensor(p[:2]) for p in used_points[label]], dim=0),
canvas_size=inputs.imgs_info[0].ori_shape,
dtype=torch.float32,
device=self.device,
),
)
scores.append(torch.stack([torch.as_tensor(p[2]) for p in used_points[label]], dim=0))
labels.append(torch.stack([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0))
scores.append(torch.stack([torch.as_tensor(p[2]) for p in used_points[label]], dim=0).to(self.device))
labels.append(
torch.stack([torch.LongTensor([label]) for _ in range(len(scores[-1]))], dim=0).to(self.device),
)

return ZeroShotVisualPromptingBatchPredEntity(
batch_size=len(outputs),
Expand Down

0 comments on commit 326ec37

Please sign in to comment.