diff --git a/extensions/tokenizers/rust/src/ndarray/reduce.rs b/extensions/tokenizers/rust/src/ndarray/reduce.rs index 1c5739a44bb..d35bb97baa8 100644 --- a/extensions/tokenizers/rust/src/ndarray/reduce.rs +++ b/extensions/tokenizers/rust/src/ndarray/reduce.rs @@ -31,10 +31,17 @@ pub extern "system" fn Java_ai_djl_engine_rust_RustLibrary_sumWithAxis<'local>( keep_dims: jboolean, ) -> jlong { let tensor = cast_handle::(handle); + let rank = tensor.shape().rank() as i32; let axes = unsafe { env.get_array_elements(&axes, ReleaseMode::NoCopyBack) }.unwrap(); let dims = axes .into_iter() - .map(|i| *i as usize) + .map(|i| { + let mut dim = *i as i32; + if dim < 0 { + dim = rank + dim; + } + return dim as usize; + }) .collect::>(); let ret = if keep_dims == JNI_TRUE {