Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine #2900: export with XAI #2905

Closed
wants to merge 13 commits into from

Conversation

sovrasov
Copy link
Contributor

Summary

This PR extends #2900
The main rationale to these changes is that we need to have a single mechanism for introducing XAI into the model.
Custom forward only for export adds the second xai forward mechanism which is hard to test.

How to test

Checklist

  • I have added unit tests to cover my changes.​
  • I have added integration tests to cover my changes.​
  • I have added e2e tests for validation.
  • I have added the description of my changes into CHANGELOG in my target branch (e.g., CHANGELOG in develop).​
  • I have updated the documentation in my target branch accordingly (e.g., documentation in develop).
  • I have linked related issues.

License

  • I submit my code changes under the same Apache License that covers the project.
    Feel free to contact the maintainers if that's a concern.
  • I have updated the license header for each file (see an example below).
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

self.hook = hook
self.hook.reset()

def forward(self, x: torch.Tensor) -> tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we move this as global forward function for the entire ExplainableOTXClsModel?
We can call super().forward() and then check whether we have explain hook or not to return just logits or tuple consist of logits and saliency maps.
I just don't like much this wrapper class which substitute our OTXModel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final mechanism of obtaining the saliency maps is still not clear. Actually, we can already access the hook in _customize_outputs, the problem is what to do with the raw saliency maps. This wrapper doesn't look as a comprehensive replacement to forward, it's rather a workaround if batch is 1.

Copy link
Contributor Author

@sovrasov sovrasov Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if we look in general, the whole idea of hooks is that we don't change anything in the forward itself, while can extend it with an arbitrary number of hooks.
If we want a custom forward, we don't need hooks, but with hooks we have to generate a custom forward for export, since torch tracer is quite limited in terms of functionality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2870 (comment)
#2900 (comment)
Like these comments, I think that we should introduce a dedicated forward() function for XAI at this time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinnamkim does that mean we have to remove the hook? If we have a standalone forward, keeping the hooks makes no sense, like I described in the previous comment.

Copy link
Contributor

@vinnamkim vinnamkim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinnamkim does that mean we have to remove the hook? If we have a standalone forward, keeping the hooks makes no sense, like I described in the previous comment.

Yes

I guess you guys can move forward gradually. For example,

  1. Introduce forward_explain() with the hook
def forward_explain(self, inputs) -> BatchPredictionWithXAI:
     handle = self.register_explain_hook()
     outputs = self.forward(inputs)
     augmented_outputs = augment_with_xai_outputs(outputs, handle.xai_outputs)
     handle.remove()
     return augmented_outputs
  1. Remove using the hook from forward_explain()

src/otx/core/model/entity/classification.py Outdated Show resolved Hide resolved
@@ -367,13 +367,15 @@ def export(
checkpoint: str | Path | None = None,
export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO,
export_precision: OTXPrecisionType = OTXPrecisionType.FP32,
dump_auxiliaries: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the naming. Maybe "dump_features" or "dump_saliency_maps", or just "explain"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from the original implementation, I also prefer *explain, since feature vector is something unclear and required only because we have Geti integration requirements.
@negvet what do you think?

@sovrasov sovrasov changed the title Refine Refine #2900: export with XAI Feb 12, 2024
@sovrasov sovrasov closed this Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
TEST Any changes in tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants