-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add trainable theta and euler as discretizer #41
Conversation
7e1b230
to
357044f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished up my fixups, and then one question below.
Summary of changes:
- Simplified the A/B calculations (removed the
_gen_constants
method, and only call_gen_AB
inbuild
rather than incall
) - Incorporated the identity matrix into euler A matrix in the trainable_theta=False case
- Some improvements to the efficiency of
cont2discerete_zoh
implementation - Renamed
train_theta
totrainable_theta
(for consistency with Keras API, e.g.self.theta_inv.trainable
attribute) - Removed A/B caching in zoh training (it had a non-trivial impact on the training speed, and I think in general training speed is more important than inference speed, since users are likely to spend the vast majority of their time in training).
- Stored A/B as constants rather than non-trainable variables. This can offer slight speed improvements. It does mean that you won't be able to load weights from earlier versions, but I think that's fine.
- Disabled FFT if
trainable_theta=True
(this won't work, since the A/B matrices aren't being used incall
) - Added a test to make sure that
zoh
andeuler
produce approximately the same output (otherwise we're only testing euler implementations against other euler implementations, so it's possible for the euler implementation to be completely incorrect but still pass all the tests because it is internally consistent) - Made some simplifications to the other new tests to focus on the aspects being covered in those tests
- Set the minimum TensorFlow version back to 2.1.0 (I'm not sure what wasn't working before, but after I made the other changes everything was OK on 2.1)
- Added a
theta
property to layers for retrieving the (possibly trained) value of theta
I also made some updates to the example notebook (had to update the weights for these changes, and then just noticed some other changes that could be made while I was doing that).
constraint=tf.keras.constraints.NonNeg(), | ||
) | ||
else: | ||
self.theta_inv = tf.constant(1 / self._init_theta, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks here like it will also work when init_theta
is an array instead of a scalar? That would be nice as that's how I've done it elsewhere (at least for euler
; the zoh
method might be too slow this way).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That might work with broadcasting, but it's definitely not tested/supported. We intentionally just left this with scalar thetas for now.
e00c270
to
70d7b69
Compare
Co-authored-by: Daniel Rasmussen <daniel.rasmussen@appliedbrainresearch.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a commit to run the docs/examples builds on remote GPU. The FFT implementation on CPU is really slow (see tensorflow/tensorflow#6541), which was causing the notebook to time out.
I did implement a faster version of the CPU FFT (see 1257a40), but it didn't help enough (it scales with the number of available cores, and we only have two when running on TravisCI). Could be useful for some other reason in the future though, so I saved it in the train_theta2
branch.
Also, in the future we'll probably switched to the convolution-based implementation (see #42), which is much faster on CPU. We could probably switch back to running the build on CPU at that point.
With the builds all passing this LGTM!
What does this PR add?
theta
How is a trainable
theta
implemented?theta_inv
= 1/theta
since that can result in better gradients, if trainingtheta
. If training oftheta
is disabled, we still work withtheta_inv
but it does not get updated.theta
which is then internally inverted totheta_inv
theta_inv
is added as a weight of the layer. If not, then it is added as an attribute of the layer. This distinction is made so that this implementation stays compatible with models that were built with previous versions of keras-lmu (without trainabletheta
).How does training with Euler work?
theta
can be decoupled from theA
andB
matrices when using euler,A
andB
(weights of the layer) are set toCONST_A
andCONST_B
and never updated if training theta.A
andB
are set toCONST_A
theta_inv
andCONST_B
theta_inv
respectively (they are still not updated naturally).call
function implementes the memory update asm = m + theta_inv*(A*m + B*u)
, thus capturing the gradient oftheta_inv
and ensuring that gradients oftheta_inv
are well composed.How does training with Zero Order Hold (zoh) work?
theta
cannot be decoupled fromA
andB
matrices when using zoh. Thus, when training, newA
andB
matrices are generated during thecall
function itself. This will be slower than discretizing with euler._cont2discrete
function for discretizing with zoh has been implemented instead of using the previosuly default implementation fromscipy.signal
. This is becausescipy.signal.cont2discrete
only accepts numpy inputs and nottf.tensors
, which will break the flow of gradients totheta_inv
.Where to start the review?
You can start from the commit of
Add trainable theta and discretization options
and then go toUpdate and add new tests
. These are the only 2 main commits. There is an additional commit but that is a bones update.Any other remarks?
get_config
of eachLMUCell
,LMU
andLMUFFT
seralisestheta_init
as thetheta
parameter and not the final value. Leaving it here to confirm this makes sense.