Skip to content

Commit

Permalink
refactor: weights (#1065)
Browse files Browse the repository at this point in the history
* update docs for a fine granularity of weight columns

* add deprecation warning and rename old weight role

* implemented new weights in task

* added weights everywhere, docs, tests

* changed defaults

* improve docs

* rename role weights_train to weights_learner for consistency

* fix printer test

* improve docs

* cleanup

* better soft deprecate

* fix typos

* next try

* hard deprecate weights of learner, all workarounds failed

* adapt autotest

* update news

* cleanup

* fix autotest

* fix error message

* change defaults

* fix test

* fix: properties

* refactor: add use weight parameter in base learner

* chore: deprecated

* tests: col roles

* fix: use weights

* test: use weights

* fix

---------

Co-authored-by: be-marc <marcbecker@posteo.de>
  • Loading branch information
mllg and be-marc authored Aug 23, 2024
1 parent 27531f6 commit 30639e7
Show file tree
Hide file tree
Showing 61 changed files with 650 additions and 217 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Depends:
R (>= 3.1.0)
Imports:
R6 (>= 2.4.1),
backports,
backports (>= 1.5.0),
checkmate (>= 2.0.0),
data.table (>= 1.15.0),
evaluate,
Expand Down
16 changes: 10 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
# mlr3 (development version)

* Deprecated `data_format` and `data_formats` for Learners, Tasks, and DataBackends.
* refactor: It is now possible to use weights also during scoring predictions via measures and during resampling to sample observations with unequal probability.
The weights must be stored in the task and can be assigned the column role `weights_measure` or `weights_resampling`, respectively.
The weights used during training by the Learner are renamed to `weights_learner`, the previous column role `weight` is dysfunctional.
Additionally, it is now possible to disable the use of weights via the new hyperparameter `use_weights`.
Note that this is a breaking change, but appears to be the less error-prone solution in the long run.
* refactor: Deprecated `data_format` and `data_formats` for Learners, Tasks, and DataBackends.
* feat: The `partition()` function creates training, test and validation sets.
* refactor: Optimize runtime of fixing factor levels.
* refactor: Optimize runtime of setting row roles.
* refactor: Optimize runtime of marshalling.
* refactor: Optimize runtime of `Task$col_info`.
* fix: column info is now checked for compatibility during `Learner$predict` (#943).
* BREAKING CHANGE: the predict time of the learner now stores the cumulative duration for all predict sets (#992).
* fix: Column info is now checked for compatibility during `Learner$predict` (#943).
* BREAKING CHANGE: The predict time of the learner now stores the cumulative duration for all predict sets (#992).
* feat: `$internal_valid_task` can now be set to an `integer` vector.
* feat: Measures can now have an empty `$predict_sets` (#1094).
this is relevant for measures that only extract information from
the model of a learner (such as internal validation scores or AIC / BIC)
This is relevant for measures that only extract information from the model of a learner (such as internal validation scores or AIC / BIC)
* refactor: Deprecated the `$divide()` method
* fix: `Task$cbind()` now works with non-standard primary keys for `data.frames` (#961).
* fix: Triggering of fallback learner now has log-level `"info"` instead of `"debug"` (#972).
* feat: Added new measure `pinballs `.
* feat: Added new measure `mu_auc`.
* feat: Add option to calculate the mean of the true values on the train set in `msr("regr.rsq")`.
* feat: default fallback learner is set when encapsulation is activated.
* feat: Default fallback learner is set when encapsulation is activated.

# mlr3 0.20.2

Expand Down
16 changes: 14 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@
#' Only available for [`Learner`]s with the `"internal_tuning"` property.
#' If the learner is not trained yet, this returns `NULL`.
#'
#' @section Weights:
#'
#' Many learners support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_learner` needs to be assigned to a single numeric column.
#' The weights are automatically used if found in the task, this can be disabled by setting the hyperparamerter `use_weights` to `FALSE`.
#' If the learner is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' The weights do not necessarily need to sum up to 1, they are passed down to the learner.
#'
#' @section Setting Hyperparameters:
#'
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
Expand Down Expand Up @@ -215,7 +223,6 @@ Learner = R6Class("Learner",
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
private$.param_set = assert_param_set(param_set)
self$feature_types = assert_ordered_set(feature_types, mlr_reflections$task_feature_types, .var.name = "feature_types")
self$predict_types = assert_ordered_set(predict_types, names(mlr_reflections$learner_predict_types[[task_type]]),
empty.ok = FALSE, .var.name = "predict_types")
Expand All @@ -225,6 +232,11 @@ Learner = R6Class("Learner",
self$packages = union("mlr3", assert_character(packages, any.missing = FALSE, min.chars = 1L))
self$man = assert_string(man, na.ok = TRUE)

if ("weights" %in% self$properties) {
param_set = c(param_set, ps(use_weights = p_lgl(default = TRUE, tags = "train")))
}
private$.param_set = assert_param_set(param_set)

check_packages_installed(packages, msg = sprintf("Package '%%s' required but not installed for Learner '%s'", id))
},

Expand Down Expand Up @@ -405,7 +417,7 @@ Learner = R6Class("Learner",
assert_names(newdata$colnames, must.include = task$feature_names)

# the following columns are automatically set to NA if missing
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE)
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weights_learner", "weights_measure", "weights_resampling")], use.names = FALSE)
impute = setdiff(impute, newdata$colnames)
if (length(impute)) {
# create list with correct NA types and cbind it to the backend
Expand Down
10 changes: 4 additions & 6 deletions R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#'
#' @section Initial parameter values:
#' * Parameter `xval` is initialized to 0 in order to save some computation time.
#' * Parameter `use_weights` can be set to `FALSE` to ignore observation weights with column role `weights_learner` ,
#' if present.
#'
#' @section Custom mlr3 parameters:
#' * Parameter `model` has been renamed to `keep_model`.
Expand Down Expand Up @@ -35,9 +37,8 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down Expand Up @@ -77,10 +78,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv = get_weights(task$weights_learner$weight, pv)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
10 changes: 4 additions & 6 deletions R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#'
#' @section Initial parameter values:
#' * Parameter `xval` is initialized to 0 in order to save some computation time.
#' * Parameter `use_weights` can be set to `FALSE` to ignore observation weights with column role `weights_learner` ,
#' if present.
#'
#' @section Custom mlr3 parameters:
#' * Parameter `model` has been renamed to `keep_model`.
Expand Down Expand Up @@ -35,9 +37,8 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
xval = p_int(0L, default = 10L, init = 0L, tags = "train")
)
ps$values = list(xval = 0L)

super$initialize(
id = "regr.rpart",
Expand Down Expand Up @@ -77,10 +78,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
if ("weights" %in% task$properties) {
pv = insert_named(pv, list(weights = task$weights$weight))
}

pv = get_weights(task$weights_learner$weight, pv)
invoke(rpart::rpart, formula = task$formula(), data = task$data(), .args = pv, .opts = allow_partial_matching)
},

Expand Down
41 changes: 29 additions & 12 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
#' In such cases it is necessary to overwrite the public methods `$aggregate()` and/or `$score()` to return a named `numeric()`
#' where at least one of its names corresponds to the `id` of the measure itself.
#'
#' @section Weights:
#'
#' Many measures support observation weights, indicated by their property `"weights"`.
#' The weights are stored in the [Task] where the column role `weights_measure` needs to be assigned to a single numeric column.
#' The weights are automatically used if found in the task, this can be disabled by setting the hyperparamerter `use_weights` to `FALSE`.
#' If the measure is set-up to use weights but the task does not have a designated weight column, an unweighted version is calculated instead.
#' The weights do not necessarily need to sum up to 1, they are normalized by dividing by the sum of weights.
#'
#' @template param_id
#' @template param_param_set
#' @template param_range
Expand Down Expand Up @@ -94,10 +102,6 @@ Measure = R6Class("Measure",
#' Lower and upper bound of possible performance scores.
range = NULL,

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = NULL,

#' @field minimize (`logical(1)`)\cr
#' If `TRUE`, good predictions correspond to small values of performance scores.
minimize = NULL,
Expand All @@ -117,7 +121,6 @@ Measure = R6Class("Measure",
predict_sets = "test", task_properties = character(), packages = character(),
label = NA_character_, man = NA_character_, trafo = NULL) {

self$properties = unique(properties)
self$id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = task_type
Expand All @@ -140,6 +143,8 @@ Measure = R6Class("Measure",
assert_subset(task_properties, mlr_reflections$task_properties[[task_type]])
}


self$properties = unique(properties)
self$predict_type = predict_type
self$predict_sets = predict_sets
self$task_properties = task_properties
Expand Down Expand Up @@ -195,24 +200,25 @@ Measure = R6Class("Measure",
#' @return `numeric(1)`.
score = function(prediction, task = NULL, learner = NULL, train_set = NULL) {
assert_measure(self, task = task, learner = learner)
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% self$properties)
properties = self$properties
assert_prediction(prediction, null.ok = "requires_no_prediction" %nin% properties)

if ("requires_task" %in% self$properties && is.null(task)) {
if ("requires_task" %in% properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}

if ("requires_learner" %in% self$properties && is.null(learner)) {
if ("requires_learner" %in% properties && is.null(learner)) {
stopf("Measure '%s' requires a learner", self$id)
}

if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
if ("requires_model" %in% properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && is_marshaled_model(learner$model)) {
if ("requires_model" %in% properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", self$id)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
if ("requires_train_set" %in% properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
}

Expand All @@ -231,7 +237,6 @@ Measure = R6Class("Measure",
#'
#' @return `numeric(1)`.
aggregate = function(rr) {

switch(self$average,
"macro" = {
aggregator = self$aggregator %??% mean
Expand Down Expand Up @@ -275,6 +280,17 @@ Measure = R6Class("Measure",
self$predict_sets, mget(private$.extra_hash, envir = self))
},

#' @field properties (`character()`)\cr
#' Properties of this measure.
properties = function(rhs) {
if (!missing(rhs)) {
props = if (is.na(self$task_type)) unique(unlist(mlr_reflections$measure_properties), use.names = FALSE) else mlr_reflections$measure_properties[[self$task_type]]
private$.properties = assert_subset(rhs, props)
} else {
private$.properties
}
},

#' @field average (`character(1)`)\cr
#' Method for aggregation:
#'
Expand Down Expand Up @@ -307,6 +323,7 @@ Measure = R6Class("Measure",
),

private = list(
.properties = character(),
.predict_sets = NULL,
.extra_hash = character(),
.average = NULL,
Expand Down
Loading

0 comments on commit 30639e7

Please sign in to comment.