Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected gradient shape #38

Closed
malu opened this issue Dec 28, 2020 · 3 comments
Closed

Unexpected gradient shape #38

malu opened this issue Dec 28, 2020 · 3 comments
Labels

Comments

@malu
Copy link
Contributor

malu commented Dec 28, 2020

The following code panics with ndarray: could not broadcast array from shape: [2] to: [].

use autograd::ndarray::{Array, IxDyn};
use autograd::optimizers::adam::{Adam, AdamState};
use autograd::tensor::{Constant, Variable};
use std::sync::Arc;
use std::sync::RwLock;

fn main() {
    let v: Arc<RwLock<Array<f64, IxDyn>>> = autograph::ndarray_ext::into_shared(autograph::array_gen::zeros(&[]));
    let adam_state = AdamState::new(&[&v]);
    let adam = Adam::default();

    with(|graph| {
        let c = graph.constant(autograph::array_gen::ones(&[2]));
        let v = graph.variable(v.clone());

        let y = graph.reduce_sum_to_scalar(c * v);
        let grads = graph.grad(&[y], &[v]);

        let updates = adam.compute_updates(&[v], &grads, &adam_state, graph);
        graph.eval(&updates, &[]);
    })
}

It seems like grads[0] has shape [2], although y and v are scalars. I tried to find out why but didn't find an answer.

Is this a bug? Is there a workaround?

If c * v is replaced by c + v it doesn't panic, by the way.

Thanks in advance for any help you can provide!

@raskr raskr added the bug label Dec 30, 2020
@raskr
Copy link
Owner

raskr commented Dec 31, 2020

@malu Sorry this is exactly a bug, and I made a patch in bingrad-hotfix branch. Could you test this?

[dependencies]
autograd = { git = "https://github.com/raskr/rust-autograd", branch = "bingrad-hotfix" }

@malu
Copy link
Contributor Author

malu commented Dec 31, 2020

It works!

Thanks for checking this 👍

@raskr
Copy link
Owner

raskr commented Dec 31, 2020

Good to hear that!
v1.1.0 will include this fix.

@raskr raskr closed this as completed Dec 31, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants