Skip to content

Commit

Permalink
Bug fixes for cv: constant columns, group.multiplier
Browse files Browse the repository at this point in the history
  • Loading branch information
pbreheny committed Jun 7, 2017
1 parent d0b4024 commit 3f43669
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 25 deletions.
39 changes: 22 additions & 17 deletions R/cv.grpreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ cv.grpreg <- function(X, y, group=1:ncol(X), ..., nfolds=10, seed, cv.ind, retur

# Get standardized X, y
X <- fit$XG$X
g <- fit$XG$g
gm <- fit$XG$m
y <- fit$y
m <- attr(fit$y, "m")
returnX <- list(...)$returnX
Expand Down Expand Up @@ -42,22 +40,17 @@ cv.grpreg <- function(X, y, group=1:ncol(X), ..., nfolds=10, seed, cv.ind, retur
# Do cross-validation
E <- Y <- matrix(NA, nrow=length(y), ncol=length(fit$lambda))
if (fit$family=="binomial") PE <- E
cv.args <- list(...)
cv.args$lambda <- fit$lambda
cv.args$group <- fit$XG$g
cv.args$group.multiplier <- fit$XG$m
cv.args$warn <- FALSE
for (i in 1:nfolds) {
if (trace) cat("Starting CV fold #",i,sep="","\n")
X1 <- X[cv.ind!=i, , drop=FALSE]
y1 <- y[cv.ind!=i]
X2 <- X[cv.ind==i, , drop=FALSE]
y2 <- y[cv.ind==i]

args <- list(..., X=X1, y=y1, group=g, group.multiplier=gm)
args$lambda <- fit$lambda
args$warn <- FALSE
fit.i <- do.call('grpreg', args)

yhat <- matrix(predict(fit.i, X2, type="response"), length(y2))
E[cv.ind==i, 1:length(fit.i$lambda)] <- loss.grpreg(y2, yhat, fit$family)
if (fit$family=="binomial") PE[cv.ind==i, 1:length(fit.i$lambda)] <- (yhat < 0.5) == y2
Y[cv.ind==i, 1:length(fit.i$lambda)] <- yhat
res <- cvf(i, X, y, cv.ind, cv.args)
Y[cv.ind==i, 1:res$nl] <- res$yhat
E[cv.ind==i, 1:res$nl] <- res$loss
if (fit$family=="binomial") PE[cv.ind==i, 1:res$nl] <- res$pe
}

## Eliminate saturated lambda values, if any
Expand All @@ -70,7 +63,7 @@ cv.grpreg <- function(X, y, group=1:ncol(X), ..., nfolds=10, seed, cv.ind, retur
cve <- apply(E, 2, mean)
cvse <- apply(E, 2, sd) / sqrt(n)
min <- which.min(cve)
null.dev <- calcNullDev(X, y, group=g, family=fit$family)
null.dev <- calcNullDev(X, y, group=fit$XG$g, family=fit$family)

val <- list(cve=cve, cvse=cvse, lambda=lambda, fit=fit, min=min, lambda.min=lambda[min], null.dev=null.dev)
if (fit$family=="binomial") val$pe <- apply(PE[,ind], 2, mean)
Expand All @@ -80,3 +73,15 @@ cv.grpreg <- function(X, y, group=1:ncol(X), ..., nfolds=10, seed, cv.ind, retur
}
structure(val, class="cv.grpreg")
}
cvf <- function(i, X, y, cv.ind, cv.args) {
cv.args$X <- X[cv.ind!=i, , drop=FALSE]
cv.args$y <- y[cv.ind!=i]
fit.i <- do.call("grpreg", cv.args)

X2 <- X[cv.ind==i, , drop=FALSE]
y2 <- y[cv.ind==i]
yhat <- predict(fit.i, X2, type="response")
loss <- loss.grpreg(y2, yhat, fit.i$family)
pe <- if (fit.i$family=="binomial") {(yhat < 0.5) == y2} else NULL
list(loss=loss, pe=pe, nl=length(fit.i$lambda), yhat=yhat)
}
12 changes: 5 additions & 7 deletions R/reorderGroups.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,19 @@ reorderGroups <- function(group, m, bilevel) {
ord <- order(g)
ord.inv <- match(1:length(g), ord)
g <- g[ord]
gLevels <- setdiff(levels(gf), "0")
} else {
reorder <- FALSE
g <- group
ord <- ord.inv <- NULL
gLevels <- paste0("G", unique(g[g!=0]))
}
J <- max(g)
J <- length(gLevels)
if (missing(m)) {
m <- if (bilevel) rep(1, J) else sqrt(table(g[g!=0]))
}
if (length(m)!=max(g)) stop("Length of group.multiplier must equal number of penalized groups")
if (reorder) {
names(m) <- setdiff(levels(gf), "0")
} else {
names(m) <- paste0("G", unique(g[g!=0]))
}
names(m) <- gLevels
if (length(m) != J) stop("Length of group.multiplier must equal number of penalized groups")
if (storage.mode(m) != "double") storage.mode(m) <- "double"
list(g=g, m=m, ord=ord, ord.inv=ord.inv, reorder=reorder)
}
10 changes: 9 additions & 1 deletion inst/tests/extra-features.R
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ fit <- gBridge(X, yy, group, family="binomial"); plot(fit); fit$beta[,100]
fit <- grpreg(X, yy, group, penalty="grLasso", family="poisson"); plot(fit)
fit <- grpreg(X, yy, group, penalty="cMCP", family="poisson"); plot(fit); fit$beta[,100]
fit <- gBridge(X, yy, group, family="poisson"); plot(fit); fit$beta[,100]
cvfit <- cv.grpreg(X, y, group, penalty="grLasso")
cvfit <- cv.grpreg(X, y, group, penalty="gel")

.test = "grpreg handles groups of non-full rank"
n <- 50
Expand All @@ -61,6 +63,8 @@ fit <- grpreg(X, yy, group, penalty="grLasso", family="binomial"); plot(fit)
fit <- grpreg(X, yy, group, penalty="cMCP", family="binomial"); plot(fit)
fit <- grpreg(X, yy, group, penalty="grLasso", family="poisson"); plot(fit)
fit <- grpreg(X, yy, group, penalty="cMCP", family="poisson"); plot(fit)
cvfit <- cv.grpreg(X, y, group, penalty="grLasso")
cvfit <- cv.grpreg(X, y, group, penalty="gel")

.test = "grpreg out-of-order groups"
n <- 50
Expand All @@ -76,6 +80,7 @@ fit2 <- grpreg(X[,ind], y, group[ind], penalty="grLasso")
b1 <- coef(fit1)[-1,][ind,]
b2 <- coef(fit2)[-1,]
check(b1, b2, tol=0.01)
cvfit <- cv.grpreg(X, y, group, penalty="grLasso")

.test = "grpreg named groups"
n <- 50
Expand All @@ -89,18 +94,20 @@ yy <- y > 0
fit1 <- grpreg(X, y, group1, penalty="grLasso")
fit2 <- grpreg(X, y, group2, penalty="grLasso")
check(coef(fit1), coef(fit2), tol=0.001)
cvfit <- cv.grpreg(X, y, group, penalty="grLasso")

.test = "grpreg out-of-order groups with constant columns"
n <- 50
group <- rep(c(1,3,0,2),5:2)
p <- length(group)
X <- matrix(rnorm(n*p),ncol=p)
#X[,group==2] <- 0
X[,group==2] <- 0
y <- rnorm(n)
mle <- coef(lm(y~X))
mle[!is.finite(mle)] <- 0
grl <- coef(grpreg(X, y, group, penalty="grLasso", lambda.min=0, eps=1e-7), lambda=0)
check(mle, grl, tol=0.01)
cvfit <- cv.grpreg(X, y, group, penalty="grLasso")

.test = "group.multiplier works"
n <- 50
Expand All @@ -114,6 +121,7 @@ plot(fit <- gBridge(X, y, group, lambda.min=0, group.multiplier=gm), main=fit$pe
plot(fit <- grpreg(X, y, group, penalty="grLasso", lambda.min=0, group.multiplier=gm), main=fit$penalty)
plot(fit <- grpreg(X, y, group, penalty="grMCP", lambda.min=0, group.multiplier=gm), main=fit$penalty)
plot(fit <- grpreg(X, y, group, penalty="grSCAD", lambda.min=0, group.multiplier=gm), main=fit$penalty)
cvfit <- cv.grpreg(X, y, group, penalty="grLasso", group.multiplier=gm)

.test = "dfmax works"
n <- 100
Expand Down

0 comments on commit 3f43669

Please sign in to comment.