-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
t(emporal)PARAFAC2's temporal smoothness penalty #4
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cchatzis,
Thank you for the PR! This looks like a great contribution -- a perfect fit for MatCoupLy!
I made a couple of small suggestions, including one that avoids a nested for loop for the A-assembly, which could speed it up for large I
.
It would also be great with a test for this penalty. To add this, you need to add a test class that inherits from matcouply.testing.BaseTestFactorMatricesPenalty
(it should already be imported in tests/test_penalties.py
) and add three methods: test_penalty
, get_invariant_matrices
and get_non_invariant_matrices
:
test_penalty
should test that the penalty is computed correctly, either based on a hard-coded example where you know what the penalty should be, or by computing it completely manually in the testget_invariant_matrices(rng, shapes)
should generate a random list of matrices that won't be changed by the proximal operator (I guess it should be a stack of the same matrix repeatedlen(shapes)
times?).get_non_invariant_matrices(rng, shapes)
should generate a random list of matrices that will be changed by the proximal operator (I guess it could just be a list of random matrices).
It might be useful to set the class attribute min_rows = max_rows = 10
or something to ensure that all shapes in the shapes
list are the same. The tests for the GeneralizedL2Penalty
might also be a useful place to start: https://github.com/MarieRoald/matcouply/blob/main/tests/test_penalties.py#L247
Don't be afraid to ask if you have any other questions or if anything I said wasn't clear :)
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Thank you very much for the thorough suggestions and help @MarieRoald! I think I have addressed all of the raised issues and I added a small test Kindly let me know what you think! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really good, just a couple of minor comments. The PyTorch are failing, but addressing the comments will hopefully make them pass :)
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Co-authored-by: Marie Roald <roald.marie@gmail.com>
Hi, are there any updates on this? Thanks in advance :) |
Looks good, but we need to know if the tests run. For some reason, the CI/CD pipeline doesn't want to trigger, so I made some changes to the I tried to push the changes directly on top of your fork, but I wasn't able to. So I think that either, you need to merge main into your feature branch or make it so I can push to your branch. To merge, I think it should work if you:
|
Implementation of temporal smoothness penalty, as presented in:
This PR essentially adds the custom penalty class, (
TemporalSmoothnessPenalty
) topenalties.py
.