Skip to content

Commit

Permalink
merge pr #325: rep_sample_n() fix (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Jul 25, 2020
2 parents d68991d + 86150ad commit acdc676
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 89 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export(hypothesize)
export(p_value)
export(prop_test)
export(rep_sample_n)
export(rep_slice_sample)
export(shade_ci)
export(shade_confidence_interval)
export(shade_p_value)
Expand All @@ -28,7 +29,6 @@ export(visualise)
export(visualize)
importFrom(dplyr,bind_rows)
importFrom(dplyr,group_by)
importFrom(dplyr,inner_join)
importFrom(dplyr,mutate_if)
importFrom(dplyr,n)
importFrom(dplyr,one_of)
Expand Down
5 changes: 4 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# infer 0.5.3.9000 (development version)

To be released as 0.5.4.
- `rep_sample_n()` no longer errors when supplied a `prob` argument (#279)
- Added `rep_slice_sample()`, a light wrapper around `rep_sample_n()`, that
more closely resembles `dplyr::slice_sample()` (the function that supersedes)
`dplyr::sample_n()` (#325)

# infer 0.5.3

Expand Down
109 changes: 58 additions & 51 deletions R/rep_sample_n.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,93 @@
#'
#' @description
#'
#' Perform repeated sampling of samples of size n. Useful for creating sampling
#' distributions.
#' These functions extend the functionality of [dplyr::sample_n()] and
#' [dplyr::slice_sample()] by allowing for repeated sampling of data.
#' This operation is especially helpful while creating sampling
#' distributions—see the examples below!
#'
#' @param tbl Data frame of population from which to sample.
#' @param size Sample size of each sample.
#' @param tbl,.data Data frame of population from which to sample.
#' @param size,n Sample size of each sample.
#' @param replace Should sampling be with replacement?
#' @param reps Number of samples of size n = `size` to take.
#' @param prob A vector of probability weights for obtaining the elements of the
#' vector being sampled.
#' @param prob,weight_by A vector of sampling weights for each of the rows in
#' `tbl`—must have length equal to `nrow(tbl)`.
#'
#' @return A tibble of size `rep` times `size` rows corresponding to `rep`
#' samples of size n = `size` from `tbl`.
#' @return A tibble of size `rep * size` rows corresponding to `reps`
#' samples of size `size` from `tbl`, grouped by `replicate`.
#'
#' @examples
#' suppressPackageStartupMessages(library(dplyr))
#' suppressPackageStartupMessages(library(ggplot2))
#' @details The [dplyr::sample_n()] function (to which `rep_sample_n()` was
#' originally a supplement) has been superseded by [dplyr::slice_sample()].
#' `rep_slice_sample()` provides a light wrapper around `rep_sample_n()` that
#' has a more similar interface to `slice_sample()`.
#'
#' # A virtual population of N = 10,010, of which 3091 are hurricanes
#' population <- dplyr::storms %>%
#' select(status)
#' @examples
#' library(dplyr)
#' library(ggplot2)
#'
#' # Take samples of size n = 50 storms without replacement; do this 1000 times
#' samples <- population %>%
#' # take 1000 samples of size n = 50, without replacement
#' slices <- gss %>%
#' rep_sample_n(size = 50, reps = 1000)
#' samples
#'
#' # Compute p_hats for all 1000 samples = proportion hurricanes
#' p_hats <- samples %>%
#' slices
#'
#' # compute the proportion of respondents with a college
#' # degree in each replicate
#' p_hats <- slices %>%
#' group_by(replicate) %>%
#' summarize(prop_hurricane = mean(status == "hurricane"))
#' p_hats
#' summarize(prop_college = mean(college == "degree"))
#'
#' # Plot sampling distribution
#' ggplot(p_hats, aes(x = prop_hurricane)) +
#' # plot sampling distribution
#' ggplot(p_hats, aes(x = prop_college)) +
#' geom_density() +
#' labs(x = "p_hat", y = "Number of samples",
#' title = "Sampling distribution of p_hat from 1000 samples of size 50")
#'
#' @importFrom dplyr pull
#' @importFrom dplyr inner_join
#' @importFrom dplyr group_by
#' labs(
#' x = "p_hat", y = "Number of samples",
#' title = "Sampling distribution of p_hat"
#' )
#'
#' # sampling with probability weights. Note probabilities are automatically
#' # renormalized to sum to 1
#' library(tibble)
#' df <- tibble(
#' id = 1:5,
#' letter = factor(c("a", "b", "c", "d", "e"))
#' )
#' rep_sample_n(df, size = 2, reps = 5, prob = c(.5, .4, .3, .2, .1))
#' @export
rep_sample_n <- function(tbl, size, replace = FALSE, reps = 1, prob = NULL) {
n <- nrow(tbl)

check_type(tbl, is.data.frame)
check_type(size, is.numeric)
check_type(replace, is.logical)
check_type(reps, is.numeric)
if (!is.null(prob)) {
check_type(prob, is.numeric)
}

# assign non-uniform probabilities
# there should be a better way!!
# prob needs to be nrow(tbl) -- not just number of factor levels
if (!is.null(prob)) {
if (length(prob) != n) {
if (length(prob) != nrow(tbl)) {
stop_glue(
"The argument `prob` must have length `nrow(tbl)` = {nrow(tbl)}"
)
}

prob <- tibble::tibble(vals = levels(dplyr::pull(tbl, 1))) %>%
dplyr::mutate(probs = prob) %>%
dplyr::inner_join(tbl) %>%
dplyr::select(probs) %>%
dplyr::pull()
}

# Generate row indexes for every future replicate (this way it respects
# possibility of `replace = FALSE`)
n <- nrow(tbl)
i <- unlist(replicate(
reps,
sample.int(n, size, replace = replace, prob = prob),
simplify = FALSE
))
rep_tbl <- cbind(
replicate = rep(1:reps, rep(size, reps)),
tbl[i, ]
)
rep_tbl <- tibble::as_tibble(rep_tbl)
names(rep_tbl)[-1] <- names(tbl)
dplyr::group_by(rep_tbl, replicate)

tbl %>%
dplyr::slice(i) %>%
dplyr::mutate(replicate = rep(seq_len(reps), each = size)) %>%
dplyr::select(replicate, dplyr::everything()) %>%
tibble::as_tibble() %>%
dplyr::group_by(replicate)
}

#' @rdname rep_sample_n
#' @export
rep_slice_sample <- function(.data, n = 1, replace = FALSE, weight_by = NULL,
reps = 1) {
rep_sample_n(.data, n, replace, reps, weight_by)
}
2 changes: 1 addition & 1 deletion man/get_confidence_interval.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 44 additions & 26 deletions man/rep_sample_n.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

96 changes: 87 additions & 9 deletions tests/testthat/test-rep_sample_n.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,103 @@ context("rep_sample_n")

N <- 5
population <- tibble::tibble(
ball_ID = 1:N,
ball_id = 1:N,
color = factor(c(rep("red", 3), rep("white", N - 3)))
)

test_that("rep_sample_n works", {
expect_silent(population %>% rep_sample_n(size = 2, reps = 10))
test_that("rep_sample_n is sensitive to the size argument", {
set.seed(1)
reps <- 10
s1 <- 2
s2 <- 3

res1 <- population %>% rep_sample_n(size = s1, reps = reps)
res2 <- population %>% rep_sample_n(size = s2, reps = reps)

expect_equal(ncol(res1), ncol(res2))
expect_equal(ncol(res1), 3)

expect_equal(nrow(res1) / s1, nrow(res2) / s2)
expect_equal(nrow(res1), reps * s1)
})

test_that("rep_sample_n is sensitive to the reps argument", {
set.seed(1)
r1 <- 10
r2 <- 5
size <- 2

res1 <- population %>% rep_sample_n(size = size, reps = r1)
res2 <- population %>% rep_sample_n(size = size, reps = r2)

expect_equal(ncol(res1), ncol(res2))
expect_equal(ncol(res1), 3)

expect_equal(nrow(res1) / r1, nrow(res2) / r2)
expect_equal(nrow(res1), r1 * size)
})

test_that("rep_sample_n is sensitive to the replace argument", {
set.seed(1)
res1 <- population %>% rep_sample_n(size = 5, reps = 100, replace = TRUE)

set.seed(1)
res2 <- population %>% rep_sample_n(size = 5, reps = 100, replace = FALSE)

expect_true(all(res1$replicate == res2$replicate))
expect_false(all(res1$ball_id == res2$ball_id))
expect_false(all(res1$color == res2$color))

expect_equal(ncol(res1), ncol(res2))
expect_equal(ncol(res1), 3)

# Check if there are actually no duplicates in case `replace = FALSE`
no_duplicates <- all(tapply(res2$ball_id, res2$replicate, anyDuplicated) == 0)
expect_true(no_duplicates)
})

test_that("rep_sample_n is sensitive to the prob argument", {
set.seed(1)
res1 <- population %>%
rep_sample_n(
size = 5,
reps = 100,
replace = TRUE,
prob = c(1, rep(0, 4))
)

expect_true(all(res1$ball_id == 1))
expect_true(all(res1$color == "red"))
})

test_that("rep_sample_n errors with bad arguments", {
expect_error(
population %>%
rep_sample_n(size = 2, reps = 10, prob = rep(x = 1 / 5, times = 100))
)

expect_error(
population %>%
rep_sample_n(size = 2, reps = 10, prob = rep(x = 1/5, times = 100))
rep_sample_n(size = 2, reps = 10, prob = c(1 / 2, 1 / 2))
)

expect_error(
population %>%
rep_sample_n(size = 2, reps = 10, prob = c(1/2, 1/2))
rep_sample_n(size = "a lot", reps = 10)
)

expect_error(
population %>%
rep_sample_n(size = 2, reps = 10, prob = c(0.25, 1/5, 1/5, 1/5, 0.15))
rep_sample_n(size = 2, reps = "a lot")
)

test_rep <- population %>% rep_sample_n(size = 2, reps = 10)
expect_equal(c("replicate", names(population)), names(test_rep))
})

test_that("rep_slice_sample works", {
set.seed(1)
res1 <- rep_sample_n(population, size = 2, reps = 5, prob = rep(1 / N, N))

set.seed(1)
res2 <- rep_slice_sample(population, n = 2, reps = 5, weight_by = rep(1 / N, N))

expect_equal(res1, res2)
})

0 comments on commit acdc676

Please sign in to comment.