Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use insight and dw remotes #648

Merged
merged 4 commits into from
May 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lintr, fix
  • Loading branch information
strengejacke committed May 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 525e3d604624b43b2f04de3569f767221ed3402f
14 changes: 7 additions & 7 deletions R/diagnostic_posterior.R
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ diagnostic_posterior.stanreg <- function(posterior, diagnostic = "all", effects

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

diagnostic <- match.arg(diagnostic, c("ESS", "Rhat", "MCSE", "all"), several.ok = TRUE)
@@ -136,7 +136,7 @@ diagnostic_posterior.stanmvreg <- function(posterior,

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

diagnostic <- match.arg(diagnostic, c("ESS", "Rhat", "MCSE", "all"), several.ok = TRUE)
@@ -197,15 +197,15 @@ diagnostic_posterior.brmsfit <- function(posterior,

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

# Get diagnostic
diagnostic <- match.arg(diagnostic, c("ESS", "Rhat", "MCSE", "all"), several.ok = TRUE)
if ("all" %in% diagnostic) {
diagnostic <- c("ESS", "Rhat", "MCSE", "khat") # Add MCSE
} else {
if ("Rhat" %in% diagnostic) diagnostic <- c(diagnostic, "khat")
} else if ("Rhat" %in% diagnostic) {
diagnostic <- c(diagnostic, "khat")
}

insight::check_if_installed("rstan")
@@ -241,7 +241,7 @@ diagnostic_posterior.stanfit <- function(posterior, diagnostic = "all", effects

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

# Get diagnostic
@@ -288,7 +288,7 @@ diagnostic_posterior.blavaan <- function(posterior, diagnostic = "all", ...) {
# Find parameters
params <- suppressWarnings(insight::find_parameters(posterior, flatten = TRUE))

out <- data.frame("Parameter" = params)
out <- data.frame(Parameter = params)

# If no diagnostic
if (is.null(diagnostic)) {
20 changes: 10 additions & 10 deletions R/estimate_density.R
Original file line number Diff line number Diff line change
@@ -258,7 +258,7 @@ estimate_density.data.frame <- function(x,
} else {
# Deal with at- grouping --------

groups <- insight::get_datagrid(x[, at, drop = FALSE], at = at) # Get combinations
groups <- insight::get_datagrid(x[, at, drop = FALSE], by = at) # Get combinations
out <- data.frame()
for (row in seq_len(nrow(groups))) {
subdata <- datawizard::data_match(x, groups[row, , drop = FALSE])
@@ -607,8 +607,8 @@ as.data.frame.density <- function(x, ...) {
#' density_at(posterior, c(0, 1))
#' @export
density_at <- function(posterior, x, precision = 2^10, method = "kernel", ...) {
density <- estimate_density(posterior, precision = precision, method = method, ...)
stats::approx(density$x, density$y, xout = x)$y
posterior_density <- estimate_density(posterior, precision = precision, method = method, ...)
stats::approx(posterior_density$x, posterior_density$y, xout = x)$y
}


@@ -620,30 +620,30 @@ density_at <- function(posterior, x, precision = 2^10, method = "kernel", ...) {
dots[c("effects", "component", "parameters")] <- NULL

# Get the kernel density estimation (KDE)
args <- c(dots, list(
my_args <- c(dots, list(
x = x,
n = precision,
bw = bw,
from = x_range[1],
to = x_range[2]
))
fun <- get("density", asNamespace("stats"))
kde <- suppressWarnings(do.call("fun", args))
df <- as.data.frame(kde)
kde <- suppressWarnings(do.call("fun", my_args))
my_df <- as.data.frame(kde)

# Get CI (https://bookdown.org/egarpor/NP-UC3M/app-kde-ci.html)
if (!is.null(ci)) {
h <- kde$bw # Selected bandwidth
# R(K) for a normal
Rk <- 1 / (2 * sqrt(pi))
# Estimate the SD
sd_kde <- sqrt(df$y * Rk / (length(x) * h))
sd_kde <- sqrt(my_df$y * Rk / (length(x) * h))
# CI with estimated variance
z_alpha <- stats::qnorm(ci)
df$CI_low <- df$y - z_alpha * sd_kde
df$CI_high <- df$y + z_alpha * sd_kde
my_df$CI_low <- my_df$y - z_alpha * sd_kde
my_df$CI_high <- my_df$y + z_alpha * sd_kde
}
df
my_df
}


Loading