Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parametric model from survivalmodels #345

Merged
merged 8 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Suggests:
sm,
stats,
survival,
survivalmodels,
survivalmodels (>= 0.1.19),
survivalsvm,
tensorflow (>= 2.0.0),
testthat,
Expand All @@ -122,5 +122,5 @@ Config/testthat/edition: 3
Encoding: UTF-8
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3.9000
RoxygenNote: 7.3.1.9000
Config/Needs/website: rmarkdown
299 changes: 46 additions & 253 deletions R/learner_survival_surv_parametric.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#'
#' @description
#' Parametric survival model.
#' Calls [survival::survreg()] from \CRANpkg{survival}.
#' Calls [survivalmodels::parametric()] from 'survivalmodels'.
#'
#' @section Custom mlr3 parameters:
#' - `discrete` determines the class of the returned survival probability
Expand All @@ -14,32 +14,42 @@
#'
#' @template learner
#' @templateVar id surv.parametric
#' @template install_survivalmodels
#'
#' @details
#' This learner allows you to choose a distribution and a model form to compose a predicted
#' survival probability distribution.
#' This learner allows you to choose a distribution and a model form to compose
#' a predicted survival probability distribution.
#'
#' The internal predict method is implemented in this package as our implementation is more
#' efficient for composition to distributions than [survival::predict.survreg()].
#' The predict method is implemented in [survivalmodels::predict.parametric()].
#' Our implementation is more efficient for composition to distributions than
#' [survival::predict.survreg()].
#'
#' `lp` is predicted using the formula \eqn{lp = X\beta} where \eqn{X} are the variables in the test
#' data set and \eqn{\beta} are the fitted coefficients.
#' Three types of prediction are returned for this learner:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to generate this programmatically, i.e. something like r surv_predict_types("lp", "crank")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good idea. Since I plan to go through all survival learners and chnage the prediction order (#331), I can see how many need that prediction info and refactor it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #347

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or better a doc template?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both is fine, whatever works :)

#' 1. `lp`: a vector of linear predictors (relative risk scores), one per
#' observation.
#' `lp` is predicted using the formula \eqn{lp = X\beta} where \eqn{X} are the
#' variables in the test data set and \eqn{\beta} are the fitted coefficients.
#' 2. `crank`: same as `lp`.
#' 3. `distr`: a survival matrix in two dimensions, where observations are
#' represented in rows and time points in columns.
#' The distribution `distr` is composed using the `lp` and specifying a model
#' form in the `type` hyper-parameter. These are as follows, with respective
#' survival functions:
#'
#' The distribution `distr` is composed using the `lp` and specifying a model form in the
#' `type` hyper-parameter. These are as follows, with respective survival functions,
#' * Accelerated Failure Time (`aft`) \deqn{S(t) = S_0(\frac{t}{exp(lp)})}{S(t) = S0(t/exp(lp))}
#' * Proportional Hazards (`ph`) \deqn{S(t) = S_0(t)^{exp(lp)}}{S(t) = S0(t)^exp(lp)}
#' * Proportional Odds (`po`) \deqn{S(t) =
#' - Accelerated Failure Time (`aft`) \deqn{S(t) = S_0(\frac{t}{exp(lp)})}{S(t) = S0(t/exp(lp))}
#' - Proportional Hazards (`ph`) \deqn{S(t) = S_0(t)^{exp(lp)}}{S(t) = S0(t)^exp(lp)}
#' - Proportional Odds (`po`) \deqn{S(t) =
#' \frac{S_0(t)}{exp(-lp) + (1-exp(-lp)) S_0(t)}}{S(t) = S0(t) / [exp(-lp) + S0(t) (1-exp(-lp))]}
#' * Tobit (`tobit`) \deqn{S(t) = 1 - F((t - lp)/s)}
#' - Tobit (`tobit`) \deqn{S(t) = 1 - \Phi((t - lp)/s)}
#'
#' where \eqn{S_0}{S0} is the estimated baseline survival distribution (in this case
#' with a given parametric form), \eqn{lp} is the predicted linear predictor, \eqn{F} is the cdf
#' of a N(0, 1) distribution, and \eqn{s} is the fitted scale parameter.
#' where \eqn{S_0}{S0} is the estimated baseline survival distribution (in
#' this case with a given parametric form), \eqn{lp} is the predicted linear
#' predictor, \eqn{\Phi} is the cdf of a N(0, 1) distribution, and \eqn{s} is
#' the fitted scale parameter.
#'
#' Whilst any combination of distribution and model form is possible, this does not mean it will
#' necessarily create a sensible or interpretable prediction. The following combinations are
#' 'sensible':
#' Whilst any combination of distribution and model form is possible, this does
#' not mean it will necessarily create a sensible or interpretable prediction.
#' The following combinations are 'sensible':
#'
#' * dist = "gaussian"; type = "tobit";
#' * dist = "weibull"; type = "ph";
Expand Down Expand Up @@ -100,256 +110,39 @@ LearnerSurvParametric = R6Class("LearnerSurvParametric",

private = list(
.train = function(task) {

pv = self$param_set$get_values(tags = "train")

if ("weights" %in% task$properties) {
pv$weights = task$weights$weight
}

fit = invoke(survival::survreg, formula = task$formula(), data = task$data(),
.args = pv)

# Fits the baseline distribution by reparameterising the fitted coefficients.
# These were mostly derived numerically as precise documentation on the parameterisations is
# hard to find.
location = as.numeric(fit$coefficients[1])
scale = fit$scale
eps = 1e-15

if (scale < eps) {
scale = eps
} else if (scale > .Machine$double.xmax) {
scale = .Machine$double.xmax
}

if (location < -709 & fit$dist %in% c("weibull", "exponential", "loglogistic")) {
location = -709
}

basedist = switch(fit$dist,
"weibull" = distr6::Weibull$new(shape = 1 / scale, scale = exp(location),
decorators = "ExoticStatistics"),
"exponential" = distr6::Exponential$new(scale = exp(location),
decorators = "ExoticStatistics"),
"gaussian" = distr6::Normal$new(mean = location, sd = scale,
decorators = "ExoticStatistics"),
"lognormal" = distr6::Lognormal$new(meanlog = location, sdlog = scale,
decorators = "ExoticStatistics"),
"loglogistic" = distr6::Loglogistic$new(scale = exp(location),
shape = 1 / scale,
decorators = "ExoticStatistics")
invoke(
survivalmodels::parametric,
data = data.table::setDF(task$data()),
time_variable = task$target_names[1L],
status_variable = task$target_names[2L],
.args = pv
)

set_class(list(fit = fit, basedist = basedist), "surv.parametric")
},

.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
if (pv$discrete) {
pred = invoke(.predict_survreg_discrete, object = self$model, task = task,
learner = self, type = pv$type)
} else {
pred = invoke(.predict_survreg_continuous, object = self$model, task = task,
learner = self, type = pv$type)
}
# lp is aft-style, where higher value = lower risk, opposite needed for crank
list(distr = pred$distr, crank = -pred$lp, lp = -pred$lp)
}
)
)


