Skip to content

Commit

Permalink
Fixes to regrid for non-est cases
Browse files Browse the repository at this point in the history
Also removed disabled code in ref_grid, regrid
and corrected an error in FAQs
  • Loading branch information
rvlenth committed Jan 8, 2025
1 parent 254488e commit 8678372
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 103 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: emmeans
Type: Package
Title: Estimated Marginal Means, aka Least-Squares Means
Version: 1.10.6-090002
Date: 2024-12-28
Version: 1.10.6-090003
Date: 2025-01-08
Authors@R: c(person("Russell V.", "Lenth", role = c("aut", "cre", "cph"),
email = "russell-lenth@uiowa.edu"),
person("Balazs", "Banfai", role = "ctb"),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ title: "NEWS for the emmeans package"
reference grids. Whenever we average over a counterfactual `B`, we only
use the cases where `B == actual_B`, thus obtaining the same results as
would be obtained when `B` is not regarded as a counterfactual.
* Tweaks to `regrid()` to create `@post.beta` slot correctly when there are
non-estimable cases.


## emmeans 1.10.6
Expand Down
132 changes: 48 additions & 84 deletions R/emmGrid-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,8 @@ regrid = function(object, transform = c("response", "mu", "unlink", "none", "pas

if (is.na(object@post.beta[1]) && !missing(N.sim)) {
message("Simulating a sample of size ", N.sim, " of regression coefficients.")
object@post.beta = sim(N.sim, object@bhat, object@V)
bh = object@bhat[!is.na(object@bhat)]
object@post.beta = sim(N.sim, bh, object@V)
}

if (transform == "pass")
Expand All @@ -926,65 +927,58 @@ regrid = function(object, transform = c("response", "mu", "unlink", "none", "pas
if ((transform == "response") && (!is.null(object@misc$tran2)))
object = regrid(object, transform = "mu")

.collapse = (\(.collapse = NULL, ...) .collapse)(...) # check counterfactuals hook

# Save post.beta stuff
PB = object@post.beta
NC = attr(PB, "n.chains")

if (!is.na(PB[1])) { # fix up post.beta BEFORE we overwrite parameters
PB = PB %*% t(object@linfct)
PB = PB %*% t(object@linfct[ , !is.na(object@bhat), drop = FALSE])
if (".offset." %in% names(object@grid))
PB = t(apply(PB, 1, function(.) . + object@grid[[".offset."]]))
}

est = .est.se.df(object, do.se = TRUE) ###FALSE)

if(is.null(.collapse)) {
estble = !(is.na(est[[1]]))
object@V = vcov(object)[estble, estble, drop = FALSE]
object@bhat = est[[1]]
object@linfct = diag(1, length(estble))
object@misc$regrid.flag = TRUE
pargs = object@grid[names(object@levels)]
lbls = do.call(paste, c(pargs, sep = "."))
if (!is.null(disp <- object@misc$display)) { # fix up for the bookkeeping in nested models
object@V = object@V[estble, estble, drop = FALSE]
object@linfct = matrix(0, nrow = length(disp), ncol = length(estble))
object@linfct[disp, ] = diag(1, length(estble))
lbls = lbls[disp]
}
colnames(object@linfct) = lbls
if(all(estble))
object@nbasis = estimability::all.estble
else
object@nbasis = object@linfct[, !estble, drop = FALSE]

# override the df function
df = est$df
edf = df[estble]
if (length(edf) == 0) edf = NA
# note both NA/NA and Inf/Inf test is.na() = TRUE
prev.df.msg = attr(object@dffun, "mesg")
if (any(is.na(edf/edf)) || (diff(range(edf)) < .01)) { # use common value
object@dfargs = list(df = mean(edf, na.rm = TRUE))
object@dffun = function(k, dfargs) dfargs$df
}
else { # use containment df
object@dfargs = list(df = df)
object@dffun = function(k, dfargs) {
idx = which(zapsmall(k) != 0)
ifelse(length(idx) == 0, NA, min(dfargs$df[idx], na.rm = TRUE))
}
}
if(!is.null(prev.df.msg))
attr(object@dffun, "mesg") = ifelse(
startsWith(prev.df.msg, "inherited"), prev.df.msg,
paste("inherited from", prev.df.msg, "when re-gridding"))
estble = !(is.na(est[[1]]))
object@V = vcov(object)[estble, estble, drop = FALSE]
object@bhat = est[[1]]
object@linfct = diag(1, length(estble))
object@misc$regrid.flag = TRUE
pargs = object@grid[names(object@levels)]
lbls = do.call(paste, c(pargs, sep = "."))
if (!is.null(disp <- object@misc$display)) { # fix up for the bookkeeping in nested models
object@V = object@V[estble, estble, drop = FALSE]
object@linfct = matrix(0, nrow = length(disp), ncol = length(estble))
object@linfct[disp, ] = diag(1, length(estble))
lbls = lbls[disp]
}
colnames(object@linfct) = lbls
if(all(estble))
object@nbasis = estimability::all.estble
else
object@nbasis = object@linfct[, !estble, drop = FALSE]

if(!is.null(.collapse) && is.null(object@misc$tran)) # need explicit link so we can collapse
object@misc$tran = attr(est, "link") = make.link("identity")
# override the df function
df = est$df
edf = df[estble]
if (length(edf) == 0) edf = NA
# note both NA/NA and Inf/Inf test is.na() = TRUE
prev.df.msg = attr(object@dffun, "mesg")
if (any(is.na(edf/edf)) || (diff(range(edf)) < .01)) { # use common value
object@dfargs = list(df = mean(edf, na.rm = TRUE))
object@dffun = function(k, dfargs) dfargs$df
}
else { # use containment df
object@dfargs = list(df = df)
object@dffun = function(k, dfargs) {
idx = which(zapsmall(k) != 0)
ifelse(length(idx) == 0, NA, min(dfargs$df[idx], na.rm = TRUE))
}
}
if(!is.null(prev.df.msg))
attr(object@dffun, "mesg") = ifelse(
startsWith(prev.df.msg, "inherited"), prev.df.msg,
paste("inherited from", prev.df.msg, "when re-gridding"))

if(transform %in% c("response", "mu", "unlink", links, "user") && !is.null(object@misc$tran)) {
flink = link = attr(est, "link")
Expand All @@ -1000,36 +994,11 @@ regrid = function(object, transform = c("response", "mu", "unlink", "none", "pas
}
if (!is.na(PB[1]))
PB = matrix(flink$linkinv(PB), ncol = ncol(PB))
if(is.null(.collapse)) {
D = flink$mu.eta(object@bhat[estble])
object@bhat = flink$linkinv(object@bhat)
# efficient repl for D'VD with D <- diag(D)
object@V = sweep(sweep(object@V, 1, D, "*"), 2, D, "*")
}
else { # we'll average over the levels of .collapse (assume it varies slowest)
est = est[[1]]
nobs = length(object@levels[[.collapse]])
idx = sapply(seq_len(nobs), \(i) which(object@grid[[.collapse]] == i))
X = sweep(object@linfct, 1, flink$mu.eta(est), "*")
wt.counter = (\(wt.counter, ...) wt.counter)(...)
if (length(wt.counter) == 1) wt.counter = rep(1, nobs)
wmn = function(x) sum(wt.counter*x) / sum(wt.counter)
est = flink$linkinv(est)
object@bhat = sapply(seq_len(nrow(idx)), \(i) wmn(est[idx[i,]]))
L = matrix(0, nrow = nrow(idx), ncol = ncol(X))
for (i in seq_len(nrow(idx)))
L[i, ] = apply(X[idx[i,], , drop = FALSE], 2, wmn)
object@V = L %*% tcrossprod(object@V, L)
object@linfct = diag(1, nrow(L))
object@misc$regrid.flag = TRUE
if(!is.na(PB[1])) {
pb = matrix(0, ncol = nrow(L), nrow = nrow(PB))
for (i in seq_len(nrow(idx)))
pb[, i] = apply(PB[, idx[i,], drop = FALSE], 1, wmn)
PB = pb
}
}

D = flink$mu.eta(object@bhat[estble])
object@bhat = flink$linkinv(object@bhat)
# efficient repl for D'VD with D <- diag(D)
object@V = sweep(sweep(object@V, 1, D, "*"), 2, D, "*")

inm = object@misc$inv.lbl
if (!is.null(inm)) {
object@misc$estName = inm
Expand Down Expand Up @@ -1077,6 +1046,7 @@ regrid = function(object, transform = c("response", "mu", "unlink", "none", "pas
}

if(!is.na(PB[1])) {
PB = PB[ , !is.na(object@bhat), drop = FALSE]
attr(PB, "n.chains") = NC
object@post.beta = PB
}
Expand All @@ -1087,13 +1057,7 @@ regrid = function(object, transform = c("response", "mu", "unlink", "none", "pas
object@model.info$model.matrix = "Submodels are not available with regridded objects"
if(!missing(predict.type))
object = update(object, predict.type = predict.type)
if(!is.null(.collapse)) {
object@grid = object@grid[idx[,1], , drop = TRUE]
object@grid[[.collapse]] = object@levels[[.collapse]] = NULL
object@roles$predictors = setdiff(object@roles$predictors, .collapse)
object@misc$famSize = nrow(object@linfct)
}


object
}

16 changes: 0 additions & 16 deletions R/ref-grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -607,17 +607,8 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c

# Now create the reference grid
if(no.nuis <- (length(nuisance) == 0)) {
# if (!missing(counterfactuals)) {
# cfac = intersect(counterfactuals, names(ref.levels))
# ref.levels = ref.levels[cfac]
# ref.levels$.obs.no. = seq_len(nrow(data))
# .check.grid(ref.levels, rg.limit)
# grid = .setup.cf(ref.levels, data)
# }
# ## else {
.check.grid(ref.levels, rg.limit)
grid = do.call(expand.grid, ref.levels)
##}
}
else {
nuis.info = .setup.nuis(nuisance, ref.levels, trms, rg.limit)
Expand Down Expand Up @@ -688,13 +679,6 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c
call. = TRUE)

collapse = NULL
# if (!missing(counterfactuals)) {
# grid = do.call(expand.grid, ref.levels)
# if (missing(regrid))
# regrid = "response"
# if (avg.counter) collapse = ".obs.no."
# }

if(!no.nuis) {
basis = .basis.nuis(basis, nuis.info, wt.nuis, ref.levels, data, grid, ref.levels)
grid = basis$grid
Expand Down
3 changes: 2 additions & 1 deletion vignettes/FAQs.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ argument (see help for `ref_grid`). For example, to reproduce the results
in Stata's Example 6 for `margins`, do:
```r
margex <- haven::read_dta("https://www.stata-press.com/data/r18/margex.dta")
margex.glm <- glm(outcome ~ sex * factor(group) + age, data = margex)
margex.glm <- glm(outcome ~ sex * factor(group) + age, data = margex,
family = binomial)
emmeans(margex.glm, "sex", counterfactuals = "sex")
```

Expand Down

0 comments on commit 8678372

Please sign in to comment.