Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable smart weight loading #2758

Merged
merged 11 commits into from
Jan 8, 2024
Merged

Conversation

jaegukhyun
Copy link
Contributor

@jaegukhyun jaegukhyun commented Jan 4, 2024

Summary

Enable smart weight loading for all task.

Note

There are some exceptions in copying algorithms for two stage detectors, and ssd-mobilenetv2
They are resolved by next PR

How to test

I added unit tests for smart weight loading algo for general cases

Checklist

  • I have added unit tests to cover my changes.​
  • I have added integration tests to cover my changes.​
  • I have added e2e tests for validation.
  • I have added the description of my changes into CHANGELOG in my target branch (e.g., CHANGELOG in develop).​
  • I have updated the documentation in my target branch accordingly (e.g., documentation in develop).
  • I have linked related issues.

License

  • I submit my code changes under the same Apache License that covers the project.
    Feel free to contact the maintainers if that's a concern.
  • I have updated the license header for each file (see an example below).
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

@github-actions github-actions bot added the TEST Any changes in tests label Jan 4, 2024
@jaegukhyun jaegukhyun changed the title Enable smart weight loading for detection Enable smart weight loading Jan 4, 2024
@jaegukhyun jaegukhyun marked this pull request as ready for review January 4, 2024 07:49
Copy link
Contributor

@vinnamkim vinnamkim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have questions:

  1. What is the smart weight loading?
  2. Which does an user scenario this smart weight loading cover?

@jaegukhyun
Copy link
Contributor Author

I have questions:

  1. What is the smart weight loading?
  2. Which does an user scenario this smart weight loading cover?

Smart weight loading is weight loading method for class incremental learning.
When user trained their model on dataset whose classes are ["car", "bus"]
And user wants to train their trained model with the new dataset whose classes are ["car", "bus", "truck"]
In this case, without smart weight loading, the final classification layer will be initialized since dimension of final classification layer between checkpoint and model are different.
Here smart weight loading will load weights from checkpoint for previous classes and initialize weights for newly added classes.

@vinnamkim
Copy link
Contributor

vinnamkim commented Jan 4, 2024

Smart weight loading is weight loading method for class incremental learning. When user trained their model on dataset whose classes are ["car", "bus"] And user wants to train their trained model with the new dataset whose classes are ["car", "bus", "truck"]

It seems that we should add a dataset-level meta info to OTXModel. Let's call it as PredictionMetaInfo (you can name a more appropriate one for it, shorter one) such as

class PredictionMetaInfo:
     pass

@dataclasses
class MultiClassClsPredictionMetaInfo(PredictionMetaInfo):
     num_classes: int
     class_names: list[str]

Then, assume OTXModel should be created by OTXModel(prediction_meta_info=PredictionMetaInfo).

Now, you can handle weight loading in class incremental learning scenarios by

ckpt = old_model.state_dict()

assert "prediction_meta_info" in ckpt

new_model = OTXModel(prediction_meta_info=new_prediction_meta_info)

new_model.load_state_dict(ckpt) # Smart weight loading can be done

class OTXModel(nn.Module):
     ...

     def load_state_dict(self, ckpt):
           old_info = ckpt["prediction_meta_info"]
           new_info = self.prediction_meta_info
           
           if old_info != new_info:
               self._load_state_dict_incrementally(ckpt)
           else:
               super().load_state_dict(ckpt)

     def _load_state_dict_incrementally(self, ckpt):
           ...

     def state_dict(self):
           ckpt = super().state_dict()
           ckpt["prediction_meta_info"] = asdict(self.prediction_meta_info)

This requirement is not only observed in this PR, but also @sungmanc's today presentation,
image
This HLabelInfo includes the dataset-level meta information, but is passed in each data sample. In addition, it is related to ModelAPI, so that it is not just required for H-label but for all other tasks.

p.s. Is it fundamentally possible to solve the following problem by implementing own _load_state_dict_incrementally() for each OTXModel? We can have some general solutions at some model classes, but it is not possible to infer what is the head layer of the model by machine automatically.

There are some exceptions in copying algorithms for two stage detectors, and ssd-mobilenetv2. They are resolved by next PR

@sungmanc
Copy link
Contributor

sungmanc commented Jan 5, 2024

It seems that we should add a dataset-level meta info to OTXModel. Let's call it as PredictionMetaInfo (you can name a more appropriate one for it, shorter one)

Sounds good to me, if we can feed a dataset-lele meta info to the ModelEntity, it's great

@jaegukhyun
Copy link
Contributor Author

jaegukhyun commented Jan 5, 2024

@vinnamkim @sungmanc
I'll update codes from updating meta info from dataset to updating meta info from model

@jaegukhyun jaegukhyun closed this Jan 5, 2024
src/otx/core/engine/train.py Outdated Show resolved Hide resolved
src/otx/core/engine/train.py Outdated Show resolved Hide resolved
jaegukhyun and others added 2 commits January 5, 2024 15:36
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
sungmanc
sungmanc previously approved these changes Jan 5, 2024
src/otx/core/engine/train.py Outdated Show resolved Hide resolved
src/otx/core/model/module/base.py Show resolved Hide resolved
src/otx/core/model/entity/base.py Outdated Show resolved Hide resolved
@sungmanc sungmanc mentioned this pull request Jan 5, 2024
8 tasks
sungmanc
sungmanc previously approved these changes Jan 8, 2024
@jaegukhyun jaegukhyun requested a review from sungmanc January 8, 2024 06:43
@jaegukhyun jaegukhyun merged commit a48426c into openvinotoolkit:v2 Jan 8, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
TEST Any changes in tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants