Skip to content

Commit 840a016

Browse files
authored
Merge pull request #27 from stephenslab/eweine/add_subset_option
Eweine/add subset option
2 parents ed33367 + 49ebe86 commit 840a016

File tree

4 files changed

+165
-5
lines changed

4 files changed

+165
-5
lines changed

R/fit.R

+102-5
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,21 @@
7171
#' control argument for \code{\link[daarem]{daarem}}. This setting
7272
#' determines to what extent the monotonicity condition can be
7373
#' violated.}
74+
#'
75+
#' \item{\code{training_frac}}{Fraction of the columns of input data \code{Y}
76+
#' to fit initial model on. If set to \code{1} (default), the model is fit
77+
#' by optimizing the parameters on the entire dataset. If set between \code{0}
78+
#' and \code{1}, the model is optimized by first fitting a model on a randomly
79+
#' selected fraction of the columns of \code{Y}, and then projecting the
80+
#' remaining columns of \code{Y} onto the solution. Setting this to a smaller
81+
#' value will increase speed but decrease accuracy.
82+
#' }
83+
#'
84+
#' \item{\code{num_projection_ccd_iter}}{Number of co-ordinate descent updates
85+
#' be made to elements of \code{V} if and when a subset of \code{Y} is
86+
#' projected onto \code{U}. Only used if \code{training_frac} is less than
87+
#' \code{1}.
88+
#' }
7489
#'
7590
#' \item{\code{num_ccd_iter}}{Number of co-ordinate descent updates to
7691
#' be made to parameters at each iteration of the algorithm.}
@@ -196,7 +211,7 @@ fit_glmpca_pois <- function(
196211
# Check and process input argument "control".
197212
control <- modifyList(fit_glmpca_pois_control_default(),
198213
control,keep.null = TRUE)
199-
214+
200215
# Set up the internal fit.
201216
D <- sqrt(fit0$d)
202217
if (K == 1)
@@ -205,7 +220,7 @@ fit_glmpca_pois <- function(
205220
D <- diag(D)
206221
LL <- t(cbind(fit0$U %*% D,fit0$X,fit0$W))
207222
FF <- t(cbind(fit0$V %*% D,fit0$B,fit0$Z))
208-
223+
209224
# Determine which rows of LL and FF are "clamped".
210225
fixed_l <- numeric(0)
211226
fixed_f <- numeric(0)
@@ -217,9 +232,86 @@ fit_glmpca_pois <- function(
217232
fixed_f <- c(fixed_f,K + fit0$fixed_b_cols)
218233
if (nz > 0)
219234
fixed_f <- c(fixed_f,K + nx + seq(1,nz))
220-
221-
# Perform the updates.
222-
res <- fit_glmpca_pois_main_loop(LL,FF,Y,fixed_l,fixed_f,verbose,control)
235+
236+
if (control$training_frac == 1) {
237+
238+
# Perform the updates.
239+
res <- fit_glmpca_pois_main_loop(LL,FF,Y,fixed_l,fixed_f,verbose,control)
240+
241+
} else {
242+
243+
if (control$training_frac <= 0 || control$training_frac > 1)
244+
stop("control argument \"training_frac\" should be between 0 and 1")
245+
246+
train_idx <- sample(
247+
1:ncol(Y),
248+
size = ceiling(ncol(Y) * control$training_frac)
249+
)
250+
251+
browser()
252+
Y_train <- Y[, train_idx]
253+
254+
if (any(Matrix::rowSums(Y_train) == 0) || any(Matrix::colSums(Y_train) == 0)) {
255+
256+
stop(
257+
"After subsetting, the remaining values of \"Y\" ",
258+
"contain a row or a column where all counts are 0. This can cause ",
259+
"problems with optimization. Please either remove rows / columns ",
260+
"with few non-zero counts from \"Y\", or set \"training_frac\" to ",
261+
"a larger value."
262+
)
263+
264+
}
265+
266+
FF_train <- FF[, train_idx]
267+
FF_test <- FF[, -train_idx]
268+
Y_test <- Y[, -train_idx]
269+
270+
test_idx <- 1:ncol(Y)
271+
test_idx <- test_idx[-train_idx]
272+
273+
# Perform the updates.
274+
res <- fit_glmpca_pois_main_loop(
275+
LL,
276+
FF_train,
277+
Y_train,
278+
fixed_l,
279+
fixed_f,
280+
verbose,
281+
control
282+
)
283+
284+
update_indices_f <- sort(setdiff(1:K,fixed_f))
285+
286+
# now, I just need to project the results back
287+
update_factors_faster_parallel(
288+
L_T = t(res$fit$LL),
289+
FF = FF_test,
290+
M = as.matrix(res$fit$LL[update_indices_f,,drop = FALSE] %*% Y_test),
291+
update_indices = update_indices_f - 1,
292+
num_iter = control$num_projection_ccd_iter,
293+
line_search = control$line_search,
294+
alpha = control$ls_alpha,
295+
beta = control$ls_beta
296+
)
297+
298+
# now, I need to reconstruct FF, and hopefully compute the log-likelihood
299+
FF[, train_idx] <- res$fit$FF
300+
FF[, test_idx] <- FF_test
301+
res$fit$FF <- FF
302+
303+
if (inherits(Y,"sparseMatrix")) {
304+
test_loglik_const <- sum(mapSparse(Y_test,lfactorial))
305+
loglik_func <- lik_glmpca_pois_log_sp
306+
} else {
307+
test_loglik_const <- sum(lfactorial(Y_test))
308+
loglik_func <- lik_glmpca_pois_log
309+
}
310+
311+
test_loglik <- loglik_func(Y_test,res$fit$LL,FF_test,test_loglik_const)
312+
res$loglik <- res$loglik + test_loglik
313+
314+
}
223315

224316
# Prepare the final output.
225317
res$progress$iter <- max(fit0$progress$iter) + res$progress$iter
@@ -258,9 +350,12 @@ fit_glmpca_pois <- function(
258350
dimnames(fit$W) <- dimnames(fit0$W)
259351
}
260352
class(fit) <- c("glmpca_pois_fit","list")
353+
261354
return(fit)
355+
262356
}
263357

358+
264359
# This implements the core part of fit_glmpca_pois.
265360
#
266361
#' @importFrom Matrix t
@@ -358,6 +453,8 @@ fit_glmpca_pois_control_default <- function()
358453
list(use_daarem = FALSE,
359454
maxiter = 100,
360455
tol = 1e-4,
456+
training_frac = 1,
457+
num_projection_ccd_iter = 10,
361458
mon.tol = 0.05,
362459
convtype = "objfn",
363460
line_search = TRUE,

inst/.DS_Store

10 KB
Binary file not shown.

inst/scratch/test_projection_method.R

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
library(fastglmpca)
2+
3+
set.seed(1)
4+
cc <- pbmc_facs$counts[Matrix::rowSums(pbmc_facs$counts) > 10, ]
5+
6+
fit1 <- fit_glmpca_pois(
7+
Y = cc,
8+
K = 2,
9+
control = list(training_frac = 0.99, maxiter = 10)
10+
)
11+
12+
# for some reason the calculated log-likelihood and the expected
13+
# are not matching up
14+
set.seed(1)
15+
fit2 <- fit_glmpca_pois(
16+
Y = pbmc_facs$counts,
17+
K = 2,
18+
control = list(training_frac = 0.25, maxiter = 10, num_projection_ccd_iter = 25)
19+
)
20+
21+
set.seed(1)
22+
fit3 <- fit_glmpca_pois(
23+
Y = pbmc_facs$counts,
24+
K = 2,
25+
control = list(training_frac = 0.25, maxiter = 10, num_projection_ccd_iter = 5)
26+
)
27+
#
28+
# df1 <- data.frame(
29+
# celltype = pbmc_facs$samples$celltype,
30+
# PC1 = fit1$V[,1],
31+
# PC2 = fit1$V[,2]
32+
# )
33+
#
34+
# library(ggplot2)
35+
#
36+
# ggplot(data = df1) +
37+
# geom_point(aes(x = PC1, y = PC2, color = celltype))
38+
#
39+
# df2 <- data.frame(
40+
# celltype = pbmc_facs$samples$celltype,
41+
# PC1 = fit2$V[,1],
42+
# PC2 = fit2$V[,2]
43+
# )
44+
#
45+
# library(ggplot2)
46+
#
47+
# ggplot(data = df2) +
48+
# geom_point(aes(x = PC1, y = PC2, color = celltype))

man/fit_glmpca_pois.Rd

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)