Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inaccurate Autograd Gradients #2042

Closed
WillE9a opened this issue Oct 25, 2024 · 10 comments
Closed

Inaccurate Autograd Gradients #2042

WillE9a opened this issue Oct 25, 2024 · 10 comments
Assignees

Comments

@WillE9a
Copy link

WillE9a commented Oct 25, 2024

Description
The grad values from autograd are quite different from those produced by JAX.

From Autograd:
image

From Adjoint plugin with JAX:
image

Simulation Setup
I am using tidy3d version 2.7.5, running the gradient checking example notebooks comparing the JAX version to the autograd version. Note: I've modified some simulation parameters such that they are identical (e.g., so the TMM and FDTD transmissions are the same between notebooks).

Other information
Similar issue: I've converted the JAX-based topology optimization notebook here (https://www.flexcompute.com/tidy3d/examples/notebooks/AdjointPlugin12LightExtractor/) to an autograd-compatible version and have seen some strange behavior in the optimization. Shown below is the first iteration showing very similar simulation results but large difference in the gradient. I am using tidy3d's value_and_grad(obj_func, has_aux=True) so I can access objective function components separately.

Using autograd:
image
image

Using adjoint plugin with JAX:
image
image

The optimizers and input parameters are identical in these two optimizations. Any ideas what might be responsible for these differences? As long as the gradient directions are the same I suppose the learning_rate could be changed to adjust the step size taken in the gradient ascent/descent. Is the autograd functionality still considered "experimental" in 2.7.5? It seems like there are several working examples for autograd-based topology optimization, so any insight is appreciated.

@tylerflex
Copy link
Collaborator

Hi @WillE9a , thanks for bringing this up.

Regarding the gradient accuracy checking notebook, this was indeed an issue that I investigated earlier this week and last week. I made a PR #2020 to fix this, which has been merged. It's available in the develop branch but not released in 2.7.5. We'll release 2.7.6 soon.

The gradient was off due to some changes in how we calculated the inside and outside permittivity of the Box, which turned out to have some small issue meaning that the derivative of the objective w.r.t. the box thickness was not properly computed when the Box was next to another medium (like in that notebook). Effectively it was not sampling the permittivity far enough away from the interface and was therefore catching some of the subpixelled permittivity values. This only affects Box objects in non-uniform backgrounds.

You'll also notice that the gradient accuracy in the autograd version of that notebook now are about 1-2% relative error. Which is a little bit higher than the jax version. The reason for this is simply a difference in how we do the numerical surface integration between the two versions. The jax version is a bit more accurate, technically, but the autograd version is much faster. The error is a bit higher but still low enough for good optimization, and the extra performance makes a huge difference for doing optimizations with several Box structures, which can get slow and ultimately limit the usefulness of the tool.

To get the changes, feel free to either run from the develop branch or wait until 2.7.6 release.

As for the light extractor notebook, it's interesting that the gradient norm is so high.. I bet that's the source of the decreasing coupling efficiency. Does it work with a very low step size? If you wouldn't mind sending me the notebook(s), I could investigate further. We don't notice issues with our other topology optimization notebooks, including a new one added this week (https://docs.flexcompute.com/projects/tidy3d/en/latest/notebooks/Autograd18TopologyBend.html).

By the way, we had it on our "to-do" list to convert this notebook over, would you be ok if I published the converted version (with some modifications?)

As a side note: the autograd and jax versions have diverged in implementation of the gradient calculation in many respects. We made choices in the autograd version that improve performance, which sometimes result in small decrease in accuracy. It's also worth noting that we've been continuously improving the autograd version for the past several months, whereas the jax version has been deprecated so we have been leaving it alone for the most part as changing anything is cumbersome.

Is the autograd functionality still considered "experimental" in 2.7.5? It seems like there are several working examples for autograd-based topology optimization, so any insight is appreciated.

Our plan is for autograd to be no longer experimental as of 2.7.6. In that release we will introduce these gradient fixes that I mentioned, plus field projection support and a greatly improved performance and usability for DataArray handling. Of course, if you still encounter issues please let me know, but I think it should be considered stable after this.

Thanks for catching these issues, I'd recommend trying with the develop branch to see if one or both of them resolve. And will do my best to continue improving things in the meantime. I'll leave this issue open for a few days, but would like to test the light extractor when I get a chance.

@WillE9a
Copy link
Author

WillE9a commented Oct 26, 2024

Thanks for the quick reply! I'll play around with the develop branch for some optimizations to see if that resolves these issues.

Something I forgot to mention, the autograd version of the light extractor ended up being roughly 3x slower than the original JAX one, which might indicate where something could be going wrong. For reference, the original JAX optimization took ~3 hours for 100 iterations while the autograd equivalent took ~9 hours. The changelog mentioned an improvement to the gradient calculation efficiency in v2.7.5, and the server simulation times seemed reasonably similar between JAX/autograd, so this is surprising.

Since things seem to be working for that waveguide bend optimization, could the issue be related to the adjoint simulation set up? (e.g., something related to the custom current (adjoint) sources created from the field monitors that were included for calculate the varying dipole power). I might just have to play with the step size and binarization scheme some more, since even in the JAX optimization the objective function starts to decrease toward the end.

And yes, the more examples we can include for inverse design the better! I have attached my versions of the quantum light extractor notebook along the autograd optimization history. There is some extra, unorganized analysis post-optimization that you can use and modify as you see fit.
opt_data.zip

I appreciate the help looking into this. Let me know if you discover anything (even it's just a simple user error), or need some more details (e.g., package versions, local hardware, etc...).

@tylerflex
Copy link
Collaborator

Something I forgot to mention, the autograd version of the light extractor ended up being roughly 3x slower than the original JAX one, which might indicate where something could be going wrong. For reference, the original JAX optimization took ~3 hours for 100 iterations while the autograd equivalent took ~9 hours. The changelog mentioned an improvement to the gradient calculation efficiency in v2.7.5, and the server simulation times seemed reasonably similar between JAX/autograd, so this is surprising.

Hm, so this is something that I think could be resolved in develop after the DataArray handling PR was merged #2025. I will investigate on your notebook.

Since things seem to be working for that waveguide bend optimization, could the issue be related to the adjoint simulation set up? (e.g., something related to the custom current (adjoint) sources created from the field monitors that were included for calculate the varying dipole power). I might just have to play with the step size and binarization scheme some more, since even in the JAX optimization the objective function starts to decrease toward the end.

I'm not really sure to be honest. @e-g-melo made this example, maybe he has thoughts, but either way I will take a look.

And yes, the more examples we can include for inverse design the better! I have attached my versions of the quantum light extractor notebook along the autograd optimization history. There is some extra, unorganized analysis post-optimization that you can use and modify as you see fit.
opt_data.zip

Thanks a lot!

I appreciate the help looking into this. Let me know if you discover anything (even it's just a simple user error), or need some more details (e.g., package versions, local hardware, etc...).

no problem, I'll comment here when I have time to investigate more. In the meantime, in the interest of organization, I'll close this PR for now feel free to comment or re-open if after running from develop things do not resolve.

@tylerflex
Copy link
Collaborator

@WillE9a , notebook was added flexcompute/tidy3d-notebooks#185

I timed it and it's 2.5 minutes / iteration using the latest frontend code. Objective function staying steady for now.

Note: I needed to use branch yaugenst-flex/xarray-versions, which fixes one remaining DataArray issue but is not merged into develop yet. corresponds to this PR #2041

@tylerflex
Copy link
Collaborator

@WillE9a I'm running the notebook again, it turns out that the learning rate was too high:

getting good results with learning_rate = 0.02 whereas the original was leading to some strong oscillations at the end of the optimization, which you observed.

I'm still wrapping up the final few iterations but will update the notebook on tidy3d-notebooks branch when Im done

@WillE9a
Copy link
Author

WillE9a commented Oct 28, 2024

@tylerflex, Perfect, just commenting with some of my own updates. Interestingly enough, 0.02 was the learning rate I settled on as well! Are the gradient norms you're seeing still on the 1e2-1e3 scale? I'm still seeing these large values after running a few iterations from the develop branch. The waveguide bend optimization as well as the gradient checking notebook (after PR #2020) share similar gradients between JAX and autograd. So, I am still a little suspicious optimizations involving electric dipoles and/or the field monitors required for FOM normilization. For the time being though, it looks like the signs are correct so just adjusting the step size or scaling the gradient works to get devices with good performance.

Also, The yaugenst-flex/xarray-versions (now deleted with changes merged into the develop branch) sped up the iteration run-time by about 2x and is consistent with your 2.5min/iteration rate!

Thanks for the updates!

@tylerflex
Copy link
Collaborator

Here were my final results, i updated the PR I referenced before on the notebooks, which has the pickle file if you want to save the credits.

image image

@WillE9a
Copy link
Author

WillE9a commented Oct 28, 2024

This looks great! I think my suspicions were ultimately un-warranted given that the adjoint plugin is deprecated and is expected to be missing some changes/corrections that the autograd code received. Glad things are working well, excited to resume inverse designing.

Thanks again for the help!

@e-g-melo
Copy link
Collaborator

Wow.. the optimization seems much better. A 0.96 coupling efficiency is exciting for this kind of problem.

@tylerflex
Copy link
Collaborator

tylerflex commented Oct 29, 2024

I think ultimately what happened is that the jax gradients were about a factor of 30 too small (based on my numerical tests recently). so the old notebook was running fine with the old learning rate of 0.1. After I corrected this, the new gradients (1e2 range) were correct, but the learning rate was simply too high now for this problem.

Glad it worked out and the optimization curve ended up much better than even before with jax so that's promising!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants