Skip to content

Commit

Permalink
lintr, fix
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed May 19, 2024
1 parent 6554f8c commit 525e3d6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
14 changes: 7 additions & 7 deletions R/diagnostic_posterior.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ diagnostic_posterior.stanreg <- function(posterior, diagnostic = "all", effects

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

diagnostic <- match.arg(diagnostic, c("ESS", "Rhat", "MCSE", "all"), several.ok = TRUE)
Expand Down Expand Up @@ -136,7 +136,7 @@ diagnostic_posterior.stanmvreg <- function(posterior,

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

diagnostic <- match.arg(diagnostic, c("ESS", "Rhat", "MCSE", "all"), several.ok = TRUE)
Expand Down Expand Up @@ -197,15 +197,15 @@ diagnostic_posterior.brmsfit <- function(posterior,

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

# Get diagnostic
diagnostic <- match.arg(diagnostic, c("ESS", "Rhat", "MCSE", "all"), several.ok = TRUE)
if ("all" %in% diagnostic) {
diagnostic <- c("ESS", "Rhat", "MCSE", "khat") # Add MCSE
} else {
if ("Rhat" %in% diagnostic) diagnostic <- c(diagnostic, "khat")
} else if ("Rhat" %in% diagnostic) {
diagnostic <- c(diagnostic, "khat")
}

insight::check_if_installed("rstan")
Expand Down Expand Up @@ -241,7 +241,7 @@ diagnostic_posterior.stanfit <- function(posterior, diagnostic = "all", effects

# If no diagnostic
if (is.null(diagnostic)) {
return(data.frame("Parameter" = params))
return(data.frame(Parameter = params))
}

# Get diagnostic
Expand Down Expand Up @@ -288,7 +288,7 @@ diagnostic_posterior.blavaan <- function(posterior, diagnostic = "all", ...) {
# Find parameters
params <- suppressWarnings(insight::find_parameters(posterior, flatten = TRUE))

out <- data.frame("Parameter" = params)
out <- data.frame(Parameter = params)

# If no diagnostic
if (is.null(diagnostic)) {
Expand Down
20 changes: 10 additions & 10 deletions R/estimate_density.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ estimate_density.data.frame <- function(x,
} else {
# Deal with at- grouping --------

groups <- insight::get_datagrid(x[, at, drop = FALSE], at = at) # Get combinations
groups <- insight::get_datagrid(x[, at, drop = FALSE], by = at) # Get combinations
out <- data.frame()
for (row in seq_len(nrow(groups))) {
subdata <- datawizard::data_match(x, groups[row, , drop = FALSE])
Expand Down Expand Up @@ -607,8 +607,8 @@ as.data.frame.density <- function(x, ...) {
#' density_at(posterior, c(0, 1))
#' @export
density_at <- function(posterior, x, precision = 2^10, method = "kernel", ...) {
density <- estimate_density(posterior, precision = precision, method = method, ...)
stats::approx(density$x, density$y, xout = x)$y
posterior_density <- estimate_density(posterior, precision = precision, method = method, ...)
stats::approx(posterior_density$x, posterior_density$y, xout = x)$y
}


Expand All @@ -620,30 +620,30 @@ density_at <- function(posterior, x, precision = 2^10, method = "kernel", ...) {
dots[c("effects", "component", "parameters")] <- NULL

# Get the kernel density estimation (KDE)
args <- c(dots, list(
my_args <- c(dots, list(
x = x,
n = precision,
bw = bw,
from = x_range[1],
to = x_range[2]
))
fun <- get("density", asNamespace("stats"))
kde <- suppressWarnings(do.call("fun", args))
df <- as.data.frame(kde)
kde <- suppressWarnings(do.call("fun", my_args))
my_df <- as.data.frame(kde)

# Get CI (https://bookdown.org/egarpor/NP-UC3M/app-kde-ci.html)
if (!is.null(ci)) {
h <- kde$bw # Selected bandwidth
# R(K) for a normal
Rk <- 1 / (2 * sqrt(pi))
# Estimate the SD
sd_kde <- sqrt(df$y * Rk / (length(x) * h))
sd_kde <- sqrt(my_df$y * Rk / (length(x) * h))
# CI with estimated variance
z_alpha <- stats::qnorm(ci)
df$CI_low <- df$y - z_alpha * sd_kde
df$CI_high <- df$y + z_alpha * sd_kde
my_df$CI_low <- my_df$y - z_alpha * sd_kde
my_df$CI_high <- my_df$y + z_alpha * sd_kde
}
df
my_df
}


Expand Down

0 comments on commit 525e3d6

Please sign in to comment.