-
Notifications
You must be signed in to change notification settings - Fork 446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable smart weight loading #2758
Conversation
b751c81
to
57f18cf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have questions:
- What is the smart weight loading?
- Which does an user scenario this smart weight loading cover?
Smart weight loading is weight loading method for class incremental learning. |
It seems that we should add a dataset-level meta info to class PredictionMetaInfo:
pass
@dataclasses
class MultiClassClsPredictionMetaInfo(PredictionMetaInfo):
num_classes: int
class_names: list[str] Then, assume Now, you can handle weight loading in class incremental learning scenarios by ckpt = old_model.state_dict()
assert "prediction_meta_info" in ckpt
new_model = OTXModel(prediction_meta_info=new_prediction_meta_info)
new_model.load_state_dict(ckpt) # Smart weight loading can be done
class OTXModel(nn.Module):
...
def load_state_dict(self, ckpt):
old_info = ckpt["prediction_meta_info"]
new_info = self.prediction_meta_info
if old_info != new_info:
self._load_state_dict_incrementally(ckpt)
else:
super().load_state_dict(ckpt)
def _load_state_dict_incrementally(self, ckpt):
...
def state_dict(self):
ckpt = super().state_dict()
ckpt["prediction_meta_info"] = asdict(self.prediction_meta_info) This requirement is not only observed in this PR, but also @sungmanc's today presentation, p.s. Is it fundamentally possible to solve the following problem by implementing own
|
Sounds good to me, if we can feed a dataset-lele meta info to the |
@vinnamkim @sungmanc |
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
Summary
Enable smart weight loading for all task.
Note
There are some exceptions in copying algorithms for two stage detectors, and ssd-mobilenetv2
They are resolved by next PR
How to test
I added unit tests for smart weight loading algo for general cases
Checklist
License
Feel free to contact the maintainers if that's a concern.