Skip to content

Commit

Permalink
Merge branch 'main' into weka_learners
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc authored Dec 14, 2023
2 parents e63925b + 7546845 commit 97c12d6
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 34 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# mlr3extralearners 0.7.1-9000

# Added method `selected_features()` to CoxBoost survival learners (thanks to @bblodfon)
* Fix: Replace hardcoded `VectorDistribution`s from partykit survival learners with survival matrices (`Matdist`) (thanks to @bblodfon)
* Added method `selected_features()` to CoxBoost survival learners (thanks to @bblodfon)
* Added the Random Planted Forest Learner (thanks to @jemus42)
* re-added the catboost learner as it was requested (was previously removed
because of installation issues)
Expand Down
25 changes: 12 additions & 13 deletions R/learner_partykit_surv_cforest.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,20 @@ LearnerSurvCForest = R6Class("LearnerSurvCForest",
preds = invoke(predict, object = self$model, newdata = newdata,
type = "prob", .args = pars)

# Define WeightedDiscrete distr6 distribution from the survival function
x = lapply(preds, function(z) {
time = c(0, z$time, max(z$time) + 1e-3)
surv = c(1, z$surv, 0)
data.frame(x = time, cdf = 1 - surv)
})
distr = distr6::VectorDistribution$new(
distribution = "WeightedDiscrete",
params = x,
decorators = c("CoreStatistics", "ExoticStatistics"))
times = lapply(preds, function(p) p$time)
utimes = sort(unique(unlist(times)))

# Define crank as the mean of the survival distribution
crank = -vapply(x, function(z) sum(z[, 1] * c(z[, 2][1], diff(z[, 2]))), numeric(1))
# to use non-exported function from `distr6`
extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
res = lapply(preds, function(p) {
# p is a `survfit` object
cdf = matrix(data = 1 - p$surv, ncol = 1) # 1 observation (column), rows => times
# extend cdf to 'utimes', return survival
extend_times(utimes, p$time, cdf = cdf, FALSE, FALSE)
})
surv = do.call(cbind, res) # rows => times, columns => obs

list(crank = crank, distr = distr)
.surv_return(times = utimes, surv = t(surv))
}
)
)
Expand Down
25 changes: 12 additions & 13 deletions R/learner_partykit_surv_ctree.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,20 @@ LearnerSurvCTree = R6Class("LearnerSurvCTree",
.args = pars
)

# Define WeightedDiscrete distr6 distribution from the survival function
x = lapply(preds, function(z) {
time = c(0, z$time, max(z$time) + 1e-3)
surv = c(1, z$surv, 0)
data.frame(x = time, cdf = 1 - surv)
})
distr = distr6::VectorDistribution$new(
distribution = "WeightedDiscrete",
params = x,
decorators = c("CoreStatistics", "ExoticStatistics"))
times = lapply(preds, function(p) p$time)
utimes = sort(unique(unlist(times)))

# Define crank as the mean of the survival distribution
crank = -vapply(x, function(z) sum(z[, 1] * c(z[, 2][1], diff(z[, 2]))), numeric(1))
# to use non-exported function from `distr6`
extend_times = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
res = lapply(preds, function(p) {
# p is a `survfit` object
cdf = matrix(data = 1 - p$surv, ncol = 1) # 1 observation (column), rows => times
# extend cdf to 'utimes', return survival
extend_times(utimes, p$time, cdf = cdf, FALSE, FALSE)
})
surv = do.call(cbind, res) # rows => times, columns => obs

list(crank = crank, distr = distr)
.surv_return(times = utimes, surv = t(surv))
}
)
)
Expand Down
20 changes: 18 additions & 2 deletions tests/testthat/test_partykit_surv_cforest.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
test_that("autotest", {
learner = LearnerSurvCForest$new()
learner = lrn("surv.cforest", ntree = 5)
expect_learner(learner)
result = run_autotest(learner, check_replicable = FALSE, N = 100)
result = run_autotest(learner, check_replicable = FALSE, exclude = "sanity")
expect_true(result, info = result$error)
})

test_that("correct prediction types", {
with_seed(42, {
task = tsk("rats")$filter(sample(1:300, 50))
part = partition(task, ratio = 0.9)
train_rows = part$train
test_rows = part$test
unique_times = task$unique_times(train_rows)

learner = lrn("surv.cforest", ntree = 5)
p = learner$train(task, part$train)$predict(task, test_rows)
expect_matrix(p$data$distr, nrows = length(test_rows),
max.cols = length(unique_times))
expect_numeric(p$crank, len = length(test_rows))
})
})
27 changes: 22 additions & 5 deletions tests/testthat/test_partykit_surv_ctree.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
test_that("autotest", {
set.seed(1)
learner = LearnerSurvCTree$new()
expect_learner(learner)
result = run_autotest(learner, check_replicable = FALSE)
expect_true(result, info = result$error)
with_seed(42, {
learner = lrn("surv.ctree")
expect_learner(learner)
result = run_autotest(learner, check_replicable = FALSE)
expect_true(result, info = result$error)
})
})

test_that("correct prediction types", {
with_seed(42, {
task = tsk("rats")$filter(sample(1:300, 50))
part = partition(task, ratio = 0.9)
train_rows = part$train
test_rows = part$test
unique_times = task$unique_times(train_rows)

learner = lrn("surv.ctree")
p = learner$train(task, train_rows)$predict(task, test_rows)
expect_matrix(p$data$distr, nrows = length(test_rows),
max.cols = length(unique_times))
expect_numeric(p$crank, len = length(test_rows))
})
})

0 comments on commit 97c12d6

Please sign in to comment.