diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 1941bbc2..46a76103 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -26,7 +26,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-2019, macOS-latest] + os: [ubuntu-latest, macOS-latest] rust: [stable, nightly] steps: - uses: actions/checkout@v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 15a30276..a5062c0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Automated conversion of list arguments: all the generated functions that take as input a slice of int or float can now be used directly with int values or fixed length arrays [682](https://github.com/LaurentMazare/tch-rs/pull/682). +- Replace the `From` traits with some `TryFrom` versions, + [683](https://github.com/LaurentMazare/tch-rs/pull/683). This is a breaking + change, note that also the old version would flatten the tensor if needed to + reduce the number of dimensions, this has to be done explicitely with the new + version. ## v0.11.0 - 2023-03-20 ### Added diff --git a/examples/char-rnn/main.rs b/examples/char-rnn/main.rs index 5862b380..d9020421 100644 --- a/examples/char-rnn/main.rs +++ b/examples/char-rnn/main.rs @@ -31,7 +31,7 @@ fn sample(data: &TextData, lstm: &LSTM, linear: &Linear, device: Device) -> Stri .squeeze_dim(0) .softmax(-1, Kind::Float) .multinomial(1, false); - last_label = i64::from(sampled_y); + last_label = i64::try_from(sampled_y).unwrap(); result.push(data.label_to_char(last_label)) } result @@ -58,7 +58,7 @@ pub fn main() -> Result<()> { .view([BATCH_SIZE * SEQ_LEN, labels]) .cross_entropy_for_logits(&ys.to_device(device).view([BATCH_SIZE * SEQ_LEN])); opt.backward_step_clip(&loss, 0.5); - sum_loss += f64::from(loss); + sum_loss += f64::try_from(loss)?; cnt_loss += 1.0; } println!("Epoch: {} loss: {:5.3}", epoch, sum_loss / cnt_loss); diff --git a/examples/custom-optimizer/main.rs b/examples/custom-optimizer/main.rs index 88cfcaba..8e40d27c 100644 --- a/examples/custom-optimizer/main.rs +++ b/examples/custom-optimizer/main.rs @@ -38,8 +38,8 @@ pub fn run() -> Result<()> { println!( "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", epoch, - f64::from(&loss), - 100. * f64::from(&test_accuracy), + f64::try_from(&loss)?, + 100. * f64::try_from(&test_accuracy)?, ); } Ok(()) diff --git a/examples/min-gpt/main.rs b/examples/min-gpt/main.rs index cf40a2be..43546b19 100644 --- a/examples/min-gpt/main.rs +++ b/examples/min-gpt/main.rs @@ -130,7 +130,7 @@ fn sample(data: &TextData, gpt: &impl ModuleT, input: Tensor) -> String { for _index in 0..SAMPLING_LEN { let logits = input.apply_t(gpt, false).i((0, -1, ..)); let sampled_y = logits.softmax(-1, Kind::Float).multinomial(1, true); - let last_label = i64::from(&sampled_y); + let last_label = i64::try_from(&sampled_y).unwrap(); result.push(data.label_to_char(last_label)); input = Tensor::cat(&[input, sampled_y.view([1, 1])], 1).narrow(1, 1, BLOCK_SIZE); } @@ -175,7 +175,7 @@ pub fn main() -> Result<()> { .view([BATCH_SIZE * BLOCK_SIZE, labels]) .cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE])); opt.backward_step_clip(&loss, 0.5); - sum_loss += f64::from(loss); + sum_loss += f64::try_from(loss)?; cnt_loss += 1.0; idx += 1; if idx % 10000 == 0 { diff --git a/examples/mnist/mnist_nn.rs b/examples/mnist/mnist_nn.rs index 25d5826a..0c799633 100644 --- a/examples/mnist/mnist_nn.rs +++ b/examples/mnist/mnist_nn.rs @@ -26,8 +26,8 @@ pub fn run() -> Result<()> { println!( "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", epoch, - f64::from(&loss), - 100. * f64::from(&test_accuracy), + f64::try_from(&loss)?, + 100. * f64::try_from(&test_accuracy)?, ); } Ok(()) diff --git a/examples/neural-style-transfer/main.rs b/examples/neural-style-transfer/main.rs index dbbf4383..0951e512 100644 --- a/examples/neural-style-transfer/main.rs +++ b/examples/neural-style-transfer/main.rs @@ -63,7 +63,7 @@ pub fn main() -> Result<()> { let loss = style_loss * STYLE_WEIGHT + content_loss; opt.backward_step(&loss); if step_idx % 1000 == 0 { - println!("{} {}", step_idx, f64::from(loss)); + println!("{} {}", step_idx, f64::try_from(loss)?); imagenet::save_image(&input_var, format!("out{step_idx}.jpg"))?; } } diff --git a/examples/reinforcement-learning/a2c.rs b/examples/reinforcement-learning/a2c.rs index a6e69f0a..a5b8ef53 100644 --- a/examples/reinforcement-learning/a2c.rs +++ b/examples/reinforcement-learning/a2c.rs @@ -90,11 +90,11 @@ pub fn train() -> cpython::PyResult<()> { let (critic, actor) = tch::no_grad(|| model(&s_states.get(s))); let probs = actor.softmax(-1, Float); let actions = probs.multinomial(1, true).squeeze_dim(-1); - let step = env.step(Vec::::from(&actions))?; + let step = env.step(Vec::::try_from(&actions).unwrap())?; sum_rewards += &step.reward; - total_rewards += f64::from((&sum_rewards * &step.is_done).sum(Float)); - total_episodes += f64::from(step.is_done.sum(Float)); + total_rewards += f64::try_from((&sum_rewards * &step.is_done).sum(Float)).unwrap(); + total_episodes += f64::try_from(step.is_done.sum(Float)).unwrap(); let masks = Tensor::from(1f32) - step.is_done; sum_rewards *= &masks; @@ -162,7 +162,7 @@ pub fn sample>(weight_file: T) -> cpython::PyResult<() let (_critic, actor) = tch::no_grad(|| model(obs)); let probs = actor.softmax(-1, Float); let actions = probs.multinomial(1, true).squeeze_dim(-1); - let step = env.step(Vec::::from(&actions))?; + let step = env.step(Vec::::try_from(&actions).unwrap())?; let masks = Tensor::from(1f32) - step.is_done; obs = frame_stack.update(&step.obs, Some(&masks)); diff --git a/examples/reinforcement-learning/ddpg.rs b/examples/reinforcement-learning/ddpg.rs index 9d96bdea..743411bd 100644 --- a/examples/reinforcement-learning/ddpg.rs +++ b/examples/reinforcement-learning/ddpg.rs @@ -317,7 +317,7 @@ pub fn run() -> cpython::PyResult<()> { let mut total_reward = 0.0; for _ in 0..EPISODE_LENGTH { - let mut actions = 2.0 * f64::from(agent.actions(&obs)); + let mut actions = 2.0 * f64::try_from(agent.actions(&obs)).unwrap(); actions = actions.clamp(-2.0, 2.0); let action_vec = vec![actions]; diff --git a/examples/reinforcement-learning/policy_gradient.rs b/examples/reinforcement-learning/policy_gradient.rs index a8b7177a..6207eedb 100644 --- a/examples/reinforcement-learning/policy_gradient.rs +++ b/examples/reinforcement-learning/policy_gradient.rs @@ -48,7 +48,7 @@ pub fn run() -> cpython::PyResult<()> { let action = tch::no_grad(|| { obs.unsqueeze(0).apply(&model).softmax(1, Float).multinomial(1, true) }); - let action = i64::from(action); + let action = i64::try_from(action).unwrap(); let step = env.step(action)?; steps.push(step.copy_with_obs(&obs)); obs = if step.is_done { env.reset()? } else { step.obs }; diff --git a/examples/reinforcement-learning/ppo.rs b/examples/reinforcement-learning/ppo.rs index 07b599c3..d1b462ed 100644 --- a/examples/reinforcement-learning/ppo.rs +++ b/examples/reinforcement-learning/ppo.rs @@ -94,11 +94,12 @@ pub fn train() -> cpython::PyResult<()> { let (critic, actor) = tch::no_grad(|| model(&s_states.get(s))); let probs = actor.softmax(-1, Kind::Float); let actions = probs.multinomial(1, true).squeeze_dim(-1); - let step = env.step(Vec::::from(&actions))?; + let step = env.step(Vec::::try_from(&actions).unwrap())?; sum_rewards += &step.reward; - total_rewards += f64::from((&sum_rewards * &step.is_done).sum(Kind::Float)); - total_episodes += f64::from(step.is_done.sum(Kind::Float)); + total_rewards += + f64::try_from((&sum_rewards * &step.is_done).sum(Kind::Float)).unwrap(); + total_episodes += f64::try_from(step.is_done.sum(Kind::Float)).unwrap(); let masks = Tensor::from(1f32) - step.is_done; sum_rewards *= &masks; @@ -171,7 +172,7 @@ pub fn sample>(weight_file: T) -> cpython::PyResult<() let (_critic, actor) = tch::no_grad(|| model(obs)); let probs = actor.softmax(-1, Kind::Float); let actions = probs.multinomial(1, true).squeeze_dim(-1); - let step = env.step(Vec::::from(&actions))?; + let step = env.step(Vec::::try_from(&actions).unwrap())?; let masks = Tensor::from(1f32) - step.is_done; obs = frame_stack.update(&step.obs, Some(&masks)); diff --git a/examples/stable-diffusion/main.rs b/examples/stable-diffusion/main.rs index a2756725..8ca1aa4f 100644 --- a/examples/stable-diffusion/main.rs +++ b/examples/stable-diffusion/main.rs @@ -2418,7 +2418,7 @@ impl DDIMScheduler { ), }; let alphas: Tensor = 1.0 - betas; - let alphas_cumprod = Vec::::from(alphas.cumprod(0, Kind::Double)); + let alphas_cumprod = Vec::::try_from(alphas.cumprod(0, Kind::Double)).unwrap(); Self { alphas_cumprod, timesteps, step_ratio, config } } diff --git a/examples/transfer-learning/main.rs b/examples/transfer-learning/main.rs index adea2f4d..bb773e50 100644 --- a/examples/transfer-learning/main.rs +++ b/examples/transfer-learning/main.rs @@ -36,7 +36,7 @@ pub fn main() -> Result<()> { sgd.backward_step(&loss); let test_accuracy = test_images.apply(&linear).accuracy_for_logits(&dataset.test_labels); - println!("{} {:.2}%", epoch_idx, 100. * f64::from(test_accuracy)); + println!("{} {:.2}%", epoch_idx, 100. * f64::try_from(test_accuracy)?); } Ok(()) } diff --git a/examples/translation/main.rs b/examples/translation/main.rs index 9e10fc30..4bbca9e6 100644 --- a/examples/translation/main.rs +++ b/examples/translation/main.rs @@ -132,7 +132,7 @@ impl Model { let target_tensor = Tensor::of_slice(&[s as i64]).to_device(self.device); loss = loss + output.nll_loss(&target_tensor); let (_, output) = output.topk(1, -1, true, true); - if self.decoder_eos == i64::from(&output) as usize { + if self.decoder_eos == i64::try_from(&output).unwrap() as usize { break; } prev = if use_teacher_forcing { target_tensor } else { output }; @@ -155,7 +155,7 @@ impl Model { for _i in 0..MAX_LENGTH { let (output, state_) = self.decoder.forward(&prev, &state, &enc_outputs, true); let (_, output) = output.topk(1, -1, true, true); - let output_ = i64::from(&output) as usize; + let output_ = i64::try_from(&output).unwrap() as usize; output_seq.push(output_); if self.decoder_eos == output_ { break; @@ -208,7 +208,7 @@ pub fn main() -> Result<()> { let (input_, target) = pairs.choose(&mut rng).unwrap(); let loss = model.train_loss(input_, target, &mut rng); opt.backward_step(&loss); - loss_stats.update(f64::from(loss) / target.len() as f64); + loss_stats.update(f64::try_from(loss)? / target.len() as f64); if idx % 1000 == 0 { println!("{} {}", idx, loss_stats.avg_and_reset()); for _pred_index in 1..5 { diff --git a/examples/vae/main.rs b/examples/vae/main.rs index 8bd803b3..96b75627 100644 --- a/examples/vae/main.rs +++ b/examples/vae/main.rs @@ -83,7 +83,7 @@ pub fn main() -> Result<()> { let (recon_batch, mu, logvar) = vae.forward(&bimages); let loss = loss(&recon_batch, &bimages, &mu, &logvar); opt.backward_step(&loss); - train_loss += f64::from(&loss); + train_loss += f64::try_from(&loss)?; samples += bimages.size()[0] as f64; } println!("Epoch: {}, loss: {}", epoch, train_loss / samples); diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 034be712..690bc2b2 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -65,7 +65,7 @@ pub fn report(pred: &Tensor, img: &Tensor, w: i64, h: i64) -> Result { let mut bboxes: Vec> = (0..nclasses).map(|_| vec![]).collect(); // Extract the bounding boxes for which confidence is above the threshold. for index in 0..npreds { - let pred = Vec::::from(pred.get(index)); + let pred = Vec::::try_from(pred.get(index))?; let confidence = pred[4]; if confidence > CONFIDENCE_THRESHOLD { let mut class_index = 0; diff --git a/src/data.rs b/src/data.rs index 1d6e43e1..e46a7564 100644 --- a/src/data.rs +++ b/src/data.rs @@ -192,7 +192,7 @@ impl Iterator for TextDataIter { None } else { self.batch_index += 1; - let indexes = Vec::::from(&self.indexes.i(start..start + size)); + let indexes = Vec::::try_from(&self.indexes.i(start..start + size)).unwrap(); let batch: Vec<_> = indexes.iter().map(|&i| self.data.i(i..i + self.seq_len)).collect(); let batch: Vec<_> = batch.iter().collect(); Some(Tensor::stack(&batch, 0)) diff --git a/src/error.rs b/src/error.rs index 5deaf57d..1f7cc3e3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -56,6 +56,9 @@ pub enum TchError { #[error(transparent)] Zip(#[from] ZipError), + #[error(transparent)] + NdArray(#[from] ndarray::ShapeError), + /// Errors returned by the safetensors library. #[error("safetensors error {path}: {err}")] SafeTensorError { path: String, err: safetensors::SafeTensorError }, diff --git a/src/nn/linear.rs b/src/nn/linear.rs index cd6fd244..e5d385ba 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -65,8 +65,8 @@ fn matches_pytorch() { let linear = Linear { ws, bs }; let output = linear.forward(&input); - let delta_output: f32 = (&output - &expected_output).norm().into(); - let delta_original: f32 = (&original_output - &expected_output).norm().into(); + let delta_output: f32 = (&output - &expected_output).norm().try_into().unwrap(); + let delta_original: f32 = (&original_output - &expected_output).norm().try_into().unwrap(); // The `matmul()` implementation is close, but `linear()` is at least as close or closer. assert!(output.allclose(&expected_output, 1e-5, 1e-8, false)); diff --git a/src/nn/module.rs b/src/nn/module.rs index 964b3f09..6301cbfa 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -26,7 +26,7 @@ pub trait ModuleT: std::fmt::Debug + Send { for (xs, ys) in Iter2::new(xs, ys, batch_size).return_smaller_last_batch() { let acc = self.forward_t(&xs.to_device(d), false).accuracy_for_logits(&ys.to_device(d)); let size = xs.size()[0] as f64; - sum_accuracy += f64::from(&acc) * size; + sum_accuracy += f64::try_from(&acc).unwrap() * size; sample_count += size; } sum_accuracy / sample_count diff --git a/src/nn/optimizer.rs b/src/nn/optimizer.rs index 17dd3409..c0d89c2c 100644 --- a/src/nn/optimizer.rs +++ b/src/nn/optimizer.rs @@ -231,7 +231,7 @@ impl Optimizer { for var in v.trainable_variables.iter() { norms.push(var.tensor.grad().norm()); } - let total_norm = f64::from(Tensor::stack(&norms, 0).norm()); + let total_norm = f64::try_from(Tensor::stack(&norms, 0).norm()).unwrap(); let clip_coef = max / (total_norm + 1e-6); if clip_coef < 1.0 { for var in v.trainable_variables.iter() { diff --git a/src/tensor/convert.rs b/src/tensor/convert.rs index 4010d89b..a61519fd 100644 --- a/src/tensor/convert.rs +++ b/src/tensor/convert.rs @@ -4,62 +4,102 @@ use crate::{kind::Element, TchError}; use half::{bf16, f16}; use std::convert::{TryFrom, TryInto}; -impl From<&Tensor> for Vec { - fn from(tensor: &Tensor) -> Vec { - let numel = tensor.numel(); - let mut vec = vec![T::ZERO; numel]; - tensor.to_device(crate::Device::Cpu).to_kind(T::KIND).copy_data(&mut vec, numel); - vec +impl TryFrom<&Tensor> for Vec { + type Error = TchError; + fn try_from(tensor: &Tensor) -> Result { + let s1 = tensor.size1()? as usize; + let num_elem = s1; + let mut vec = vec![T::ZERO; num_elem]; + tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, num_elem)?; + Ok(vec) } } -impl From<&Tensor> for Vec> { - fn from(tensor: &Tensor) -> Vec> { - let first_dim = tensor.size()[0]; - (0..first_dim).map(|i| Vec::::from(tensor.get(i))).collect() +impl TryFrom<&Tensor> for Vec> { + type Error = TchError; + fn try_from(tensor: &Tensor) -> Result { + let (s1, s2) = tensor.size2()?; + let s1 = s1 as usize; + let s2 = s2 as usize; + let num_elem = s1 * s2; + // TODO: Try to remove this intermediary copy. + let mut all_elems = vec![T::ZERO; num_elem]; + tensor.f_to_kind(T::KIND)?.f_copy_data(&mut all_elems, num_elem)?; + let out = (0..s1).map(|i1| (0..s2).map(|i2| all_elems[i1 * s2 + i2]).collect()).collect(); + Ok(out) } } -impl From<&Tensor> for Vec>> { - fn from(tensor: &Tensor) -> Vec>> { - let first_dim = tensor.size()[0]; - (0..first_dim).map(|i| Vec::>::from(tensor.get(i))).collect() +impl TryFrom<&Tensor> for Vec>> { + type Error = TchError; + fn try_from(tensor: &Tensor) -> Result { + let (s1, s2, s3) = tensor.size3()?; + let s1 = s1 as usize; + let s2 = s2 as usize; + let s3 = s3 as usize; + let num_elem = s1 * s2 * s3; + // TODO: Try to remove this intermediary copy. + let mut all_elems = vec![T::ZERO; num_elem]; + tensor.f_to_kind(T::KIND)?.f_copy_data(&mut all_elems, num_elem)?; + let out = (0..s1) + .map(|i1| { + (0..s2) + .map(|i2| (0..s3).map(|i3| all_elems[i1 * s2 * s3 + i2 * s3 + i3]).collect()) + .collect() + }) + .collect(); + Ok(out) } } -impl From for Vec { - fn from(tensor: Tensor) -> Vec { - Vec::::from(&tensor) +impl TryFrom for Vec { + type Error = TchError; + fn try_from(tensor: Tensor) -> Result { + Vec::::try_from(&tensor) } } -impl From for Vec> { - fn from(tensor: Tensor) -> Vec> { - Vec::>::from(&tensor) +impl TryFrom for Vec> { + type Error = TchError; + fn try_from(tensor: Tensor) -> Result { + Vec::>::try_from(&tensor) } } -impl From for Vec>> { - fn from(tensor: Tensor) -> Vec>> { - Vec::>>::from(&tensor) +impl TryFrom for Vec>> { + type Error = TchError; + fn try_from(tensor: Tensor) -> Result { + Vec::>>::try_from(&tensor) } } macro_rules! from_tensor { ($typ:ident) => { - impl From<&Tensor> for $typ { - fn from(tensor: &Tensor) -> $typ { + impl TryFrom<&Tensor> for $typ { + type Error = TchError; + + fn try_from(tensor: &Tensor) -> Result { let numel = tensor.numel(); if numel != 1 { - panic!("expected exactly one element, got {}", numel) + return Err(TchError::Convert(format!( + "expected exactly one element, got {}", + numel + ))); } - Vec::from(tensor)[0] + let mut vec = [$typ::ZERO; 1]; + tensor + .f_to_device(crate::Device::Cpu)? + .f_to_kind($typ::KIND)? + .f_copy_data(&mut vec, numel)?; + Ok(vec[0]) } } - impl From for $typ { - fn from(tensor: Tensor) -> $typ { - $typ::from(&tensor) + impl TryFrom for $typ { + type Error = TchError; + + fn try_from(tensor: Tensor) -> Result { + $typ::try_from(&tensor) } } }; @@ -68,20 +108,20 @@ macro_rules! from_tensor { from_tensor!(f64); from_tensor!(f32); from_tensor!(f16); -from_tensor!(bf16); from_tensor!(i64); from_tensor!(i32); from_tensor!(i8); from_tensor!(u8); from_tensor!(bool); +from_tensor!(bf16); -impl TryInto> for &Tensor { - type Error = ndarray::ShapeError; +impl TryInto> for &Tensor { + type Error = TchError; fn try_into(self) -> Result, Self::Error> { - let v: Vec = self.into(); + let v: Vec = self.try_into()?; let shape: Vec = self.size().iter().map(|s| *s as usize).collect(); - ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shape), v) + Ok(ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shape), v)?) } } diff --git a/src/tensor/display.rs b/src/tensor/display.rs index f43961e4..9022fa6d 100644 --- a/src/tensor/display.rs +++ b/src/tensor/display.rs @@ -61,10 +61,14 @@ impl std::fmt::Debug for Tensor { | Kind::ComplexDouble => (false, false), }; match (self.size().as_slice(), is_int, is_float) { - ([], true, false) => write!(f, "[{}]", i64::from(self)), - ([s], true, false) if *s < 10 => write!(f, "{:?}", Vec::::from(self)), - ([], false, true) => write!(f, "[{}]", f64::from(self)), - ([s], false, true) if *s < 10 => write!(f, "{:?}", Vec::::from(self)), + ([], true, false) => write!(f, "[{}]", i64::try_from(self).unwrap()), + ([s], true, false) if *s < 10 => { + write!(f, "{:?}", Vec::::try_from(self).unwrap()) + } + ([], false, true) => write!(f, "[{}]", f64::try_from(self).unwrap()), + ([s], false, true) if *s < 10 => { + write!(f, "{:?}", Vec::::try_from(self).unwrap()) + } _ => write!(f, "Tensor[{:?}, {:?}]", self.size(), kind), } } @@ -263,7 +267,7 @@ impl FloatFormatter { t.masked_select(&t.isfinite().logical_and(&t.ne(0.))) }; - let values = Vec::::from(&nonzero_finite_vals); + let values = Vec::::try_from(&nonzero_finite_vals).unwrap(); if nonzero_finite_vals.numel() > 0 { let nonzero_finite_abs = nonzero_finite_vals.abs(); let nonzero_finite_min = nonzero_finite_abs.min().double_value(&[]); @@ -311,7 +315,7 @@ impl TensorFormatter for FloatFormatter { } fn values(tensor: &Tensor) -> Vec { - Vec::::from(tensor) + Vec::::try_from(tensor.reshape(-1)).unwrap() } } @@ -329,7 +333,7 @@ impl TensorFormatter for IntFormatter { } fn values(tensor: &Tensor) -> Vec { - Vec::::from(tensor) + Vec::::try_from(tensor.reshape(-1)).unwrap() } } @@ -348,7 +352,7 @@ impl TensorFormatter for BoolFormatter { } fn values(tensor: &Tensor) -> Vec { - Vec::::from(tensor) + Vec::::try_from(tensor.reshape(-1)).unwrap() } } diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 3d413a00..ccf87004 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -297,7 +297,10 @@ impl PartialEq for Tensor { Err(_) => false, Ok(v) => match v.f_all() { Err(_) => false, - Ok(v) => i64::from(v) > 0, + Ok(v) => match i64::try_from(v) { + Err(_) => false, + Ok(v) => v > 0, + }, }, } } diff --git a/src/vision/imagenet.rs b/src/vision/imagenet.rs index 629e3ad8..13f70f72 100644 --- a/src/vision/imagenet.rs +++ b/src/vision/imagenet.rs @@ -1160,8 +1160,8 @@ pub fn top(tensor: &Tensor, k: i64) -> Vec<(f64, String)> { _ => panic!("unexpected tensor shape {tensor:?}"), }; let (values, indexes) = tensor.topk(k, 0, true, true); - let values = Vec::::from(values); - let indexes = Vec::::from(indexes); + let values = Vec::::try_from(values).unwrap(); + let indexes = Vec::::try_from(indexes).unwrap(); values .iter() .zip(indexes.iter()) diff --git a/tests/data_tests.rs b/tests/data_tests.rs index 40c6204f..4d78ab5d 100644 --- a/tests/data_tests.rs +++ b/tests/data_tests.rs @@ -1,6 +1,9 @@ use std::io::Write; use tch::{data, IndexOp, Tensor}; +mod test_utils; +use test_utils::*; + #[test] fn iter2() { let bsize: usize = 4; @@ -8,8 +11,8 @@ fn iter2() { let xs = Tensor::of_slice(&vs); let ys = Tensor::of_slice(&vs.iter().map(|x| x * 2).collect::>()); for (batch_xs, batch_ys) in data::Iter2::new(&xs, &ys, bsize as i64) { - let xs = Vec::::from(&batch_xs); - let ys = Vec::::from(&batch_ys); + let xs = vec_i64_from(&batch_xs); + let ys = vec_i64_from(&batch_ys); assert_eq!(xs.len(), bsize); assert_eq!(ys.len(), bsize); for i in 0..bsize { @@ -21,8 +24,8 @@ fn iter2() { } let mut all_in_order = true; for (batch_xs, batch_ys) in data::Iter2::new(&xs, &ys, bsize as i64).shuffle() { - let xs = Vec::::from(&batch_xs); - let ys = Vec::::from(&batch_ys); + let xs = vec_i64_from(&batch_xs); + let ys = vec_i64_from(&batch_ys); assert_eq!(xs.len(), bsize); assert_eq!(ys.len(), bsize); for i in 0..bsize { @@ -49,8 +52,8 @@ fn text() { for xs in text_data.iter_shuffle(2, 5) { let first_column_plus_one = (xs.i((.., ..1)) + 1).fmod(10); let second_column = xs.i((.., 1..=1)); - let err = i64::from( - (first_column_plus_one - second_column).pow_tensor_scalar(2).sum(tch::Kind::Float), + let err: i64 = from( + &(first_column_plus_one - second_column).pow_tensor_scalar(2).sum(tch::Kind::Float), ); assert_eq!(err, 0) } diff --git a/tests/jit_tests.rs b/tests/jit_tests.rs index 73443db1..0a8206f0 100644 --- a/tests/jit_tests.rs +++ b/tests/jit_tests.rs @@ -1,6 +1,9 @@ use std::convert::{TryFrom, TryInto}; use tch::{IValue, Kind, Tensor}; +mod test_utils; +use test_utils::*; + #[test] fn jit() { let x = Tensor::of_slice(&[3, 1, 4, 1, 5]).to_kind(Kind::Float); @@ -9,7 +12,7 @@ fn jit() { let mod_ = tch::CModule::load("tests/foo.pt").unwrap(); let result = mod_.forward_ts(&[&x, &y]).unwrap(); let expected = x * 2.0 + y + 42.0; - assert_eq!(Vec::::from(&result), Vec::::from(&expected)); + assert_eq!(vec_f64_from(&result), vec_f64_from(&expected)); } #[test] @@ -20,16 +23,16 @@ fn jit_data() { let mod_ = tch::CModule::load_data(&mut file).unwrap(); let result = mod_.forward_ts(&[&x, &y]).unwrap(); let expected = x * 2.0 + y + 42.0; - assert_eq!(Vec::::from(&result), Vec::::from(&expected)); + assert_eq!(vec_f64_from(&result), vec_f64_from(&expected)); } #[test] fn jit1() { let mod_ = tch::CModule::load("tests/foo1.pt").unwrap(); let result = mod_.forward_ts(&[Tensor::from(42), Tensor::from(1337)]).unwrap(); - assert_eq!(i64::from(&result), 1421); + assert_eq!(from::(&result), 1421); let result = mod_.method_ts("forward", &[Tensor::from(42), Tensor::from(1337)]).unwrap(); - assert_eq!(i64::from(&result), 1421); + assert_eq!(from::(&result), 1421); } #[test] @@ -43,8 +46,8 @@ fn jit2() { assert_eq!(result, IValue::from((expected1, expected2))); // Destructure the tuple, using an option. let (v1, v2) = <(Tensor, Option)>::try_from(result).unwrap(); - assert_eq!(i64::from(v1), 1421); - assert_eq!(i64::from(v2.unwrap()), -1295); + assert_eq!(from::(&v1), 1421); + assert_eq!(from::(&v2.unwrap()), -1295); let result = mod_ .method_is("forward", &[IValue::from(Tensor::from(42)), IValue::from(Tensor::from(1337))]) .unwrap(); @@ -52,8 +55,8 @@ fn jit2() { let expected2 = Tensor::from(-1295); assert_eq!(result, IValue::from((expected1, expected2))); let (v1, v2) = <(Tensor, Tensor)>::try_from(result).unwrap(); - assert_eq!(i64::from(v1), 1421); - assert_eq!(i64::from(v2), -1295); + assert_eq!(from::(&v1), 1421); + assert_eq!(from::(&v2), -1295); } #[test] @@ -61,7 +64,7 @@ fn jit3() { let mod_ = tch::CModule::load("tests/foo3.pt").unwrap(); let xs = Tensor::of_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]); let result = mod_.forward_ts(&[xs]).unwrap(); - assert_eq!(f64::from(&result), 120.0); + assert_eq!(from::(&result), 120.0); } #[test] @@ -147,7 +150,10 @@ fn create_traced() { let xs = Tensor::of_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]); let ys = Tensor::of_slice(&[41.0, 1335.0, std::f64::consts::PI - 3., 4.0, 5.0]); let result = modl.method_ts("MyFn", &[xs, ys]).unwrap(); - assert_eq!(Vec::::from(&result), [42.0, 1337.0, std::f64::consts::PI, 8.0, 10.0]) + assert_eq!( + Vec::::try_from(&result).unwrap(), + [42.0, 1337.0, std::f64::consts::PI, 8.0, 10.0] + ) } // https://github.com/LaurentMazare/tch-rs/issues/475 @@ -166,7 +172,7 @@ fn jit_double_free() { IValue::Tensor(tensor) => tensor, result => panic!("expected a tensor got {result:?}"), }; - assert_eq!(Vec::::from(&result), [5.0, 7.0, 9.0]) + assert_eq!(Vec::::try_from(&result).unwrap(), [5.0, 7.0, 9.0]) } // https://github.com/LaurentMazare/tch-rs/issues/597 @@ -182,6 +188,6 @@ fn specialized_dict() { ]); let result = mod_.method_is("generate", &[input]).unwrap(); let result: (Tensor, Tensor) = result.try_into().unwrap(); - assert_eq!(Vec::::from(&result.0), [1.0, 2.0, 3.0]); - assert_eq!(Vec::::from(&result.1), [1.0, 7.0]) + assert_eq!(Vec::::try_from(&result.0).unwrap(), [1.0, 2.0, 3.0]); + assert_eq!(Vec::::try_from(&result.1).unwrap(), [1.0, 7.0]) } diff --git a/tests/nn_tests.rs b/tests/nn_tests.rs index 2a22fc9d..7aa747f9 100644 --- a/tests/nn_tests.rs +++ b/tests/nn_tests.rs @@ -2,6 +2,9 @@ use tch::nn::{group_norm, layer_norm}; use tch::nn::{Module, OptimizerConfig}; use tch::{kind, nn, Device, Kind, Reduction, Tensor}; +mod test_utils; +use test_utils::*; + #[test] fn optimizer_test() { tch::manual_seed(42); @@ -20,7 +23,7 @@ fn optimizer_test() { let mut linear = nn::linear(vs.root(), 1, 1, cfg); let loss = xs.apply(&linear).mse_loss(&ys, Reduction::Mean); - let initial_loss = f64::from(&loss); + let initial_loss: f64 = from(&loss); assert!(initial_loss > 1.0, "{}", "initial loss {initial_loss}"); opt.set_lr(1e-2); @@ -30,7 +33,7 @@ fn optimizer_test() { opt.backward_step(&loss); } let loss = xs.apply(&linear).mse_loss(&ys, Reduction::Mean); - let final_loss = f64::from(loss); + let final_loss: f64 = from(&loss); assert!(final_loss < 0.25, "{}", "final loss {final_loss}"); // Reset the weights to their initial values. @@ -38,7 +41,7 @@ fn optimizer_test() { linear.ws.init(nn::Init::Const(0.)); linear.bs.as_mut().unwrap().init(nn::Init::Const(0.)); }); - let initial_loss2 = f64::from(xs.apply(&linear).mse_loss(&ys, Reduction::Mean)); + let initial_loss2: f64 = from(&xs.apply(&linear).mse_loss(&ys, Reduction::Mean)); assert_eq!(initial_loss, initial_loss2); // Set the learning-rate to be very small and check that the loss does not change @@ -49,7 +52,7 @@ fn optimizer_test() { opt.backward_step(&loss); } let loss = xs.apply(&linear).mse_loss(&ys, Reduction::Mean); - let final_loss = f64::from(loss); + let final_loss: f64 = from(&loss); assert!((final_loss - initial_loss) < 1e-5, "{}", "final loss {final_loss}") } @@ -86,7 +89,7 @@ fn gradient_descent_test_clip_norm() { } fn round4(t: Tensor) -> Vec { - let v = Vec::::from(t); + let v = vec_f64_from(&t); v.iter().map(|x| (10000. * x).round() / 10000.).collect() } @@ -108,9 +111,9 @@ fn gradient_clip_test() { let g1 = var1.grad(); let g2 = var2.grad(); let g3 = var3.grad(); - assert_eq!(Vec::::from(&g1), [2.0, 2.0]); - assert_eq!(Vec::::from(&g2), [8.0]); - assert_eq!(Vec::::from(&g3), [-8.0, -8.0]); + assert_eq!(vec_f64_from(&g1), [2.0, 2.0]); + assert_eq!(vec_f64_from(&g2), [8.0]); + assert_eq!(vec_f64_from(&g3), [-8.0, -8.0]); // Test clipping the gradient by value. let loss = all.pow_tensor_scalar(2).sum(Kind::Float); opt.zero_grad(); @@ -119,9 +122,9 @@ fn gradient_clip_test() { let g1 = var1.grad(); let g2 = var2.grad(); let g3 = var3.grad(); - assert_eq!(Vec::::from(&g1), [2.0, 2.0]); - assert_eq!(Vec::::from(&g2), [4.0]); - assert_eq!(Vec::::from(&g3), [-4.0, -4.0]); + assert_eq!(vec_f64_from(&g1), [2.0, 2.0]); + assert_eq!(vec_f64_from(&g2), [4.0]); + assert_eq!(vec_f64_from(&g3), [-4.0, -4.0]); // Test clipping the gradient norm. let loss = all.pow_tensor_scalar(2).sum(Kind::Float); opt.zero_grad(); @@ -187,7 +190,7 @@ fn layer_norm_parameters_test() { let mut ln = layer_norm(vs.root(), vec![2], Default::default()); let loss = xs.apply(&ln).mse_loss(&ys, Reduction::Mean); - let initial_loss = f64::from(&loss); + let initial_loss: f64 = from(&loss); assert!(initial_loss > 1.0, "{}", "initial loss {initial_loss}"); // Optimization loop. @@ -196,7 +199,7 @@ fn layer_norm_parameters_test() { opt.backward_step(&loss); } let loss = xs.apply(&ln).mse_loss(&ys, Reduction::Mean); - let final_loss = f64::from(loss); + let final_loss: f64 = from(&loss); assert!(final_loss < 0.25, "{}", "final loss {final_loss:?}"); // Reset the weights to their initial values. @@ -208,7 +211,7 @@ fn layer_norm_parameters_test() { bs.init(nn::Init::Const(0.)); } }); - let initial_loss2 = f64::from(xs.apply(&ln).mse_loss(&ys, Reduction::Mean)); + let initial_loss2: f64 = from(&xs.apply(&ln).mse_loss(&ys, Reduction::Mean)); assert_eq!(initial_loss, initial_loss2) } @@ -348,26 +351,26 @@ fn linear() { fn pad() { let xs = Tensor::of_slice(&[1., 2., 3.]); let padded = nn::PaddingMode::Zeros.pad(&xs, &[1, 1]); - assert_eq!(Vec::::from(&padded), [0., 1., 2., 3., 0.]); + assert_eq!(vec_f32_from(&padded), [0., 1., 2., 3., 0.]); let xs = Tensor::of_slice(&[1., 2., 3.]).view([1, 3]); let padded = nn::PaddingMode::Zeros.pad(&xs, &[1, 1]); - assert_eq!(Vec::::from(&padded), [0., 1., 2., 3., 0.]); + assert_eq!(vec_f32_from(&padded.reshape(-1)), [0., 1., 2., 3., 0.]); let xs = Tensor::of_slice(&[1., 2., 3., 4.]).view([1, 2, 2]); let padded = nn::PaddingMode::Reflect.pad(&xs, &[1, 1, 1, 1]); assert_eq!( - Vec::::from(&padded), + vec_f32_from(&padded.reshape(-1)), &[4.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0, 1.0, 4.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0, 1.0] ); let padded = nn::PaddingMode::Reflect.pad(&xs, &[1, 1, 1, 1]); assert_eq!( - Vec::::from(&padded), + vec_f32_from(&padded.reshape(-1)), &[4.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0, 1.0, 4.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0, 1.0] ); let padded = nn::PaddingMode::Reflect.pad(&xs, &[1, 1, 1, 1]); assert_eq!( - Vec::::from(&padded), + vec_f32_from(&padded.reshape(-1)), &[4.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0, 1.0, 4.0, 3.0, 4.0, 3.0, 2.0, 1.0, 2.0, 1.0] ); } @@ -387,14 +390,14 @@ fn conv() { let xs = Tensor::of_slice(&[1f32, 2., 3., 4.]).view([1, 1, 2, 2]); // NCHW let conved = apply_conv(&xs, nn::PaddingMode::Zeros); - assert_eq!(Vec::::from(&conved), &[10.0, 10.0, 10.0, 10.0]); + assert_eq!(vec_f32_from(&conved.reshape(-1)), &[10.0, 10.0, 10.0, 10.0]); let conved = apply_conv(&xs, nn::PaddingMode::Reflect); - assert_eq!(Vec::::from(&conved), &[27.0, 24.0, 21.0, 18.0]); + assert_eq!(vec_f32_from(&conved.reshape(-1)), &[27.0, 24.0, 21.0, 18.0]); let conved = apply_conv(&xs, nn::PaddingMode::Circular); - assert_eq!(Vec::::from(&conved), &[27.0, 24.0, 21.0, 18.0]); + assert_eq!(vec_f32_from(&conved.reshape(-1)), &[27.0, 24.0, 21.0, 18.0]); let conved = apply_conv(&xs, nn::PaddingMode::Replicate); - assert_eq!(Vec::::from(&conved), &[18.0, 21.0, 24.0, 27.0]); + assert_eq!(vec_f32_from(&conved.reshape(-1)), &[18.0, 21.0, 24.0, 27.0]); } diff --git a/tests/serialization_tests.rs b/tests/serialization_tests.rs index 93d3cb31..7695e88a 100644 --- a/tests/serialization_tests.rs +++ b/tests/serialization_tests.rs @@ -1,5 +1,8 @@ use tch::{Kind, Tensor}; +mod test_utils; +use test_utils::*; + struct TmpFile(std::path::PathBuf); impl TmpFile { @@ -33,7 +36,7 @@ fn save_and_load() { let t1 = Tensor::of_slice(&vec); t1.save(&tmp_file).unwrap(); let t2 = Tensor::load(&tmp_file).unwrap(); - assert_eq!(Vec::::from(&t2), vec) + assert_eq!(vec_f64_from(&t2), vec) } #[test] @@ -43,7 +46,7 @@ fn save_to_stream_and_load() { let t1 = Tensor::of_slice(&vec); t1.save_to_stream(std::fs::File::create(&tmp_file).unwrap()).unwrap(); let t2 = Tensor::load(&tmp_file).unwrap(); - assert_eq!(Vec::::from(&t2), vec) + assert_eq!(vec_f64_from(&t2), vec) } #[test] @@ -54,7 +57,7 @@ fn save_and_load_from_stream() { t1.save(&tmp_file).unwrap(); let reader = std::io::BufReader::new(std::fs::File::open(&tmp_file).unwrap()); let t2 = Tensor::load_from_stream(reader).unwrap(); - assert_eq!(Vec::::from(&t2), vec) + assert_eq!(vec_f64_from(&t2), vec) } #[test] @@ -67,7 +70,7 @@ fn save_and_load_multi() { assert_eq!(named_tensors.len(), 2); assert_eq!(named_tensors[0].0, "pi"); assert_eq!(named_tensors[1].0, "e"); - assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57); + assert_eq!(from::(&named_tensors[1].1.sum(tch::Kind::Float)), 57); } #[test] @@ -84,7 +87,7 @@ fn save_to_stream_and_load_multi() { assert_eq!(named_tensors.len(), 2); assert_eq!(named_tensors[0].0, "pi"); assert_eq!(named_tensors[1].0, "e"); - assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57); + assert_eq!(from::(&named_tensors[1].1.sum(tch::Kind::Float)), 57); } #[test] @@ -98,7 +101,7 @@ fn save_and_load_multi_from_stream() { assert_eq!(named_tensors.len(), 2); assert_eq!(named_tensors[0].0, "pi"); assert_eq!(named_tensors[1].0, "e"); - assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57); + assert_eq!(from::(&named_tensors[1].1.sum(tch::Kind::Float)), 57); } #[test] @@ -111,7 +114,7 @@ fn save_and_load_npz() { assert_eq!(named_tensors.len(), 2); assert_eq!(named_tensors[0].0, "pi"); assert_eq!(named_tensors[1].0, "e"); - assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57); + assert_eq!(from::(&named_tensors[1].1.sum(tch::Kind::Float)), 57); } #[test] @@ -125,7 +128,7 @@ fn save_and_load_npz_half() { assert_eq!(named_tensors.len(), 2); assert_eq!(named_tensors[0].0, "pi"); assert_eq!(named_tensors[1].0, "e"); - assert_eq!(i64::from(&named_tensors[1].1.sum(tch::Kind::Float)), 57); + assert_eq!(from::(&named_tensors[1].1.sum(tch::Kind::Float)), 57); } #[test] @@ -139,7 +142,7 @@ fn save_and_load_npz_byte() { assert_eq!(named_tensors.len(), 2); assert_eq!(named_tensors[0].0, "pi"); assert_eq!(named_tensors[1].0, "e"); - assert_eq!(i8::from(&named_tensors[1].1.sum(tch::Kind::Int8)), 57); + assert_eq!(from::(&named_tensors[1].1.sum(tch::Kind::Int8)), 57); } #[test] @@ -148,12 +151,12 @@ fn save_and_load_npy() { let pi = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0, 9.0]); pi.write_npy(&tmp_file).unwrap(); let pi = Tensor::read_npy(&tmp_file).unwrap(); - assert_eq!(Vec::::from(&pi), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]); + assert_eq!(vec_f64_from(&pi), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]); let pi = pi.reshape([3, 1, 2]); pi.write_npy(&tmp_file).unwrap(); let pi = Tensor::read_npy(&tmp_file).unwrap(); assert_eq!(pi.size(), [3, 1, 2]); - assert_eq!(Vec::::from(pi.flatten(0, -1)), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]); + assert_eq!(vec_f64_from(&pi.flatten(0, -1)), [3.0, 1.0, 4.0, 1.0, 5.0, 9.0]); } #[test] @@ -166,8 +169,8 @@ fn save_and_load_safetensors() { assert_eq!(named_tensors.len(), 2); for (name, tensor) in named_tensors { match name.as_str() { - "pi" => assert_eq!(i64::from(&tensor.sum(tch::Kind::Float)), 14), - "e" => assert_eq!(i64::from(&tensor.sum(tch::Kind::Float)), 57), + "pi" => assert_eq!(from::(&tensor.sum(tch::Kind::Float)), 14), + "e" => assert_eq!(from::(&tensor.sum(tch::Kind::Float)), 57), _ => panic!("unknow name tensors"), } } @@ -184,8 +187,8 @@ fn save_and_load_safetensors_half() { assert_eq!(named_tensors.len(), 2); for (name, tensor) in named_tensors { match name.as_str() { - "pi" => assert_eq!(i64::from(&tensor.sum(tch::Kind::Float)), 14), - "e" => assert_eq!(i64::from(&tensor.sum(tch::Kind::Float)), 57), + "pi" => assert_eq!(from::(&tensor.sum(tch::Kind::Float)), 14), + "e" => assert_eq!(from::(&tensor.sum(tch::Kind::Float)), 57), _ => panic!("unknow name tensors"), } } diff --git a/tests/tensor_indexing.rs b/tests/tensor_indexing.rs index acc49a6e..083648b5 100644 --- a/tests/tensor_indexing.rs +++ b/tests/tensor_indexing.rs @@ -1,6 +1,9 @@ use tch::{Device, Kind, Tensor}; use tch::{IndexOp, NewAxis}; +mod test_utils; +use test_utils::*; + #[test] fn integer_index() { let opt = (Kind::Float, Device::Cpu); @@ -8,16 +11,16 @@ fn integer_index() { let tensor = Tensor::arange_start(0, 2 * 3, opt).view([2, 3]); let result = tensor.i(1); assert_eq!(result.size(), &[3]); - assert_eq!(Vec::::from(result), &[3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[3, 4, 5]); let tensor = Tensor::arange_start(0, 2 * 3, opt).view([2, 3]); let result = tensor.i((.., 2)); assert_eq!(result.size(), &[2]); - assert_eq!(Vec::::from(result), &[2, 5]); + assert_eq!(vec_i64_from(&result), &[2, 5]); let result = tensor.i((.., -2)); assert_eq!(result.size(), &[2]); - assert_eq!(Vec::::from(result), &[1, 4]); + assert_eq!(vec_i64_from(&result), &[1, 4]); } #[test] @@ -28,43 +31,43 @@ fn range_index() { let tensor = Tensor::arange_start(0, 4 * 3, opt).view([4, 3]); let result = tensor.i(1..3); assert_eq!(result.size(), &[2, 3]); - assert_eq!(Vec::::from(result), &[3, 4, 5, 6, 7, 8]); + assert_eq!(vec_i64_from(&result), &[3, 4, 5, 6, 7, 8]); // RangeFull let tensor = Tensor::arange_start(0, 2 * 3, opt).view([2, 3]); let result = tensor.i(..); assert_eq!(result.size(), &[2, 3]); - assert_eq!(Vec::::from(result), &[0, 1, 2, 3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2, 3, 4, 5]); // RangeFrom let tensor = Tensor::arange_start(0, 4 * 3, opt).view([4, 3]); let result = tensor.i(2..); assert_eq!(result.size(), &[2, 3]); - assert_eq!(Vec::::from(result), &[6, 7, 8, 9, 10, 11]); + assert_eq!(vec_i64_from(&result), &[6, 7, 8, 9, 10, 11]); // RangeTo let tensor = Tensor::arange_start(0, 4 * 3, opt).view([4, 3]); let result = tensor.i(..2); assert_eq!(result.size(), &[2, 3]); - assert_eq!(Vec::::from(result), &[0, 1, 2, 3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2, 3, 4, 5]); // RangeInclusive let tensor = Tensor::arange_start(0, 4 * 3, opt).view([4, 3]); let result = tensor.i(1..=2); assert_eq!(result.size(), &[2, 3]); - assert_eq!(Vec::::from(result), &[3, 4, 5, 6, 7, 8]); + assert_eq!(vec_i64_from(&result), &[3, 4, 5, 6, 7, 8]); // RangeTo let tensor = Tensor::arange_start(0, 4 * 3, opt).view([4, 3]); let result = tensor.i(..1); assert_eq!(result.size(), &[1, 3]); - assert_eq!(Vec::::from(result), &[0, 1, 2]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2]); // RangeToInclusive let tensor = Tensor::arange_start(0, 4 * 3, opt).view([4, 3]); let result = tensor.i(..=1); assert_eq!(result.size(), &[2, 3]); - assert_eq!(Vec::::from(result), &[0, 1, 2, 3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2, 3, 4, 5]); } #[test] @@ -75,13 +78,13 @@ fn slice_index() { let index: &[_] = &[1, 3, 5]; let result = tensor.i(index); assert_eq!(result.size(), &[3, 2]); - assert_eq!(Vec::::from(result), &[2, 3, 6, 7, 10, 11]); + assert_eq!(vec_i64_from(&result), &[2, 3, 6, 7, 10, 11]); let tensor = Tensor::arange_start(0, 3 * 4, opt).view([3, 4]); let index: &[_] = &[3, 0]; let result = tensor.i((.., index)); assert_eq!(result.size(), &[3, 2]); - assert_eq!(Vec::::from(result), &[3, 0, 7, 4, 11, 8]); + assert_eq!(vec_i64_from(&result), &[3, 0, 7, 4, 11, 8]); } #[test] @@ -91,17 +94,17 @@ fn new_index() { let tensor = Tensor::arange_start(0, 2 * 3, opt).view([2, 3]); let result = tensor.i((NewAxis,)); assert_eq!(result.size(), &[1, 2, 3]); - assert_eq!(Vec::::from(result), &[0, 1, 2, 3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2, 3, 4, 5]); let tensor = Tensor::arange_start(0, 2 * 3, opt).view([2, 3]); let result = tensor.i((.., NewAxis)); assert_eq!(result.size(), &[2, 1, 3]); - assert_eq!(Vec::::from(result), &[0, 1, 2, 3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2, 3, 4, 5]); let tensor = Tensor::arange_start(0, 2 * 3, opt).view([2, 3]); let result = tensor.i((.., .., NewAxis)); assert_eq!(result.size(), &[2, 3, 1]); - assert_eq!(Vec::::from(result), &[0, 1, 2, 3, 4, 5]); + assert_eq!(vec_i64_from(&result), &[0, 1, 2, 3, 4, 5]); } #[cfg(target_os = "linux")] @@ -113,7 +116,7 @@ fn complex_index() { let result = tensor.i((1, 1..2, vec![2, 3, 0].as_slice(), NewAxis, 3..)); assert_eq!(result.size(), &[1, 3, 1, 4]); assert_eq!( - Vec::::from(result), + vec_i64_from(&result), &[157, 158, 159, 160, 164, 165, 166, 167, 143, 144, 145, 146] ); } @@ -122,9 +125,9 @@ fn complex_index() { fn index_3d() { let values: Vec = (0..24).collect(); let tensor = tch::Tensor::of_slice(&values).view((2, 3, 4)); - assert_eq!(Vec::::from(tensor.i((0, 0, 0))), &[0]); - assert_eq!(Vec::::from(tensor.i((1, 0, 0))), &[12]); - assert_eq!(Vec::::from(tensor.i((0..2, 0, 0))), &[0, 12]); + assert_eq!(vec_i64_from(&tensor.i((0, 0, 0))), &[0]); + assert_eq!(vec_i64_from(&tensor.i((1, 0, 0))), &[12]); + assert_eq!(vec_i64_from(&tensor.i((0..2, 0, 0))), &[0, 12]); } #[test] @@ -135,7 +138,7 @@ fn tensor_index() { let selected = t.index(&[Some(rows_select), Some(column_select)]); assert_eq!(selected.size(), &[3]); - assert_eq!(Vec::::from(selected), &[1, 5, 2]); + assert_eq!(vec_i64_from(&selected), &[1, 5, 2]); } #[test] @@ -152,7 +155,7 @@ fn tensor_index2() { // 115 = 0 * 200 + 1 * 100 + 1 * 10 + 5 // 126 = 0 * 200 + 1 * 100 + 2 * 10 + 6 // 137 = 0 * 200 + 1 * 100 + 3 * 10 + 7 - assert_eq!(Vec::::from(selected), &[115, 126, 137]); + assert_eq!(vec_i64_from(&selected), &[115, 126, 137]); } #[test] @@ -166,7 +169,7 @@ fn tensor_multi_index() { let selected = t.index(&[Some(select_final)]); // index only rows assert_eq!(selected.size(), &[2, 3, 3]); - assert_eq!(Vec::::from(selected), &[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 0, 1, 2]); + assert_eq!(vec_i64_from(&selected), &[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 0, 1, 2]); // after flattening } @@ -178,30 +181,30 @@ fn tensor_put() { let values = Tensor::of_slice(&[10i64, 12, 24]); let updated = t.index_put(&[Some(rows_select), Some(column_select)], &values, false); - assert_eq!(Vec::::from(updated), &[0i64, 10, 24, 3, 4, 12]); // after flattening + assert_eq!(vec_i64_from(&updated), &[0i64, 10, 24, 3, 4, 12]); // after flattening } #[test] fn indexing_doc() { let tensor = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).view((2, 3)); let t = tensor.i(1); - assert_eq!(Vec::::from(t), [4, 5, 6]); + assert_eq!(vec_i64_from(&t), [4, 5, 6]); let t = tensor.i((.., -2)); - assert_eq!(Vec::::from(t), [2, 5]); + assert_eq!(vec_i64_from(&t), [2, 5]); let tensor = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).view((2, 3)); let t = tensor.i((.., 1..)); assert_eq!(t.size(), [2, 2]); - assert_eq!(Vec::::from(t.contiguous().view(-1)), [2, 3, 5, 6]); + assert_eq!(vec_i64_from(&t.contiguous().view(-1)), [2, 3, 5, 6]); let t = tensor.i((..1, ..)); assert_eq!(t.size(), [1, 3]); - assert_eq!(Vec::::from(t.contiguous().view(-1)), [1, 2, 3]); + assert_eq!(vec_i64_from(&t.contiguous().view(-1)), [1, 2, 3]); let t = tensor.i((.., 1..2)); assert_eq!(t.size(), [2, 1]); - assert_eq!(Vec::::from(t.contiguous().view(-1)), [2, 5]); + assert_eq!(vec_i64_from(&t.contiguous().view(-1)), [2, 5]); let t = tensor.i((.., 1..=2)); assert_eq!(t.size(), [2, 2]); - assert_eq!(Vec::::from(t.contiguous().view(-1)), [2, 3, 5, 6]); + assert_eq!(vec_i64_from(&t.contiguous().view(-1)), [2, 3, 5, 6]); let tensor = Tensor::of_slice(&[1, 2, 3, 4, 5, 6]).view((2, 3)); let t = tensor.i((NewAxis,)); diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index bdb5819f..f5bc6233 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -4,6 +4,9 @@ use std::convert::{TryFrom, TryInto}; use std::f32; use tch::{Device, Tensor}; +mod test_utils; +use test_utils::*; + #[test] #[cfg(feature = "cuda-tests")] fn amp_non_finite_check_and_unscale() { @@ -11,18 +14,18 @@ fn amp_non_finite_check_and_unscale() { let mut found_inf = Tensor::of_slice(&[0f32]).to_device(Device::Cuda(0)); let inv_scale = Tensor::of_slice(&[0.1f32]).to_device(Device::Cuda(0)); u.internal_amp_non_finite_check_and_unscale(&mut found_inf, &inv_scale); - assert_eq!(Vec::::from(&u), &[1f32, 2f32]); - assert_eq!(Vec::::from(&found_inf), [0f32]); + assert_eq!(vec_f32_from(&u), &[1f32, 2f32]); + assert_eq!(vec_f32_from(&found_inf), [0f32]); let mut v = Tensor::of_slice(&[1f32, f32::INFINITY]).to_device(Device::Cuda(0)); v.internal_amp_non_finite_check_and_unscale(&mut found_inf, &inv_scale); - assert_eq!(Vec::::from(&v), &[0.1, f32::INFINITY]); - assert_eq!(Vec::::from(&found_inf), [1f32]); + assert_eq!(vec_f32_from(&v), &[0.1, f32::INFINITY]); + assert_eq!(vec_f32_from(&found_inf), [1f32]); u.internal_amp_non_finite_check_and_unscale(&mut found_inf, &inv_scale); - assert_eq!(Vec::::from(&u), &[0.1, 0.2]); + assert_eq!(vec_f32_from(&u), &[0.1, 0.2]); // found_inf is sticky - assert_eq!(Vec::::from(&found_inf), [1f32]); + assert_eq!(vec_f32_from(&found_inf), [1f32]); } #[test] @@ -31,22 +34,22 @@ fn assign_ops() { t += 1; t *= 2; t -= 1; - assert_eq!(Vec::::from(&t), [7, 3, 9, 3, 11]); + assert_eq!(vec_i64_from(&t), [7, 3, 9, 3, 11]); } #[test] fn constant_ops() { let mut t = Tensor::of_slice(&[7i64, 3, 9, 3, 11]); t = -t; - assert_eq!(Vec::::from(&t), [-7, -3, -9, -3, -11]); + assert_eq!(vec_i64_from(&t), [-7, -3, -9, -3, -11]); t = 1 - t; - assert_eq!(Vec::::from(&t), [8, 4, 10, 4, 12]); + assert_eq!(vec_i64_from(&t), [8, 4, 10, 4, 12]); t = 2 * t; - assert_eq!(Vec::::from(&t), [16, 8, 20, 8, 24]); + assert_eq!(vec_i64_from(&t), [16, 8, 20, 8, 24]); let mut t = Tensor::of_slice(&[0.2f64, 0.1]); t = 2 / t; - assert_eq!(Vec::::from(&t), [10.0, 20.0]); + assert_eq!(vec_f64_from(&t), [10.0, 20.0]); } #[test] @@ -63,18 +66,18 @@ fn iter() { fn array_conversion() { let vec: Vec<_> = (0..6).map(|x| (x * x) as f64).collect(); let t = Tensor::of_slice(&vec); - assert_eq!(Vec::::from(&t), [0.0, 1.0, 4.0, 9.0, 16.0, 25.0]); + assert_eq!(vec_f64_from(&t), [0.0, 1.0, 4.0, 9.0, 16.0, 25.0]); let t = t.view([3, 2]); - assert_eq!(Vec::>::from(&t), [[0.0, 1.0], [4.0, 9.0], [16.0, 25.0]]); + assert_eq!(from::>>(&t), [[0.0, 1.0], [4.0, 9.0], [16.0, 25.0]]); let t = t.view([2, 3]); - assert_eq!(Vec::>::from(&t), [[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]) + assert_eq!(from::>>(&t), [[0.0, 1.0, 4.0], [9.0, 16.0, 25.0]]) } #[test] fn binary_ops() { let t = Tensor::of_slice(&[3.0, 1.0, 4.0, 1.0, 5.0]); let t = (&t * &t) + &t - 1.5; - assert_eq!(Vec::::from(&t), [10.5, 0.5, 18.5, 0.5, 28.5]); + assert_eq!(vec_f64_from(&t), [10.5, 0.5, 18.5, 0.5, 28.5]); } #[test] @@ -84,7 +87,7 @@ fn grad() { x.zero_grad(); y.backward(); let dy_over_dx = x.grad(); - assert_eq!(Vec::::from(&dy_over_dx), [5.0]); + assert_eq!(vec_f64_from(&dy_over_dx), [5.0]); } #[test] @@ -98,7 +101,7 @@ fn grad_grad() { let dy_over_dx = &dy_over_dx[0]; dy_over_dx.backward(); let dy_over_dx2 = x.grad(); - assert_eq!(f64::from(&dy_over_dx2), 254.0); + assert_eq!(from::(&dy_over_dx2), 254.0); } #[test] @@ -114,17 +117,17 @@ fn cat_and_stack() { let t = Tensor::of_slice(&[13.0, 37.0]); let t = Tensor::cat(&[&t, &t, &t], 0); assert_eq!(t.size(), [6]); - assert_eq!(Vec::::from(&t), [13.0, 37.0, 13.0, 37.0, 13.0, 37.0]); + assert_eq!(vec_f64_from(&t), [13.0, 37.0, 13.0, 37.0, 13.0, 37.0]); let t = Tensor::of_slice(&[13.0, 37.0]); let t = Tensor::stack(&[&t, &t, &t], 0); assert_eq!(t.size(), [3, 2]); - assert_eq!(Vec::::from(&t), [13.0, 37.0, 13.0, 37.0, 13.0, 37.0]); + assert_eq!(vec_f64_from(&t), [13.0, 37.0, 13.0, 37.0, 13.0, 37.0]); let t = Tensor::of_slice(&[13.0, 37.0]); let t = Tensor::stack(&[&t, &t, &t], 1); assert_eq!(t.size(), [2, 3]); - assert_eq!(Vec::::from(&t), [13.0, 13.0, 13.0, 37.0, 37.0, 37.0]); + assert_eq!(vec_f64_from(&t), [13.0, 13.0, 13.0, 37.0, 37.0, 37.0]); } #[test] @@ -132,7 +135,7 @@ fn onehot() { let xs = Tensor::of_slice(&[0, 1, 2, 3]); let onehot = xs.onehot(4); assert_eq!( - Vec::::from(&onehot), + vec_f64_from(&onehot), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0] ); assert_eq!(onehot.size(), vec![4, 4]) @@ -152,9 +155,9 @@ fn chunk() { let xs = Tensor::of_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); let tensors = xs.chunk(3, 0); assert_eq!(tensors.len(), 3); - assert_eq!(Vec::::from(&tensors[0]), vec![0, 1, 2, 3]); - assert_eq!(Vec::::from(&tensors[1]), vec![4, 5, 6, 7]); - assert_eq!(Vec::::from(&tensors[2]), vec![8, 9]); + assert_eq!(vec_i64_from(&tensors[0]), vec![0, 1, 2, 3]); + assert_eq!(vec_i64_from(&tensors[1]), vec![4, 5, 6, 7]); + assert_eq!(vec_i64_from(&tensors[2]), vec![8, 9]); } #[test] @@ -163,8 +166,8 @@ fn broadcast() { let ys = Tensor::from(42); let tensors = Tensor::broadcast_tensors(&[xs, ys]); assert_eq!(tensors.len(), 2); - assert_eq!(Vec::::from(&tensors[0]), vec![4, 5, 3]); - assert_eq!(Vec::::from(&tensors[1]), vec![42, 42, 42]); + assert_eq!(vec_i64_from(&tensors[0]), vec![4, 5, 3]); + assert_eq!(vec_i64_from(&tensors[1]), vec![42, 42, 42]); } #[test] @@ -198,89 +201,83 @@ fn values_at_index() { #[test] fn into_ndarray_f64() { - let tensor = Tensor::of_slice(&[1., 2., 3., 4.]).reshape([2, 2]); + let tensor = Tensor::of_slice(&[1., 2., 3., 4.]); let nd: ndarray::ArrayD = (&tensor).try_into().unwrap(); - assert_eq!(Vec::::from(tensor).as_slice(), nd.as_slice().unwrap()); + assert_eq!(vec_f64_from(&tensor).as_slice(), nd.as_slice().unwrap()); } #[test] fn into_ndarray_i64() { - let tensor = Tensor::of_slice(&[1, 2, 3, 4]).reshape([2, 2]); + let tensor = Tensor::of_slice(&[1, 2, 3, 4]); let nd: ndarray::ArrayD = (&tensor).try_into().unwrap(); - assert_eq!(Vec::::from(tensor).as_slice(), nd.as_slice().unwrap()); + assert_eq!(vec_i64_from(&tensor).as_slice(), nd.as_slice().unwrap()); } #[test] fn from_ndarray_f64() { let nd = ndarray::arr2(&[[1f64, 2.], [3., 4.]]); let tensor = Tensor::try_from(nd.clone()).unwrap(); - assert_eq!(Vec::::from(tensor).as_slice(), nd.as_slice().unwrap()); + assert_eq!(vec_f64_from(&tensor).as_slice(), nd.as_slice().unwrap()); } #[test] fn from_ndarray_i64() { let nd = ndarray::arr2(&[[1i64, 2], [3, 4]]); let tensor = Tensor::try_from(nd.clone()).unwrap(); - assert_eq!(Vec::::from(tensor).as_slice(), nd.as_slice().unwrap()); + assert_eq!(vec_i64_from(&tensor).as_slice(), nd.as_slice().unwrap()); } #[test] fn from_ndarray_bool() { let nd = ndarray::arr2(&[[true, false], [true, true]]); let tensor = Tensor::try_from(nd.clone()).unwrap(); - assert_eq!(Vec::::from(tensor).as_slice(), nd.as_slice().unwrap()); + assert_eq!(vec_bool_from(&tensor).as_slice(), nd.as_slice().unwrap()); } #[test] fn from_primitive() -> Result<()> { - assert_eq!(Vec::::from(Tensor::try_from(1_i32)?), vec![1]); - assert_eq!(Vec::::from(Tensor::try_from(1_i64)?), vec![1]); - assert_eq!(Vec::::from(Tensor::try_from(f16::from_f64(1.0))?), vec![f16::from_f64(1.0)]); - assert_eq!(Vec::::from(Tensor::try_from(1_f32)?), vec![1.0]); - assert_eq!(Vec::::from(Tensor::try_from(1_f64)?), vec![1.0]); - assert_eq!(Vec::::from(Tensor::try_from(true)?), vec![true]); + assert_eq!(vec_i32_from(&Tensor::try_from(1_i32)?), vec![1]); + assert_eq!(vec_i64_from(&Tensor::try_from(1_i64)?), vec![1]); + assert_eq!(vec_f16_from(&Tensor::try_from(f16::from_f64(1.0))?), vec![f16::from_f64(1.0)]); + assert_eq!(vec_f32_from(&Tensor::try_from(1_f32)?), vec![1.0]); + assert_eq!(vec_f64_from(&Tensor::try_from(1_f64)?), vec![1.0]); + assert_eq!(vec_bool_from(&Tensor::try_from(true)?), vec![true]); Ok(()) } #[test] fn from_vec() -> Result<()> { - assert_eq!(Vec::::from(Tensor::try_from(vec![-1_i32, 0, 1])?), vec![-1, 0, 1]); - assert_eq!(Vec::::from(Tensor::try_from(vec![-1_i64, 0, 1])?), vec![-1, 0, 1]); + assert_eq!(vec_i32_from(&Tensor::try_from(vec![-1_i32, 0, 1])?), vec![-1, 0, 1]); + assert_eq!(vec_i64_from(&Tensor::try_from(vec![-1_i64, 0, 1])?), vec![-1, 0, 1]); assert_eq!( - Vec::::from(Tensor::try_from(vec![ + from::>(&Tensor::try_from(vec![ f16::from_f64(-1.0), f16::from_f64(0.0), f16::from_f64(1.0) ])?), vec![f16::from_f64(-1.0), f16::from_f64(0.0), f16::from_f64(1.0)] ); - assert_eq!(Vec::::from(Tensor::try_from(vec![-1_f32, 0.0, 1.0])?), vec![-1.0, 0.0, 1.0]); - assert_eq!(Vec::::from(Tensor::try_from(vec![-1_f64, 0.0, 1.0])?), vec![-1.0, 0.0, 1.0]); - assert_eq!(Vec::::from(Tensor::try_from(vec![true, false])?), vec![true, false]); + assert_eq!(vec_f32_from(&Tensor::try_from(vec![-1_f32, 0.0, 1.0])?), vec![-1.0, 0.0, 1.0]); + assert_eq!(vec_f64_from(&Tensor::try_from(vec![-1_f64, 0.0, 1.0])?), vec![-1.0, 0.0, 1.0]); + assert_eq!(vec_bool_from(&Tensor::try_from(vec![true, false])?), vec![true, false]); Ok(()) } #[test] fn from_slice() -> Result<()> { - assert_eq!(Vec::::from(Tensor::try_from(&[-1_i32, 0, 1] as &[_])?), vec![-1, 0, 1]); - assert_eq!(Vec::::from(Tensor::try_from(&[-1_i64, 0, 1] as &[_])?), vec![-1, 0, 1]); + assert_eq!(vec_i32_from(&Tensor::try_from(&[-1_i32, 0, 1] as &[_])?), vec![-1, 0, 1]); + assert_eq!(vec_i64_from(&Tensor::try_from(&[-1_i64, 0, 1] as &[_])?), vec![-1, 0, 1]); assert_eq!( - Vec::::from(Tensor::try_from(&[ + vec_f16_from(&Tensor::try_from(&[ f16::from_f64(-1.0), f16::from_f64(0.0), f16::from_f64(1.0) ] as &[_])?), vec![f16::from_f64(-1.0), f16::from_f64(0.0), f16::from_f64(1.0)] ); - assert_eq!( - Vec::::from(Tensor::try_from(&[-1_f32, 0.0, 1.0] as &[_])?), - vec![-1.0, 0.0, 1.0] - ); - assert_eq!( - Vec::::from(Tensor::try_from(&[-1_f64, 0.0, 1.0] as &[_])?), - vec![-1.0, 0.0, 1.0] - ); - assert_eq!(Vec::::from(Tensor::try_from(&[true, false] as &[_])?), vec![true, false]); + assert_eq!(vec_f32_from(&Tensor::try_from(&[-1_f32, 0.0, 1.0] as &[_])?), vec![-1.0, 0.0, 1.0]); + assert_eq!(vec_f64_from(&Tensor::try_from(&[-1_f64, 0.0, 1.0] as &[_])?), vec![-1.0, 0.0, 1.0]); + assert_eq!(vec_bool_from(&Tensor::try_from(&[true, false] as &[_])?), vec![true, false]); Ok(()) } @@ -303,24 +300,24 @@ fn where_() { let t1 = Tensor::of_slice(&[3, 1, 4, 1, 5, 9]); let t2 = Tensor::of_slice(&[2, 7, 1, 8, 2, 8]); let t = t1.where_self(&t1.lt(4), &t2); - assert_eq!(Vec::::from(&t), [3, 1, 1, 1, 2, 8]); + assert_eq!(vec_i64_from(&t), [3, 1, 1, 1, 2, 8]); } #[test] fn bool_tensor() { let t1 = Tensor::of_slice(&[true, true, false]); - assert_eq!(Vec::::from(&t1), [1, 1, 0]); - assert_eq!(Vec::::from(&t1), [true, true, false]); + assert_eq!(vec_i64_from(&t1), [1, 1, 0]); + assert_eq!(vec_bool_from(&t1), [true, true, false]); let t1 = Tensor::of_slice(&[0, 1, 0]).to_kind(tch::Kind::Bool); let t2 = Tensor::of_slice(&[1, 1, 1]).to_kind(tch::Kind::Bool); let t1_any = t1.any(); let t2_any = t2.any(); let t1_all = t1.all(); let t2_all = t2.all(); - assert!(bool::from(&t1_any)); - assert!(!bool::from(&t1_all)); - assert!(bool::from(&t2_any)); - assert!(bool::from(&t2_all)); + assert!(from::(&t1_any)); + assert!(!from::(&t1_all)); + assert!(from::(&t2_any)); + assert!(from::(&t2_all)); } #[test] @@ -352,20 +349,20 @@ fn einsum() { // Element-wise squaring of a vector. let t = Tensor::of_slice(&[1.0, 2.0, 3.0]); let t = Tensor::einsum("i, i -> i", &[&t, &t], None::); - assert_eq!(Vec::::from(&t), [1.0, 4.0, 9.0]); + assert_eq!(vec_f64_from(&t), [1.0, 4.0, 9.0]); // Matrix transpose let t = Tensor::of_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape([2, 3]); let t = Tensor::einsum("ij -> ji", &[t], None::); - assert_eq!(Vec::::from(&t), [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); + assert_eq!(vec_f64_from(&t), [1.0, 4.0, 2.0, 5.0, 3.0, 6.0]); // Sum all elements let t = Tensor::einsum("ij -> ", &[t], None::); - assert_eq!(Vec::::from(&t), [21.0]); + assert_eq!(vec_f64_from(&t), [21.0]); } #[test] fn vec2() { let tensor = Tensor::of_slice(&[1., 2., 3., 4., 5., 6.]).reshape([2, 3]); - assert_eq!(Vec::>::from(tensor), [[1., 2., 3.], [4., 5., 6.]]) + assert_eq!(Vec::>::try_from(tensor).unwrap(), [[1., 2., 3.], [4., 5., 6.]]) } #[test] @@ -374,22 +371,22 @@ fn upsample1d() { let up1 = tensor.upsample_linear1d([2], false, 1.); assert_eq!( // Exclude the last element because of some numerical instability. - Vec::::from(up1)[0..11], + vec_f64_from(&up1)[0..11], [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0] ); let up1 = tensor.upsample_linear1d([2], false, None); - assert_eq!(Vec::::from(up1), [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0]); + assert_eq!(vec_f64_from(&up1), [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0]); } #[test] fn argmax() { let tensor = Tensor::of_slice(&[7., 2., 3., 4., 5., 6.]).reshape([2, 3]); let argmax = tensor.argmax(None, false); - assert_eq!(Vec::::from(argmax), [0],); + assert_eq!(vec_i64_from(&argmax), [0],); let argmax = tensor.argmax(0, false); - assert_eq!(Vec::::from(argmax), [0, 1, 1],); + assert_eq!(vec_i64_from(&argmax), [0, 1, 1],); let argmax = tensor.argmax(-1, false); - assert_eq!(Vec::::from(argmax), [0, 2],); + assert_eq!(vec_i64_from(&argmax), [0, 2],); } #[test] @@ -420,7 +417,7 @@ fn nested_tensor() { let vec: Vec> = vec![vec![1, 2], vec![1, 2], vec![4, 5]]; let t = Tensor::of_slice2(&vec); assert_eq!(t.size(), [3, 2]); - assert_eq!(Vec::::from(t.view([-1])), [1, 2, 1, 2, 4, 5]); + assert_eq!(vec_i32_from(&t.view([-1])), [1, 2, 1, 2, 4, 5]); } #[test] @@ -428,7 +425,7 @@ fn quantized() { let t = Tensor::of_slice(&[-1f32, 0., 1., 2., 120., 0.42]); let t = t.quantize_per_tensor(0.1, 10, tch::Kind::QUInt8); let t = t.dequantize(); - assert_eq!(Vec::::from(&t), [-1f32, 0., 1., 2., 24.5, 0.4]); + assert_eq!(vec_f32_from(&t), [-1f32, 0., 1., 2., 24.5, 0.4]); } #[test] diff --git a/tests/test_utils.rs b/tests/test_utils.rs new file mode 100644 index 00000000..29f7fe2c --- /dev/null +++ b/tests/test_utils.rs @@ -0,0 +1,44 @@ +use tch::Tensor; + +pub fn from<'a, T>(t: &'a Tensor) -> T +where + >::Error: std::fmt::Debug, + T: TryFrom<&'a Tensor>, +{ + T::try_from(t).unwrap() +} + +#[allow(dead_code)] +pub fn f64_from(t: &Tensor) -> f64 { + from::(t) +} + +#[allow(dead_code)] +pub fn vec_f64_from(t: &Tensor) -> Vec { + from::>(&t.reshape(-1)) +} + +#[allow(dead_code)] +pub fn vec_f32_from(t: &Tensor) -> Vec { + from::>(&t.reshape(-1)) +} + +#[allow(dead_code)] +pub fn vec_f16_from(t: &Tensor) -> Vec { + from::>(&t.reshape(-1)) +} + +#[allow(dead_code)] +pub fn vec_i64_from(t: &Tensor) -> Vec { + from::>(&t.reshape(-1)) +} + +#[allow(dead_code)] +pub fn vec_i32_from(t: &Tensor) -> Vec { + from::>(&t.reshape(-1)) +} + +#[allow(dead_code)] +pub fn vec_bool_from(t: &Tensor) -> Vec { + from::>(&t.reshape(-1)) +} diff --git a/tests/var_store.rs b/tests/var_store.rs index f656d3d0..6d5fb890 100644 --- a/tests/var_store.rs +++ b/tests/var_store.rs @@ -2,6 +2,9 @@ use std::fs; use tch::nn::OptimizerConfig; use tch::{nn, nn::linear, nn::Init, nn::VarStore, Device, Kind, TchError, Tensor}; +mod test_utils; +use test_utils::*; + #[test] fn path_components() { let vs = VarStore::new(Device::Cpu); @@ -42,15 +45,15 @@ fn save_and_load_var_store() { u1 += 42.0; v1 *= 2.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v1.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save(&filename).unwrap(); vs2.load(&filename).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 2.0); fs::remove_file(filename).unwrap(); } @@ -73,15 +76,15 @@ fn save_to_stream_and_load_var_store() { u1 += 42.0; v1 *= 2.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v1.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save_to_stream(std::fs::File::create(&filename).unwrap()).unwrap(); vs2.load(&filename).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 2.0); fs::remove_file(filename).unwrap(); } @@ -104,15 +107,15 @@ fn save_and_load_from_stream_var_store() { u1 += 42.0; v1 *= 2.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v1.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save(&filename).unwrap(); vs2.load_from_stream(std::fs::File::open(&filename).unwrap()).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 2.0); fs::remove_file(filename).unwrap(); } @@ -135,15 +138,15 @@ fn save_and_load_partial_var_store() { u1 += 42.0; v1 *= 2.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v1.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save(&filename).unwrap(); let missing_variables = vs2.load_partial(&filename).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 2.0); assert!(missing_variables.is_empty()); fs::remove_file(filename).unwrap(); } @@ -171,14 +174,14 @@ fn save_and_load_var_store_incomplete_file() { tch::no_grad(|| { u1 += 42.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save(&filename).unwrap(); vs2.load(&filename).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); fs::remove_file(filename).unwrap(); } @@ -204,14 +207,14 @@ fn save_and_load_partial_var_store_incomplete_file() { tch::no_grad(|| { u1 += 42.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save(&filename).unwrap(); let missing_variables = vs2.load_partial(&filename).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); assert_eq!(missing_variables, vec!(String::from("a.b.t2"))); fs::remove_file(filename).unwrap(); } @@ -221,32 +224,32 @@ fn init_test() { tch::manual_seed(42); let vs = VarStore::new(Device::Cpu); let zeros = vs.root().zeros("t1", &[3]); - assert_eq!(Vec::::from(&zeros), [0., 0., 0.]); + assert_eq!(vec_f64_from(&zeros), [0., 0., 0.]); let zeros = vs.root().var("t2", &[3], Init::Const(0.)); - assert_eq!(Vec::::from(&zeros), [0., 0., 0.]); + assert_eq!(vec_f64_from(&zeros), [0., 0., 0.]); let ones = vs.root().var("t3", &[3], Init::Const(1.)); - assert_eq!(Vec::::from(&ones), [1., 1., 1.]); + assert_eq!(vec_f64_from(&ones), [1., 1., 1.]); let ones = vs.root().var("t4", &[3], Init::Const(0.5)); - assert_eq!(Vec::::from(&ones), [0.5, 0.5, 0.5]); + assert_eq!(vec_f64_from(&ones), [0.5, 0.5, 0.5]); let forty_two = vs.root().var("t4", &[2], Init::Const(42.)); - assert_eq!(Vec::::from(&forty_two), [42., 42.]); + assert_eq!(vec_f64_from(&forty_two), [42., 42.]); let uniform = vs.root().var("t5", &[100], Init::Uniform { lo: 1.0, up: 2.0 }); - let uniform_min = f64::from(&uniform.min()); - let uniform_max = f64::from(&uniform.max()); + let uniform_min = f64_from(&uniform.min()); + let uniform_max = f64_from(&uniform.max()); assert!(uniform_min >= 1., "{}", "min {uniform_min}"); assert!(uniform_max <= 2., "{}", "max {uniform_max}"); - let uniform_std = f64::from(&uniform.std(true)); + let uniform_std = f64_from(&uniform.std(true)); assert!(uniform_std > 0.15 && uniform_std < 0.35, "{}", "std {uniform_std}"); let normal = vs.root().var("normal", &[100], Init::Randn { mean: 0., stdev: 0.02 }); - let normal_std = f64::from(&normal.std(true)); + let normal_std = f64_from(&normal.std(true)); assert!(normal_std <= 0.03, "{}", "std {normal_std}"); let mut vs2 = VarStore::new(Device::Cpu); let ones = vs2.root().ones("t1", &[3]); - assert_eq!(Vec::::from(&ones), [1., 1., 1.]); + assert_eq!(vec_f64_from(&ones), [1., 1., 1.]); vs2.copy(&vs).unwrap(); - assert_eq!(Vec::::from(&ones), [0., 0., 0.]); + assert_eq!(vec_f64_from(&ones), [0., 0., 0.]); let ortho = vs.root().var("orthogonal", &[100, 100], Init::Orthogonal { gain: 2.0 }); - let ortho_norm = f64::from(ortho.linalg_norm_ord_str("fro", None::, true, Kind::Float)); + let ortho_norm = f64_from(&ortho.linalg_norm_ord_str("fro", None::, true, Kind::Float)); assert!( f64::abs(ortho_norm - 20.) < 1e-5, "{}", @@ -255,12 +258,12 @@ fn init_test() { let ortho_shape_fail = tch::nn::f_init(Init::Orthogonal { gain: 1.0 }, &[10], Device::Cpu); assert!(ortho_shape_fail.is_err()); let kaiming_u = vs.root().var("kaiming_u", &[20, 100], nn::init::DEFAULT_KAIMING_UNIFORM); - assert!(f64::abs(f64::from(kaiming_u.mean(Kind::Float))) < 5e-3); + assert!(f64::abs(f64_from(&kaiming_u.mean(Kind::Float))) < 5e-3); // The expected stdev is sqrt(2 / 100) - assert!(f64::abs(f64::from(kaiming_u.std(true)) - (0.02f64).sqrt()) < 2e-3); + assert!(f64::abs(f64_from(&kaiming_u.std(true)) - (0.02f64).sqrt()) < 2e-3); let kaiming_n = vs.root().var("kaiming_n", &[20, 100], nn::init::DEFAULT_KAIMING_NORMAL); - assert!(f64::abs(f64::from(kaiming_n.mean(Kind::Float))) < 5e-3); - assert!(f64::abs(f64::from(kaiming_n.std(true)) - (0.02f64).sqrt()) < 3e-3); + assert!(f64::abs(f64_from(&kaiming_n.mean(Kind::Float))) < 5e-3); + assert!(f64::abs(f64_from(&kaiming_n.std(true)) - (0.02f64).sqrt()) < 3e-3); } fn check_param_group(mut opt: tch::nn::Optimizer, var_foo: Tensor, var_bar: Tensor) { @@ -270,37 +273,37 @@ fn check_param_group(mut opt: tch::nn::Optimizer, var_foo: Tensor, var_bar: Tens let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(0.42f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.00"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.42"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.00"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.42"); opt.set_lr_group(0, 0.1); for _idx in 1..100 { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(0f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "-0.21"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.21"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "-0.21"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.21"); opt.set_lr_group(7, 0.); for _idx in 1..100 { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(0.22f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.01"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.21"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.01"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.21"); // The following sets the learning rate for both groups. opt.set_lr(0.); for _idx in 1..100 { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(0.42f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.01"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.21"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.01"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.21"); opt.set_lr(0.1); for _idx in 1..100 { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(0.42f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.11"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.31"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.11"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.31"); } #[test] @@ -332,15 +335,15 @@ fn save_and_load_with_group() { u1 += 42.0; v1 *= 2.0; }); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v1.mean(Kind::Float)), 2.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 0.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 1.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v1.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 0.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 1.0); vs1.save(&filename).unwrap(); vs2.load(&filename).unwrap(); - assert_eq!(f64::from(&u1.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&u2.mean(Kind::Float)), 42.0); - assert_eq!(f64::from(&v2.mean(Kind::Float)), 2.0); + assert_eq!(f64_from(&u1.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&u2.mean(Kind::Float)), 42.0); + assert_eq!(f64_from(&v2.mean(Kind::Float)), 2.0); fs::remove_file(filename).unwrap(); } @@ -357,22 +360,22 @@ fn param_group_weight_decay() { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(1f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.50"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.50"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.50"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.50"); opt.set_weight_decay(0.1); for _idx in 1..100 { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(1f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.49"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.49"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.49"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.49"); opt.set_weight_decay_group(7, 0.); for _idx in 1..100 { let loss = (&var_foo + &var_bar).mse_loss(&Tensor::from(1f32), tch::Reduction::Mean); opt.backward_step(&loss); } - assert_eq!(format!("{:.2}", f64::from(&var_foo)), "0.30"); - assert_eq!(format!("{:.2}", f64::from(&var_bar)), "0.69"); + assert_eq!(format!("{:.2}", f64_from(&var_foo)), "0.30"); + assert_eq!(format!("{:.2}", f64_from(&var_bar)), "0.69"); } #[test] @@ -410,9 +413,9 @@ fn half_precision_conversion_entire_varstore() { assert_eq!(vs.root().get("zeros").unwrap().kind(), Kind::Float); assert_eq!(vs.root().get("ones").unwrap().kind(), Kind::Float); assert_eq!(vs.root().get("forty_two").unwrap().kind(), Kind::Float); - assert_eq!(format!("{:.2}", f64::from(vs.root().get("zeros").unwrap())), "0.00"); - assert_eq!(format!("{:.2}", f64::from(vs.root().get("ones").unwrap())), "1.00"); - assert_eq!(format!("{:.2}", f64::from(vs.root().get("forty_two").unwrap())), "42.00"); + assert_eq!(format!("{:.2}", f64_from(&vs.root().get("zeros").unwrap())), "0.00"); + assert_eq!(format!("{:.2}", f64_from(&vs.root().get("ones").unwrap())), "1.00"); + assert_eq!(format!("{:.2}", f64_from(&vs.root().get("forty_two").unwrap())), "42.00"); } #[test] @@ -476,13 +479,13 @@ fn path_free_type_conversion() { assert_eq!(vs.root().sub("convert").sub("group_1").get("ones").unwrap().kind(), Kind::Float); assert_eq!(vs.root().sub("convert").sub("group_2").get("zeros").unwrap().kind(), Kind::Float); - assert_eq!(format!("{:.2}", f64::from(vs.root().sub("ignore").get("zeros").unwrap())), "0.00"); + assert_eq!(format!("{:.2}", f64_from(&vs.root().sub("ignore").get("zeros").unwrap())), "0.00"); assert_eq!( - format!("{:.2}", f64::from(vs.root().sub("convert").sub("group_1").get("ones").unwrap())), + format!("{:.2}", f64_from(&vs.root().sub("convert").sub("group_1").get("ones").unwrap())), "1.00" ); assert_eq!( - format!("{:.2}", f64::from(vs.root().sub("convert").sub("group_2").get("zeros").unwrap())), + format!("{:.2}", f64_from(&vs.root().sub("convert").sub("group_2").get("zeros").unwrap())), "0.00" ); }