-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/learner updates #184
Conversation
Recompiles the Keras model. This way the optimizer history get erased, | ||
which is needed before a new training round, otherwise the outdated history is used. | ||
""" | ||
compile_args = self.model._get_compile_args() # pylint: disable=protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not find any better way of resetting the optimizer, than re-compiling the mode, which needs calling this protected function.
When I simply replace the optimizer
new_opt = keras.optimizers.SGD(lr=0.1)
model.optimizer = new_opt
then I cannot change batch size anymore, and also the optimizer seems not to be updated, the model.optimizer.get_weights()
should store the number of steps and the momentums, but it remains an empty list this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe name the function reset_optimizer
to better reflect why we have this and when we want to use it
Recompiles the Keras model. This way the optimizer history get erased, | ||
which is needed before a new training round, otherwise the outdated history is used. | ||
""" | ||
compile_args = self.model._get_compile_args() # pylint: disable=protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe name the function reset_optimizer
to better reflect why we have this and when we want to use it
Two changes:
1.) The
mli_get_current_weights
returns all trainable and non-trainable parameters. Previously only the trainable parameters were returned for PyTochLearner, but normalization layers have non-trainable parameters, which are needed too to reproduce the model performance. See running statistics here.2.) Erase optimizer history at the beginning of the
mli_propose_weights
, which is more general, as history gets outdated when others improve the model.The learner test files were also modified, accordingly.