From 45825747ac3db7d5a918aebe4e62e9e8bff1466f Mon Sep 17 00:00:00 2001
From: lochhh <changhuan.lo@ucl.ac.uk>
Date: Thu, 25 Jan 2024 18:27:14 +0000
Subject: [PATCH] Vectorise kinematic functions

---
 movement/analysis/kinematics.py    | 99 +++++++++++++++++++-----------
 tests/test_unit/test_kinematics.py | 90 +++++++++++++++++++--------
 2 files changed, 128 insertions(+), 61 deletions(-)

diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py
index 9602b5eb1..f5b176158 100644
--- a/movement/analysis/kinematics.py
+++ b/movement/analysis/kinematics.py
@@ -2,49 +2,58 @@
 import xarray as xr
 
 
-def displacement(data: xr.DataArray) -> np.ndarray:
+def displacement(data: xr.DataArray) -> xr.Dataset:
     """Compute the displacement between consecutive locations
-    of a single keypoint from a single individual.
+    of each keypoint of each individual.
 
     Parameters
     ----------
     data : xarray.DataArray
-        The input data, assumed to be of shape (..., 2), where the last
-        dimension contains the x and y coordinates.
+        The input data, assumed to be of shape (..., 2), where
+        the last dimension contains the x and y coordinates.
 
     Returns
     -------
-    numpy.ndarray
-        A numpy array containing the computed magnitude and
+    xarray.Dataset
+        An xarray Dataset containing the computed magnitude and
         direction of the displacement.
     """
