Skip to content

Commit

Permalink
lots of new functions and test data moved
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikkel Roald-Arbøl committed Dec 14, 2024
1 parent 4d752ea commit 9e47f16
Show file tree
Hide file tree
Showing 49 changed files with 499 additions and 98 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ Imports:
cli,
collapse,
dplyr,
ggplot2,
ggtext,
janitor,
lifecycle,
methods,
patchwork,
rhdf5,
rlang,
roll,
signal,
stinepack,
stringi,
tidyr,
tidyselect,
Expand All @@ -44,7 +47,6 @@ Suggests:
stringr,
tinyplot,
tinytable,
ggplot2,
performance,
see,
markdown
Expand Down
34 changes: 34 additions & 0 deletions R/calculate_speed.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#' Calculate Speed from Position Data
#'
#' @description
#' Calculates the instantaneous speed from x, y coordinates and time data.
#' Speed is computed as the absolute magnitude of velocity (change in position over time).
#'
#' @param x Numeric vector of x coordinates
#' @param y Numeric vector of y coordinates
#' @param time Numeric vector of time values
#'
#' @return Numeric vector of speeds. The first value will be NA since speed
#' requires two positions to calculate.
#'
#' @examples
#' \dontrun{
#' # Inside dplyr pipeline
#' data %>%
#' group_by(keypoint) %>%
#' mutate(speed = calculate_speed(x, y, time))
#' }
#'
#' @export
calculate_speed <- function(x, y, time) {
# Calculate position differences
dx <- diff(x)
dy <- diff(y)
dt <- diff(time)

# Calculate speed: sqrt((dx/dt)^2 + (dy/dt)^2)
speed <- sqrt((dx/dt)^2 + (dy/dt)^2)

# Add NA at the start since we can't calculate speed for first point
c(NA, speed)
}
143 changes: 143 additions & 0 deletions R/check_timeseries.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#' Plot Time Series of Keypoint Position
#'
#' @description
#' Creates a multi-panel visualization of keypoint position data over time.
#' Each keypoint gets its own panel showing its x and/or y coordinates,
#' with different colors distinguishing between x (orange) and y (blue) coordinates.
#' Useful for visually inspecting movement patterns and identifying potential tracking issues.
#'
#' @param data A data frame containing tracked keypoint data with the following columns:
#' - `time`: Numeric time values
#' - `keypoint`: Factor specifying the keypoint names
#' - `x`: x-coordinates
#' - `y`: y-coordinates
#' @param reference_keypoint Optional character string. If provided, all coordinates will
#' be translated relative to this keypoint's position. Must match one of the keypoint
#' levels in the data.
#' @param dimension Character string specifying which coordinates to plot.
#' Options are:
#' - `"xy"`: Plot both x and y coordinates (default)
#' - `"x"`: Plot only x coordinates
#' - `"y"`: Plot only y coordinates
#'
#' @return A ggplot object combining individual time series plots for each keypoint
#' using patchwork. The plots are stacked vertically with shared axes and legends.
#'
#' @examples
#' \dontrun{
#' # Plot all coordinates
#' check_timeseries(movement_data)
#'
#' # Plot coordinates relative to "head" keypoint
#' check_timeseries(movement_data, reference_keypoint = "head")
#'
#' # Plot only x coordinates
#' check_timeseries(movement_data, dimension = "x")
#' }
#'
#' @seealso
#' `translate_coords()` for the coordinate translation functionality used when
#' `reference_keypoint` is specified.
#'
#' @export
plot_position_timeseries <- function(data, reference_keypoint=NULL, dimension = "xy"){
n_keypoints <- nlevels(data$keypoint)
keypoints <- levels(data$keypoint)
plot_ts <- list()
orange <- "#FFA500"
blue <- "#1f77b4"

if (!is.null(reference_keypoint) && reference_keypoint %in% keypoints){
data <- data |>
translate_coords(to_keypoint = reference_keypoint)
}

for (j in 1:length(keypoints)){
df <- data |>
dplyr::ungroup() |>
dplyr::filter(.data$keypoint == keypoints[j])

plot_ts[[j]] <- df |>
subplot_position_timeseries(keypoint = keypoints[j], dimension = dimension)
}

output_plot <- patchwork::wrap_plots(plot_ts, ncol = 1) +
patchwork::plot_annotation(title = "Time Series of Keypoint Position",
theme = theme(plot.subtitle = ggtext::element_markdown(lineheight = 1.1))) +
patchwork::plot_layout(axes = "collect",
axis_titles = "collect",
guides = "collect")
if (dimension == "xy"){
output_plot <- output_plot +
patchwork::plot_annotation(subtitle = paste0(
"Timeseries for <b style='color:", orange, ";' >X</b> and <b style='color:", blue, "' >Y</b> coordinates
over time"))
} else if (dimension == "x"){
output_plot <- output_plot +
patchwork::plot_annotation(subtitle = paste0(
"Timeseries for <b style='color:", orange, ";' >X</b> coordinates
over time"))
} else if (dimension == "y"){
output_plot <- output_plot +
patchwork::plot_annotation(subtitle = paste0(
"Timeseries for <b style='color:", blue, "' >Y</b> coordinates
over time"))
}

return(output_plot)
}

#' Create Individual Time Series Plot for a Keypoint
#'
#' @description
#' Internal helper function that creates a single time series plot for one keypoint's
#' coordinates.
#'
#' @param data Data frame containing coordinate data for a single keypoint
#' @param keypoint Character string of the keypoint name (used for subtitle)
#' @param dimension Character string: "xy", "x", or "y"
#'
#' @return A ggplot object showing the time series for the specified coordinates
#'
#' @keywords internal
subplot_position_timeseries <- function(data, keypoint, dimension = "xy") {
if (dimension == "xy") {
# Get ranges for scaling
x_range <- range(data$x, na.rm = TRUE)
y_range <- range(data$y, na.rm = TRUE)

# Create plot with dual axes
p <- data |>
ggplot2::ggplot(aes(x = .data$time)) +
# X coordinate on primary y-axis
ggplot2::geom_line(aes(y = .data$x), colour = "#FFA500") +
# Y coordinate on secondary y-axis (scaled)
ggplot2::geom_line(aes(y = scales::rescale(.data$y, to = x_range, from = y_range)),
colour = "#1f77b4") +
# Primary axis label (X coordinate)
ggplot2::scale_y_continuous(
name = "X coordinate",
sec.axis = ggplot2::sec_axis(
# Transform function to convert back to original y-coordinate scale
trans = ~ scales::rescale(., to = y_range, from = x_range),
name = "Y coordinate"
)
)
} else {
# Single axis plots remain the same
p <- data |>
ggplot2::ggplot(aes(x = .data$time))

if (dimension == "x") {
p <- p + ggplot2::geom_line(aes(y = .data$x), colour = "#FFA500")
} else if (dimension == "y") {
p <- p + ggplot2::geom_line(aes(y = .data$y), colour = "#1f77b4")
}
}

p <- p +
ggplot2::ggtitle("", subtitle = keypoint) +
ggplot2::theme_linedraw()

return(p)
}
31 changes: 31 additions & 0 deletions R/check_trajectory.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
check_trajectory <- function(data, reference_keypoint=NULL, coord_fixed=FALSE){
n_keypoints <- nlevels(data$keypoint)
keypoints <- levels(data$keypoint)
plot_trajectories <- list()

if (!is.null(reference_keypoint) && reference_keypoint %in% keypoints){
data <- data |>
translate_coords(to_keypoint = reference_keypoint)
}

for (j in 1:length(keypoints)){
df <- data |>
dplyr::ungroup() |>
dplyr::filter(.data$keypoint == keypoints[j])

plot_trajectories[[j]] <- df |>
subplot_trajectory(keypoint = keypoints[j], coord_fixed=coord_fixed)
}

output_plot <- patchwork::wrap_plots(plot_trajectories) +
patchwork::plot_annotation(title = "Trajectory of keypoints",
subtitle = "Trajecotry of individual keypoints over time",
theme = theme(plot.subtitle = ggtext::element_markdown(lineheight = 1.1),
legend.position="right")
) +
patchwork::plot_layout(axes = "collect",
axis_titles = "collect",
guides = "collect")

return(output_plot)
}
65 changes: 65 additions & 0 deletions R/filter_by_roi.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
filter_by_roi <- function(data,
x_min=NULL,
x_max=NULL,
y_min=NULL,
y_max=NULL,
x_center=NULL,
y_center=NULL,
radius=NULL){
if (all(is.null(x_center), is.null(y_center), is.null(radius))){
if (all(is.null(x_min), is.null(x_max), is.null(y_min), is.null(y_max))){
cli::cli_abort("To use a square ROI, at least one of the following must be provided: x_min, x_max, y_min & y_max.")
}
data <- filter_by_roi_square(data, x_min, x_max, y_min, y_max)
} else {
if (any(is.null(x_center), is.null(y_center), is.null(radius))){
cli::cli_abort("To use a square ROI, all following must be provided: x_min, x_max, y_min & y_max.")
}
data <- filter_by_roi_circle(data, x_center, y_center, radius)
}
return(data)
}

filter_by_roi_square <- function(data, x_min, x_max, y_min, y_max){
if (!is.null(x_min)){
data <- data |>
dplyr::mutate(x = case_when(.data$x < x_min ~ NA,
.default = .data$x),
y = case_when(.data$x < x_min ~ NA,
.default = .data$y))
}

if (!is.null(x_max)){
data <- data |>
dplyr::mutate(x = case_when(.data$x > x_max ~ NA,
.default = .data$x),
y = case_when(.data$x > x_max ~ NA,
.default = .data$y))
}

if (!is.null(y_min)){
data <- data |>
dplyr::mutate(x = case_when(.data$y < y_min ~ NA,
.default = .data$x),
y = case_when(.data$y < y_min ~ NA,
.default = .data$y))
}

if (!is.null(y_max)){
data <- data |>
dplyr::mutate(x = case_when(.data$y > y_max ~ NA,
.default = .data$x),
y = case_when(.data$y > y_max ~ NA,
.default = .data$y))
}
return(data)
}

filter_by_roi_circle <- function(data, x_center, y_center, radius){
data <- data |>
dplyr::mutate(x = case_when(((.data$x - x_center)^2 + (.data$y - y_center)^2) > radius^2 ~ NA,
.default = .data$x),
y = case_when(((.data$x - x_center)^2 + (.data$y - y_center)^2) > radius^2 ~ NA,
.default = .data$y))
return(data)
}
5 changes: 2 additions & 3 deletions R/filter_by_speed.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
#' @param data A data frame containing at least the columns `x` and `y`.
#' @param threshold A numeric value specifying the speed threshold. Observations
#' with speeds greater than this value will have their `x`, `y`, and `confidence`
#' values replaced with `NA`. If set to `"auto"`, the function will throw an error
#' as automatic threshold determination is not yet implemented.
#' values replaced with `NA`. If set to `"auto"`, the function will set the threshold at the mean speed + 3 standard deviations (SD).
#'
#' @return A data frame with the same columns as the input `data`, but with
#' values replaced by `NA` where the speed exceeds the threshold.
Expand Down Expand Up @@ -42,7 +41,7 @@ filter_by_speed <- function(data, threshold = "auto"){
d <- data |>
calculate_kinematics()
if (threshold == "auto"){
cli::cli_abort("Not implemented yet.")
threshold <- mean(d$v_translation) + 3*sd(d$v_translation)
} else {
d <- d |>
dplyr::mutate(x = dplyr::if_else(abs(.data$v_translation) < threshold, NA, .data$x),
Expand Down
3 changes: 0 additions & 3 deletions R/filter_outside_roi.R

This file was deleted.

24 changes: 24 additions & 0 deletions R/manage_metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ set_start_datetime <- function(data, start_datetime){
return(data)
}

set_individual <- function(data, individual){
new_id <- individual
data <- data |>
dplyr::ungroup() |>
dplyr::mutate(individual = factor(new_id))
return(data)
}

set_framerate <- function(data, framerate, old_framerate=1){
scaling_factor <- old_framerate / framerate

# Ensure frame numbers start at zero
if (is.integer(data$time)){
data <- data |>
dplyr::mutate(time = .data$time - min(.data$time, na.rm = TRUE))
}
data <- data |>
dplyr::mutate(time = .data$time * scaling_factor)

attributes(data)$metadata$framerate <- framerate

return(data)
}

#' @keywords internal
check_class <- function(x, class){
class %in% class(x)
Expand Down
Loading

0 comments on commit 9e47f16

Please sign in to comment.