Skip to content

Commit

Permalink
Modify label info comparison (openvinotoolkit#3442)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunchu authored May 3, 2024
1 parent 1d529f7 commit 8f0dd3f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def test(
model_cls = self.model.__class__
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint)

if model.label_info != self.datamodule.label_info:
if model.label_info.as_dict() != self.datamodule.label_info.as_dict():
msg = (
"To launch a test pipeline, the label information should be same "
"between the training and testing datasets. "
Expand Down Expand Up @@ -452,7 +452,7 @@ def predict(
model_cls = self.model.__class__
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint)

if model.label_info != self.datamodule.label_info:
if model.label_info.as_dict() != self.datamodule.label_info.as_dict():
msg = (
"To launch a predict pipeline, the label information should be same "
"between the training and testing datasets. "
Expand Down Expand Up @@ -691,7 +691,7 @@ def explain(
model_cls = model.__class__
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint)

if model.label_info != self.datamodule.label_info:
if model.label_info.as_dict() != self.datamodule.label_info.as_dict():
msg = (
"To launch a explain pipeline, the label information should be same "
"between the training and testing datasets. "
Expand Down

0 comments on commit 8f0dd3f

Please sign in to comment.