-    displacement_vector = np.diff(data, axis=0, prepend=data[0:1])
-    magnitude = np.linalg.norm(displacement_vector, axis=1)
-    direction = np.arctan2(
-        displacement_vector[..., 1], displacement_vector[..., 0]
+    displacement_da = data.diff(dim="time")
+    magnitude = xr.apply_ufunc(
+        np.linalg.norm,
+        displacement_da,
+        input_core_dims=[["space"]],
+        kwargs={"axis": -1},
     )
-    return np.stack((magnitude, direction), axis=1)
+    magnitude = magnitude.reindex_like(data.sel(space="x"))
+    direction = xr.apply_ufunc(
+        np.arctan2,
+        displacement_da[..., 1],
+        displacement_da[..., 0],
+    )
+    direction = direction.reindex_like(data.sel(space="x"))
+    return xr.Dataset({"magnitude": magnitude, "direction": direction})
 
 
-def distance(data: xr.DataArray) -> np.ndarray:
-    """Compute the distances between consecutive locations of
-    a single keypoint from a single individual.
+def distance(data: xr.DataArray) -> xr.DataArray:
+    """Compute the Euclidean distances between consecutive
+    locations of each keypoint of each individual.
 
     Parameters
     ----------
     data : xarray.DataArray
-        The input data, assumed to be of shape (..., 2), where the last
-        dimension contains the x and y coordinates.
+        The input data, assumed to be of shape (..., 2), where
+        the last dimension contains the x and y coordinates.
 
     Returns
     -------
-    numpy.ndarray
-        A numpy array containing the computed distance.
+    xarray.DataArray
+        An xarray DataArray containing the magnitude of displacement.
     """
-    return displacement(data)[:, 0]
+    return displacement(data).magnitude
 
 
-def velocity(data: xr.DataArray) -> np.ndarray:
+def velocity(data: xr.DataArray) -> xr.Dataset:
     """Compute the velocity of a single keypoint from
     a single individual.
 
@@ -56,14 +65,15 @@ def velocity(data: xr.DataArray) -> np.ndarray:
 
     Returns
     -------
-    numpy.ndarray
-        A numpy array containing the computed velocity.
+    xarray.Dataset
+        An xarray Dataset containing the computed magnitude and
+        direction of the velocity.
     """
     return approximate_derivative(data, order=1)
 
 
-def speed(data: xr.DataArray) -> np.ndarray:
-    """Compute velocity based on the Euclidean norm (magnitude) of the
+def speed(data: xr.DataArray) -> xr.DataArray:
+    """Compute speed based on the Euclidean norm (magnitude) of the
     differences between consecutive points, i.e. the straight-line
     distance travelled, assuming equidistant time spacing.
 
@@ -75,10 +85,10 @@ def speed(data: xr.DataArray) -> np.ndarray:
 
     Returns
     -------
-    numpy.ndarray
-        A numpy array containing the computed velocity.
+    xarray.DataArray
+        An xarray DataArray containing the magnitude of velocity.
     """
-    return velocity(data)[:, 0]
+    return velocity(data).magnitude
 
 
 def acceleration(data: xr.DataArray) -> np.ndarray:
@@ -99,7 +109,7 @@ def acceleration(data: xr.DataArray) -> np.ndarray:
     return approximate_derivative(data, order=2)
 
 
-def approximate_derivative(data: xr.DataArray, order: int = 1) -> np.ndarray:
+def approximate_derivative(data: xr.DataArray, order: int = 1) -> xr.Dataset:
     """Compute velocity or acceleration using numerical differentiation,
     assuming equidistant time spacing.
 
@@ -114,9 +124,9 @@ def approximate_derivative(data: xr.DataArray, order: int = 1) -> np.ndarray:
 
     Returns
     -------
-    numpy.ndarray
-        A numpy array containing the computed magnitudes and directions of
-        the kinematic variable.
+    xarray.Dataset
+        An xarray Dataset containing the computed magnitudes and
+        directions of the derived variable.
     """
     if order <= 0:
         raise ValueError("order must be a positive integer.")
@@ -124,12 +134,27 @@ def approximate_derivative(data: xr.DataArray, order: int = 1) -> np.ndarray:
         result = data
         dt = data["time"].diff(dim="time").values[0]
         for _ in range(order):
-            result = np.gradient(result, dt, axis=0)
-        # Prepend with zeros to match match output to the input shape
-        result = np.pad(result[1:], ((1, 0), (0, 0)), "constant")
-    magnitude = np.linalg.norm(result, axis=-1)
-    direction = np.arctan2(result[..., 1], result[..., 0])
-    return np.stack((magnitude, direction), axis=1)
+            result = xr.apply_ufunc(
+                np.gradient,
+                result,
+                dt,
+                kwargs={"axis": 0},
+            )
+        result = result.reindex_like(data)
+    magnitude = xr.apply_ufunc(
+        np.linalg.norm,
+        result,
+        input_core_dims=[["space"]],
+        kwargs={"axis": -1},
+    )
+    magnitude = magnitude.reindex_like(data.sel(space="x"))
+    direction = xr.apply_ufunc(
+        np.arctan2,
+        result[..., 1],
+        result[..., 0],
+    )
+    direction = direction.reindex_like(data.sel(space="x"))
+    return xr.Dataset({"magnitude": magnitude, "direction": direction})
 
 
 # Locomotion Features
diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py
index c9ce36647..cb8624ced 100644
--- a/tests/test_unit/test_kinematics.py
+++ b/tests/test_unit/test_kinematics.py
@@ -1,5 +1,6 @@
 import numpy as np
 import pytest
+import xarray as xr
 
 from movement.analysis import kinematics
 
@@ -9,47 +10,88 @@ class TestKinematics:
 
     def test_distance(self, valid_pose_dataset):
         """Test distance calculation."""
-        # Select a single keypoint from a single individual
-        data = valid_pose_dataset.pose_tracks.isel(keypoints=0, individuals=0)
+        data = valid_pose_dataset.pose_tracks
         result = kinematics.distance(data)
-        expected = np.pad([5.0] * 9, (1, 0), "constant")
-        assert np.allclose(result, expected)
+        expected = np.full((10, 2, 2), 5.0)
+        expected[0, :, :] = np.nan
+        np.testing.assert_allclose(result.values, expected)
 
     def test_displacement(self, valid_pose_dataset):
         """Test displacement calculation."""
-        # Select a single keypoint from a single individual
-        data = valid_pose_dataset.pose_tracks.isel(keypoints=0, individuals=0)
+        data = valid_pose_dataset.pose_tracks
         result = kinematics.displacement(data)
-        expected_magnitude = np.pad([5.0] * 9, (1, 0), "constant")
-        expected_direction = np.concatenate(([0], np.full(9, 0.92729522)))
-        expected = np.stack((expected_magnitude, expected_direction), axis=1)
-        assert np.allclose(result, expected)
+        expected_magnitude = np.full((10, 2, 2), 5.0)
+        expected_magnitude[0, :, :] = np.nan
+        expected_direction = np.full((10, 2, 2), 0.92729522)
+        expected_direction[0, :, :] = np.nan
+        expected = xr.Dataset(
+            data_vars={
+                "magnitude": xr.DataArray(
+                    expected_magnitude, dims=data.dims[:-1]
+                ),
+                "direction": xr.DataArray(
+                    expected_direction, dims=data.dims[:-1]
+                ),
+            },
+            coords={
+                "time": data.time,
+                "keypoints": data.keypoints,
+                "individuals": data.individuals,
+            },
+        )
+        xr.testing.assert_allclose(result, expected)
 
     def test_velocity(self, valid_pose_dataset):
         """Test velocity calculation."""
-        # Select a single keypoint from a single individual
-        data = valid_pose_dataset.pose_tracks.isel(keypoints=0, individuals=0)
+        data = valid_pose_dataset.pose_tracks
         # Compute velocity
         result = kinematics.velocity(data)
-        expected_magnitude = np.pad([5.0] * 9, (1, 0), "constant")
-        expected_direction = np.concatenate(([0], np.full(9, 0.92729522)))
-        expected = np.stack((expected_magnitude, expected_direction), axis=1)
-        assert np.allclose(result, expected)
+        expected_magnitude = np.full((10, 2, 2), 5.0)
+        expected_direction = np.full((10, 2, 2), 0.92729522)
+        expected = xr.Dataset(
+            data_vars={
+                "magnitude": xr.DataArray(
+                    expected_magnitude, dims=data.dims[:-1]
+                ),
+                "direction": xr.DataArray(
+                    expected_direction, dims=data.dims[:-1]
+                ),
+            },
+            coords={
+                "time": data.time,
+                "keypoints": data.keypoints,
+                "individuals": data.individuals,
+            },
+        )
+        xr.testing.assert_allclose(result, expected)
 
     def test_speed(self, valid_pose_dataset):
-        """Test velocity calculation."""
-        # Select a single keypoint from a single individual
-        data = valid_pose_dataset.pose_tracks.isel(keypoints=0, individuals=0)
+        """Test speed calculation."""
+        data = valid_pose_dataset.pose_tracks
         result = kinematics.speed(data)
-        expected = np.pad([5.0] * 9, (1, 0), "constant")
-        assert np.allclose(result, expected)
+        expected = np.full((10, 2, 2), 5.0)
+        np.testing.assert_allclose(result.values, expected)
 
     def test_acceleration(self, valid_pose_dataset):
         """Test acceleration calculation."""
-        # Select a single keypoint from a single individual
-        data = valid_pose_dataset.pose_tracks.isel(keypoints=0, individuals=0)
+        data = valid_pose_dataset.pose_tracks
         result = kinematics.acceleration(data)
-        assert np.allclose(result, np.zeros((10, 2)))
+        expected = xr.Dataset(
+            data_vars={
+                "magnitude": xr.DataArray(
+                    np.zeros((10, 2, 2)), dims=data.dims[:-1]
+                ),
+                "direction": xr.DataArray(
+                    np.zeros((10, 2, 2)), dims=data.dims[:-1]
+                ),
+            },
+            coords={
+                "time": data.time,
+                "keypoints": data.keypoints,
+                "individuals": data.individuals,
+            },
+        )
+        xr.testing.assert_allclose(result, expected)
 
     def test_approximate_derivative_with_nonpositive_order(self):
         """Test that an error is raised when the order is non-positive."""