.predict_survreg_continuous = function(object, task, learner, type = "aft") {
feature_names = intersect(names(learner$state$data_prototype) %??% learner$state$feature_names, task$feature_names)
# Extracts baseline distribution and the model fit, performs assertions
basedist = object$basedist
fit = object$fit
distr6::assertDistribution(basedist)
assertClass(fit, "survreg")

# define newdata from the supplied task and convert to model matrix
newdata = ordered_features(task, learner)
if (any(is.na(newdata))) {
stopf("Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n", learner$id, task$id)
}
x = stats::model.matrix(formulate(rhs = feature_names), data = newdata,
xlev = task$levels())[, -1]

# linear predictor defined by the fitted cofficients multiplied by the model matrix
# (i.e. covariates)
lp = matrix(fit$coefficients[-1], nrow = 1) %*% t(x)

# checks and parameterises the chosen model type: proportional hazard (ph), accelerated failure
# time (aft), odds.
# PH: h(t) = h0(t)exp(lp)
# AFT: h(t) = exp(-lp)h0(t/exp(lp))
# PO: h(t)/h0(t) = {1 + (exp(lp)-1)S0(t)}^-1

dist = toproper(fit$dist)

if (type == "tobit") {
name = paste(dist, "Tobit Model")
short_name = paste0(dist, "Tobit")
description = paste(dist, "Tobit Model with negative log-likelihood",
-fit$loglik[2])
} else if (type == "ph") {
name = paste(dist, "Proportional Hazards Model")
short_name = paste0(dist, "PH")
description = paste(dist, "Proportional Hazards Model with negative log-likelihood",
-fit$loglik[2])
} else if (type == "aft") {
name = paste(dist, "Accelerated Failure Time Model")
short_name = paste0(dist, "AFT")
description = paste(dist, "Accelerated Failure Time Model with negative log-likelihood",
-fit$loglik[2])
} else if (type == "po") {
name = paste(dist, "Proportional Odds Model")
short_name = paste0(dist, "PO")
description = paste(dist, "Proportional Odds Model with negative log-likelihood",
-fit$loglik[2])
}

params = list(list(name = name,
short_name = short_name,
type = set6::PosReals$new(),
support = set6::PosReals$new(),
valueSupport = "continuous",
variateForm = "univariate",
description = description,
.suppressChecks = TRUE,
pdf = function() {
},
cdf = function() {
},
parameters = param6::pset()
))

params = rep(params, length(lp))

pdf = function(x) {} # nolint
cdf = function(x) {} # nolint
quantile = function(p) {} # nolint
newdata = as.data.frame(ordered_features(task, self))

if (type == "tobit") {
for (i in seq_along(lp)) {
body(pdf) = substitute(pnorm((x - y) / scale), list(
y = lp[i] + fit$coefficients[1],
scale = basedist$stdev()
))
body(cdf) = substitute(pnorm((x - y) / scale), list(
y = lp[i] + fit$coefficients[1],
scale = basedist$stdev()
))
body(quantile) = substitute(qnorm(p) * scale + y, list(
y = lp[i] + fit$coefficients[1],
scale = basedist$stdev()
))
params[[i]]$pdf = pdf
params[[i]]$cdf = cdf
params[[i]]$quantile = quantile
}
} else if (type == "ph") {
for (i in seq_along(lp)) {
body(pdf) = substitute((exp(y) * basedist$hazard(x)) * (1 - self$cdf(x)), list(y = -lp[i]))
body(cdf) = substitute(1 - (basedist$survival(x)^exp(y)), list(y = -lp[i]))
body(quantile) = substitute(
basedist$quantile(1 - exp(exp(-y) * log(1 - p))), # nolint
list(y = -lp[i])
pred = invoke(
predict,
self$model,
newdata = newdata,
distr6 = !pv$discrete,
type = "all",
.args = pars
)
params[[i]]$pdf = pdf
params[[i]]$cdf = cdf
params[[i]]$quantile = quantile
}
} else if (type == "aft") {
for (i in seq_along(lp)) {
body(pdf) = substitute((exp(-y) * basedist$hazard(x / exp(y))) * (1 - self$cdf(x)),
list(y = lp[i]))
body(cdf) = substitute(1 - (basedist$survival(x / exp(y))), list(y = lp[i]))
body(quantile) = substitute(exp(y) * basedist$quantile(p), list(y = lp[i]))
params[[i]]$pdf = pdf
params[[i]]$cdf = cdf
params[[i]]$quantile = quantile
}
} else if (type == "po") {
for (i in seq_along(lp)) {
body(pdf) = substitute((basedist$hazard(x) *
(1 - (basedist$survival(x) /
(((exp(y) - 1)^-1) + basedist$survival(x))))) *
(1 - self$cdf(x)), list(y = lp[i]))
body(cdf) = substitute(1 - (basedist$survival(x) *
(exp(-y) + (1 - exp(-y)) * basedist$survival(x))^-1), # nolint
list(y = lp[i]))
body(quantile) = substitute(basedist$quantile(-p / ((exp(-y) * (p - 1)) - p)), # nolint
list(y = lp[i]))
params[[i]]$pdf = pdf
params[[i]]$cdf = cdf
params[[i]]$quantile = quantile
}
}

distlist = lapply(params, function(.x) do.call(distr6::Distribution$new, .x))
names(distlist) = paste0(short_name, seq_along(distlist))

distr = distr6::VectorDistribution$new(distlist,
decorators = c("CoreStatistics", "ExoticStatistics"))

lp = lp + fit$coefficients[1]

list(lp = as.numeric(lp), distr = distr)
}


.predict_survreg_discrete = function(object, task, learner, type = "aft") {
feature_names = intersect(names(learner$state$data_prototype), task$feature_names)

# Extracts baseline distribution and the model fit, performs assertions
basedist = object$basedist
fit = object$fit
distr6::assertDistribution(basedist)
assertClass(fit, "survreg")

# define newdata from the supplied task and convert to model matrix
newdata = ordered_features(task, learner)
if (any(is.na(newdata))) {
stopf("Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n", learner$id, task$id)
}

times = task$unique_times()

# PH: h(t) = h0(t)exp(lp)
# AFT: h(t) = exp(-lp)h0(t/exp(lp))
# PO: h(t)/h0(t) = {1 + (exp(lp)-1)S0(t)}^-1
if (type == "tobit") {
fun = function(y) stats::pnorm((times - y - fit$coefficients[1]) / basedist$stdev())
} else if (type == "ph") {
fun = function(y) 1 - (basedist$survival(times)^exp(-y))
} else if (type == "aft") {
fun = function(y) 1 - (basedist$survival(times / exp(y)))
} else if (type == "po") {
fun = function(y) {
surv = basedist$survival(times)
1 - (surv * (exp(-y) + (1 - exp(-y)) * surv)^-1)
# lp is aft-style, where higher value = lower risk
list(crank = pred$risk, distr = pred$surv)
}
}

# linear predictor defined by the fitted cofficients multiplied by the model matrix
# (i.e. covariates)
x = stats::model.matrix(mlr3misc::formulate(rhs = feature_names), data = newdata,
xlev = task$levels())[, -1]
lp = matrix(fit$coefficients[-1], nrow = 1) %*% t(x)

if (length(times) == 1) { # edge case
mat = as.matrix(vapply(lp, fun, numeric(1)), ncol = 1)
} else {
mat = t(vapply(lp, fun, numeric(length(times))))
}
colnames(mat) = times

list(
lp = as.numeric(lp + fit$coefficients[1]),
distr = distr6::as.Distribution(mat, fun = "cdf")
)
}
)

.extralrns_dict$add("surv.parametric", LearnerSurvParametric)
Loading
Loading