-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bulk derivative (big Jacobian) in MLX #1927
Comments
@sck-at-ucy take a look at this discussion #154 Summary there:
import mlx.core as mx
def jacrev(f):
def jacfn(x):
# Needed for the size of the output
y = f(x)
def vjpfn(cotan):
return mx.vjp(f, (x,), (cotan,))[1][0]
return mx.vmap(vjpfn, in_axes=0)(mx.eye(len(y)))
return jacfn
def jacfwd(f):
def jacfn(x):
def jvpfn(tan):
return mx.jvp(f, (x,), (tan,))[1][0]
return mx.vmap(jvpfn, in_axes=0)(mx.eye(len(x)))
return jacfn
def hessian(f):
def hessfn(x):
def hvp(tan):
return mx.jvp(mx.grad(f), (x,), (tan,))[1][0]
return mx.vmap(hvp, in_axes=0)(mx.eye(len(x)))
return hessfn
print(jacrev(mx.sin)(mx.array([1.0, 2.0, 3.0])))
print(jacfwd(mx.sin)(mx.array([1.0, 2.0, 3.0]))) |
This seems like a bug somewhere to me, possibly in MLX. Are you able to share something that reproduces it? |
Yes I would be happy to share the code. Will do that after finishing teaching tonight and cleaning the code. If indeed this is a bug (either mine or possibly in MLX) it would make be so happy because it would remove a main obstacle I have been facing with autograd for PDEs. |
So the code is a bit long. It makes an attempt to use the implementations you suggested above. The motivation for what I am trying to do is to be able to compute partial differential operators over the domain instead of point-wise. This opens up the possibilities to replace the MLP with something a bit more sophisticated. The example I share, however, uses only the MLP to keep things simple. The PDE is for fluid flow, so it involves the Laplacian for the viscous terms. That's where things go wrong. In computing the second derivatives. I might be doing something stupid that I cannot see, but it also looks like there might be an internal MLX issue in how it tries to unify with internal reshape logic during mx.vjp/mx.vmap. import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
#from Lamb_optimizer import Lamb
###############################################################################
# 1) PDE parameters
###############################################################################
r_in = 1.0
r_out= 2.0
L = 30.0
rho = 1.0
nu = 1.0e-2
###############################################################################
# 2) jacrev and double_jacrev with shape checks
###############################################################################
def jacrev(f):
"""
Reverse-mode Jacobian for f: R^n -> R^m => shape(m,n).
"""
def jacfn(x):
print(f"[jacfn] jacrev => x.shape={x.shape}")
y = f(x) # expect shape(m,)
# shape check or debug print
print(f"[jacfn] jacrev => y.shape={y.shape}")
m = y.shape[0]
# we do vmap(...) over an (m,m) identity
cotangent_eye = mx.eye(m) # shape(m,m)
def vjpfn(cotan):
# shape(cotan)= (m,)
# shape result => (n,)
out = mx.vjp(f, (x,), (cotan,))[1][0]
print(f"[jacfn] vjpfn => out.shape={out.shape}")
return out
J = mx.vmap(vjpfn, in_axes=0)(cotangent_eye) # shape(m,n)
# shape check
n = x.shape[0]
print(f"[jacfn] jacrev => J.shape={J.shape}")
return J
return jacfn
def double_jacrev(f):
base_jac = jacrev(f) # first derivative
def h(x):
# shape check in the intermediate function
j = base_jac(x) # shape(m,n) ideally
print(f"[h] base_jac => j.shape={j.shape}")
m, n = j.shape
print(f'm,n: {m} {n}')
out_flat = mx.reshape(j, (m*n,))
print(f"[h] flattened => {out_flat.shape}")
return out_flat
def ddjf(x):
# second derivative
j2 = jacrev(h)(x)
print(f"[ddjf] j2.shape={j2.shape}")
# now we expect j2 => shape(m*n,n). Then reshape => (m,n,n).
m = f(x).shape[0] # e.g. 3*N
n = x.shape[0] # e.g. 4*N
out = mx.reshape(j2, (m,n,n))
print(f"[ddjf] final Hessian => out.shape={out.shape}")
return out
return ddjf
###############################################################################
# 3) MLP model
###############################################################################
class MLPBatchPINN(nn.Module):
def __init__(self, input_dim=4, hidden_dims=[64,64,64], output_dim=3, activation=nn.silu):
super().__init__()
self.layers = []
in_dim = input_dim
for h_dim in hidden_dims:
self.layers.append(nn.Linear(in_dim, h_dim))
in_dim = h_dim
self.out_layer = nn.Linear(in_dim, output_dim)
self.activation = activation
def __call__(self, X):
# X shape => (N,4)
for layer in self.layers:
X = layer(X)
X = self.activation(X)
X = self.out_layer(X)
# shape => (N,3)
print(f'model: {X.shape}')
return X
###############################################################################
# 4) WeightedPINN_Batch with shape checks in field_fn, pde_fn
###############################################################################
class WeightedPINN_Batch(nn.Module):
"""
PDE with second derivatives:
field_fn: (4N,)->(3N,)
jac_fn => shape(3N,4N)
hess_fn => shape(3N,4N,4N)
"""
def __init__(self, core_model, alpha_min=0.1, alpha_max=5.0):
super().__init__()
self.core_model = core_model
self.logit_alpha_cont = mx.array(0.0)
self.logit_alpha_mom = mx.array(0.0)
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.jac_fn = None
self.hess_fn = None
def alpha_cont(self):
return self.alpha_min + (self.alpha_max-self.alpha_min)*mx.sigmoid(self.logit_alpha_cont)
def alpha_mom(self):
return self.alpha_min + (self.alpha_max-self.alpha_min)*mx.sigmoid(self.logit_alpha_mom)
def forward_batch(self, X):
return self.core_model(X) # shape(N,3)
def field_fn(self, x_flat):
"""
x_flat => shape(4N,)
-> (N,4) => forward => (N,3) => flatten => (3N,)
with shape checks
"""
# 1) ensure x_flat is 1D
assert x_flat.ndim == 1, f"[field_fn] x_flat must be 1D, got shape{ x_flat.shape}"
total_size = x_flat.shape[0]
# must be multiple of 4
assert total_size%4==0, f"[field_fn] x_flat size {total_size} not multiple of 4"
N = total_size//4
# reshape => (N,4)
X = mx.reshape(x_flat, (N,4))
# forward => (N,3)
out = self.core_model(X)
print(f"[field_fn] out.shape={out.shape}")
#assert out.shape==(N,3), f"[field_fn] expected (N,3)=({N},3), got {out.shape}"
# flatten => shape(3N,)
out_flat = mx.reshape(out, (3*N,))
print(f"[field_fn] out_flat.shape={out_flat.shape}")
#assert out_flat.shape==(3*N,), f"[field_fn] expected(3N,) => {(3*N,)}, got{out_flat.shape}"
return out_flat
def build_derivatives(self):
"""Construct first & second derivative closures with shape checks."""
self.jac_fn = jacrev(self.field_fn)
self.hess_fn = double_jacrev(self.field_fn)
def pde_fn(self, x_flat):
"""
returns shape(3N,) PDE
shape checks for each step
"""
total_size = x_flat.shape[0]
assert total_size%4==0, f"[pde_fn] x_flat size {total_size} not multiple of 4"
N = total_size//4
# 1) field => shape(3N,)
out_flat = self.field_fn(x_flat)
print(f'out_flat {out_flat.shape}')
assert out_flat.shape==(3*N,), f"[pde_fn] out_flat must be (3N,) => {(3*N,)}, got {out_flat.shape}"
# 2) jac => shape(3N,4N)
J = self.jac_fn(x_flat)
print(f'J {J.shape}')
expected_jac_shape = (3*N,4*N)
assert J.shape==expected_jac_shape, f"[pde_fn] J => expect{expected_jac_shape}, got {J.shape}"
# 3) hess => shape(3N,4N,4N)
H = self.hess_fn(x_flat)
print(f"[pde_fn] H.shape={H.shape}")
expected_hess_shape = (3*N,4*N,4*N)
assert H.shape==expected_hess_shape, f"[pde_fn] H => expect{expected_hess_shape}, got {H.shape}"
# we do PDE logic => continuity + momentum
# i_r_rows => (0,3,6,...)
i_r_rows = mx.arange(0,3*N,3) # shape(N,)
i_z_rows = i_r_rows+1
i_p_rows = i_r_rows+2
r_cols = mx.arange(0,4*N,4) # shape(N,)
z_cols = r_cols +1
# gather field
u_r_val = mx.take(out_flat, i_r_rows, axis=0)
u_z_val = mx.take(out_flat, i_z_rows, axis=0)
r_vals = mx.take(x_flat, r_cols, axis=0)
# gather first derivatives => shape(N,)
# partial_J_ur => shape(N,4N)
partial_J_ur = mx.take(J, i_r_rows, axis=0)
# du_r/dr => shape(N,)
du_r_dr = mx.take_along_axis(partial_J_ur,mx.expand_dims(r_cols,1), axis=1)
du_r_dr = mx.reshape(du_r_dr,(N,))
# du_r/dz
du_r_dz = mx.take_along_axis(partial_J_ur,mx.expand_dims(z_cols,1), axis=1)
du_r_dz = mx.reshape(du_r_dz,(N,))
partial_J_uz= mx.take(J, i_z_rows, axis=0)
du_z_dr= mx.take_along_axis(partial_J_uz,mx.expand_dims(r_cols,1), axis=1)
du_z_dr= mx.reshape(du_z_dr,(N,))
du_z_dz= mx.take_along_axis(partial_J_uz,mx.expand_dims(z_cols,1), axis=1)
du_z_dz= mx.reshape(du_z_dz,(N,))
partial_J_p= mx.take(J, i_p_rows, axis=0)
dp_dr= mx.take_along_axis(partial_J_p,mx.expand_dims(r_cols,1), axis=1)
dp_dr= mx.reshape(dp_dr,(N,))
dp_dz= mx.take_along_axis(partial_J_p,mx.expand_dims(z_cols,1), axis=1)
dp_dz= mx.reshape(dp_dz,(N,))
# gather second derivatives => shape(N,)
partial_H_ur = mx.take(H, i_r_rows, axis=0) # (N,4N,4N)
d2ur_dr2 = gather_2D(partial_H_ur, r_cols, r_cols)
d2ur_dz2 = gather_2D(partial_H_ur, z_cols, z_cols)
partial_H_uz = mx.take(H, i_z_rows, axis=0)
d2uz_dr2 = gather_2D(partial_H_uz, r_cols, r_cols)
d2uz_dz2 = gather_2D(partial_H_uz, z_cols, z_cols)
# Continuity => (1/r)(u_r + r du_r/dr) + du_z/dz
cont = (1.0/r_vals)*(u_r_val + r_vals*du_r_dr) + du_z_dz
# radial => adv + press + nu*(lapl(u_r)-u_r/r^2)
lapl_ur = d2ur_dr2 + (1.0/r_vals)*du_r_dr + d2ur_dz2
minus_ur_r2 = -u_r_val/(r_vals*r_vals)
adv_r = u_r_val*du_r_dr + u_z_val*du_r_dz
press_r= -(1.0/rho)*dp_dr
visc_r = nu*(lapl_ur + minus_ur_r2)
r_mom = adv_r + press_r + visc_r
# axial => adv + press + nu*lapl(u_z)
lapl_uz= d2uz_dr2 + (1.0/r_vals)*du_z_dr + d2uz_dz2
adv_z = u_r_val*du_z_dr + u_z_val*du_z_dz
press_z= -(1.0/rho)*dp_dz
visc_z= nu*lapl_uz
z_mom= adv_z + press_z + visc_z
pde_array= mx.stack([cont, r_mom, z_mom], axis=1)
pde_flat= mx.reshape(pde_array, (3*N,))
print(f'pde_flat {pde_flat.shape}')
return pde_flat
def pde_loss(self, X):
N= X.shape[0]
x_flat= mx.reshape(X,(4*N,))
PDEvals= self.pde_fn(x_flat) # (3N,)
PDEvals_resh= mx.reshape(PDEvals,(N,3))
cont_part= PDEvals_resh[:,0]
r_part = PDEvals_resh[:,1]
z_part = PDEvals_resh[:,2]
cont_loss= mx.mean(cont_part**2)
mom_loss = mx.mean(r_part**2+z_part**2)
return self.alpha_cont()*cont_loss + self.alpha_mom()*mom_loss
def boundary_loss(self, X_bc):
print("[boundary_loss] X_bc.shape=", X_bc.shape)
out_bc = self.core_model(X_bc)
print("[boundary_loss] out_bc.shape=", out_bc.shape)
bc_vals= out_bc[:,0:2] # (u_r,u_z)
return mx.mean(bc_vals**2)
###############################################################################
# 5) gather_2D => shape(N,) picking row_idx[i],col_idx[i] from (N,A,B).
###############################################################################
def gather_2D(tensor_3d, row_idx, col_idx):
N, A, B = tensor_3d.shape
tens_flat= mx.reshape(tensor_3d, (N, A*B))
linear_idx= row_idx*B+ col_idx
lin_idx_2d= mx.expand_dims(linear_idx,1) # shape(N,1)
out_2d= mx.take_along_axis(tens_flat, lin_idx_2d, axis=1)
out_1d= mx.reshape(out_2d,(N,))
return out_1d
###############################################################################
# 6) The Train Function w/ shape checks
###############################################################################
def train_example(num_epochs=200):
#mx.disable_compile()
# MLP
net_core= MLPBatchPINN(input_dim=4, hidden_dims=[64,64,64], output_dim=3)
pinn= WeightedPINN_Batch(net_core)
# build first & second derivatives
pinn.build_derivatives()
# domain
N= 4 # let's do small to avoid big Hessian
r_= mx.random.uniform(r_in, r_out, (N,))
z_= mx.random.uniform(0.0, L, (N,))
X_interior= mx.stack([r_,z_, mx.zeros_like(r_), mx.zeros_like(r_)], axis=1)
print(f'[X_interior] {X_interior.shape}')
M=5
r_bc= mx.full((M,), r_in)
z_bc= mx.random.uniform(0.0, L, (M,))
X_bc= mx.stack([r_bc, z_bc, mx.zeros_like(r_bc), mx.zeros_like(r_bc)], axis=1)
print(f'[X_bc] {X_bc.shape}')
#optimizer= Lamb(weight_decay=0.005, learning_rate=1e-3, eps=1e-12)
optimizer = optim.Adam(learning_rate=1e-3)
# state => [pinn.state, optimizer.state, X_interior, X_bc]
state= [pinn.state, optimizer.state, X_interior, X_bc]
mx.eval(state)
@partial(mx.compile, inputs=state, outputs=state)
def train_step():
pinn_state, opt_state, X_int, X_bc_ = state
print(f'X_int.shape: {X_int.shape}, X_bc.shape {X_bc.shape}')
def loss_fn():
pde_l= pinn.pde_loss(X_int) #<<<--- Commenting out this & returning only '10.0 * bc_l' runs
bc_l= pinn.boundary_loss(X_bc_)
return pde_l + 10.0* bc_l
#return 10.0 * bc_l # you need to also comment out the pde_l line to run w/o issue
loss_val, grads= nn.value_and_grad(pinn, loss_fn)()
optimizer.update(pinn, grads)
new_pinn_state= pinn.state
new_opt_state= optimizer.state
return loss_val, (new_pinn_state, new_opt_state, X_int, X_bc_)
for epoch in range(num_epochs):
loss_val, new_state= train_step()
if (epoch+1)%50==0:
print(f"Epoch {epoch+1}, loss={float(loss_val):.6f}")
mx.eval(new_state)
print("Done!")
return pinn
if __name__=="__main__":
trained= train_example(num_epochs=200)
''' |
I’m developing a physics‐informed neural network (PINN) in MLX that requires computing second derivatives with respect to the input coordinates (e.g.,
(r, z)
across a batch of collocation points. Right now, I’m using the “index trick”: calling mx.grad(...) inside a Python loop for each point and dimension, which works with compilation globally disabled but fails under MLX’s compiled (JIT) mode with “Cannot vjp primitive” or similar errors.
My question: Does MLX have (or plan to add) a “bulk derivative” or “big Jacobian” function, akin to JAX’s jacfwd / jacrev, that computes
in one shot? For example, if
X is (N, d)
and the network outputs Y with shape
(N, k )
we’d want a direct call that returns the Jacobian of all Y entries w.r.t.\ all X entries in a compiled‐friendly manner (without the repeated single‐point loops). That might avoid the “Cannot vjp sum” issues I see, since a single large derivative pass could preserve the compiled graph more cleanly.
In short:
Use case: PDE derivatives in PINNs (or other operator methods) that require partial derivatives wrt each input coordinate across many points.
Problem (suspected): repeated “for i in range(N): mx.grad(...)” triggers “Cannot vjp” or “Not implemented” errors under compiled mode.
Potential solution: a built‐in “bulk” or “batched” derivative function that merges the loops internally and preserves the AD graph for compilation.
Question: Is there a function or planned feature in MLX that does this (or can approximate it)? And if so, is it likely to fix the “Cannot vjp sum” errors we see in compiled mode?
@awni I look forward to your wisdom and advice/(roadmap info) on full PDE autodiff in MLX’s compiled environment.
The text was updated successfully, but these errors were encountered: